1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import binascii
8import os
9
10import pytest
11
12from cryptography.exceptions import InvalidTag, UnsupportedAlgorithm, _Reasons
13from cryptography.hazmat.backends.interfaces import CipherBackend
14from cryptography.hazmat.primitives.ciphers.aead import (
15    AESCCM,
16    AESGCM,
17    ChaCha20Poly1305,
18)
19
20from .utils import _load_all_params
21from ...utils import (
22    load_nist_ccm_vectors,
23    load_nist_vectors,
24    load_vectors_from_file,
25    raises_unsupported_algorithm,
26)
27
28
29class FakeData(object):
30    def __len__(self):
31        return 2 ** 32 + 1
32
33
34def _aead_supported(cls):
35    try:
36        cls(b"0" * 32)
37        return True
38    except UnsupportedAlgorithm:
39        return False
40
41
42@pytest.mark.skipif(
43    _aead_supported(ChaCha20Poly1305),
44    reason="Requires OpenSSL without ChaCha20Poly1305 support",
45)
46@pytest.mark.requires_backend_interface(interface=CipherBackend)
47def test_chacha20poly1305_unsupported_on_older_openssl(backend):
48    with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_CIPHER):
49        ChaCha20Poly1305(ChaCha20Poly1305.generate_key())
50
51
52@pytest.mark.skipif(
53    not _aead_supported(ChaCha20Poly1305),
54    reason="Does not support ChaCha20Poly1305",
55)
56@pytest.mark.requires_backend_interface(interface=CipherBackend)
57class TestChaCha20Poly1305(object):
58    def test_data_too_large(self):
59        key = ChaCha20Poly1305.generate_key()
60        chacha = ChaCha20Poly1305(key)
61        nonce = b"0" * 12
62
63        with pytest.raises(OverflowError):
64            chacha.encrypt(nonce, FakeData(), b"")
65
66        with pytest.raises(OverflowError):
67            chacha.encrypt(nonce, b"", FakeData())
68
69    def test_generate_key(self):
70        key = ChaCha20Poly1305.generate_key()
71        assert len(key) == 32
72
73    def test_bad_key(self, backend):
74        with pytest.raises(TypeError):
75            ChaCha20Poly1305(object())
76
77        with pytest.raises(ValueError):
78            ChaCha20Poly1305(b"0" * 31)
79
80    @pytest.mark.parametrize(
81        ("nonce", "data", "associated_data"),
82        [
83            [object(), b"data", b""],
84            [b"0" * 12, object(), b""],
85            [b"0" * 12, b"data", object()],
86        ],
87    )
88    def test_params_not_bytes_encrypt(
89        self, nonce, data, associated_data, backend
90    ):
91        key = ChaCha20Poly1305.generate_key()
92        chacha = ChaCha20Poly1305(key)
93        with pytest.raises(TypeError):
94            chacha.encrypt(nonce, data, associated_data)
95
96        with pytest.raises(TypeError):
97            chacha.decrypt(nonce, data, associated_data)
98
99    def test_nonce_not_12_bytes(self, backend):
100        key = ChaCha20Poly1305.generate_key()
101        chacha = ChaCha20Poly1305(key)
102        with pytest.raises(ValueError):
103            chacha.encrypt(b"00", b"hello", b"")
104
105        with pytest.raises(ValueError):
106            chacha.decrypt(b"00", b"hello", b"")
107
108    def test_decrypt_data_too_short(self, backend):
109        key = ChaCha20Poly1305.generate_key()
110        chacha = ChaCha20Poly1305(key)
111        with pytest.raises(InvalidTag):
112            chacha.decrypt(b"0" * 12, b"0", None)
113
114    def test_associated_data_none_equal_to_empty_bytestring(self, backend):
115        key = ChaCha20Poly1305.generate_key()
116        chacha = ChaCha20Poly1305(key)
117        nonce = os.urandom(12)
118        ct1 = chacha.encrypt(nonce, b"some_data", None)
119        ct2 = chacha.encrypt(nonce, b"some_data", b"")
120        assert ct1 == ct2
121        pt1 = chacha.decrypt(nonce, ct1, None)
122        pt2 = chacha.decrypt(nonce, ct2, b"")
123        assert pt1 == pt2
124
125    @pytest.mark.parametrize(
126        "vector",
127        load_vectors_from_file(
128            os.path.join("ciphers", "ChaCha20Poly1305", "openssl.txt"),
129            load_nist_vectors,
130        ),
131    )
132    def test_openssl_vectors(self, vector, backend):
133        key = binascii.unhexlify(vector["key"])
134        nonce = binascii.unhexlify(vector["iv"])
135        aad = binascii.unhexlify(vector["aad"])
136        tag = binascii.unhexlify(vector["tag"])
137        pt = binascii.unhexlify(vector["plaintext"])
138        ct = binascii.unhexlify(vector["ciphertext"])
139        chacha = ChaCha20Poly1305(key)
140        if vector.get("result") == b"CIPHERFINAL_ERROR":
141            with pytest.raises(InvalidTag):
142                chacha.decrypt(nonce, ct + tag, aad)
143        else:
144            computed_pt = chacha.decrypt(nonce, ct + tag, aad)
145            assert computed_pt == pt
146            computed_ct = chacha.encrypt(nonce, pt, aad)
147            assert computed_ct == ct + tag
148
149    @pytest.mark.parametrize(
150        "vector",
151        load_vectors_from_file(
152            os.path.join("ciphers", "ChaCha20Poly1305", "boringssl.txt"),
153            load_nist_vectors,
154        ),
155    )
156    def test_boringssl_vectors(self, vector, backend):
157        key = binascii.unhexlify(vector["key"])
158        nonce = binascii.unhexlify(vector["nonce"])
159        if vector["ad"].startswith(b'"'):
160            aad = vector["ad"][1:-1]
161        else:
162            aad = binascii.unhexlify(vector["ad"])
163        tag = binascii.unhexlify(vector["tag"])
164        if vector["in"].startswith(b'"'):
165            pt = vector["in"][1:-1]
166        else:
167            pt = binascii.unhexlify(vector["in"])
168        ct = binascii.unhexlify(vector["ct"].strip(b'"'))
169        chacha = ChaCha20Poly1305(key)
170        computed_pt = chacha.decrypt(nonce, ct + tag, aad)
171        assert computed_pt == pt
172        computed_ct = chacha.encrypt(nonce, pt, aad)
173        assert computed_ct == ct + tag
174
175    def test_buffer_protocol(self, backend):
176        key = ChaCha20Poly1305.generate_key()
177        chacha = ChaCha20Poly1305(key)
178        pt = b"encrypt me"
179        ad = b"additional"
180        nonce = os.urandom(12)
181        ct = chacha.encrypt(nonce, pt, ad)
182        computed_pt = chacha.decrypt(nonce, ct, ad)
183        assert computed_pt == pt
184        chacha2 = ChaCha20Poly1305(bytearray(key))
185        ct2 = chacha2.encrypt(bytearray(nonce), pt, ad)
186        assert ct2 == ct
187        computed_pt2 = chacha2.decrypt(bytearray(nonce), ct2, ad)
188        assert computed_pt2 == pt
189
190
191@pytest.mark.requires_backend_interface(interface=CipherBackend)
192class TestAESCCM(object):
193    def test_data_too_large(self):
194        key = AESCCM.generate_key(128)
195        aesccm = AESCCM(key)
196        nonce = b"0" * 12
197
198        with pytest.raises(OverflowError):
199            aesccm.encrypt(nonce, FakeData(), b"")
200
201        with pytest.raises(OverflowError):
202            aesccm.encrypt(nonce, b"", FakeData())
203
204    def test_default_tag_length(self, backend):
205        key = AESCCM.generate_key(128)
206        aesccm = AESCCM(key)
207        nonce = os.urandom(12)
208        pt = b"hello"
209        ct = aesccm.encrypt(nonce, pt, None)
210        assert len(ct) == len(pt) + 16
211
212    def test_invalid_tag_length(self, backend):
213        key = AESCCM.generate_key(128)
214        with pytest.raises(ValueError):
215            AESCCM(key, tag_length=7)
216
217        with pytest.raises(ValueError):
218            AESCCM(key, tag_length=2)
219
220        with pytest.raises(TypeError):
221            AESCCM(key, tag_length="notanint")
222
223    def test_invalid_nonce_length(self, backend):
224        key = AESCCM.generate_key(128)
225        aesccm = AESCCM(key)
226        pt = b"hello"
227        nonce = os.urandom(14)
228        with pytest.raises(ValueError):
229            aesccm.encrypt(nonce, pt, None)
230
231        with pytest.raises(ValueError):
232            aesccm.encrypt(nonce[:6], pt, None)
233
234    @pytest.mark.parametrize(
235        "vector",
236        _load_all_params(
237            os.path.join("ciphers", "AES", "CCM"),
238            [
239                "DVPT128.rsp",
240                "DVPT192.rsp",
241                "DVPT256.rsp",
242                "VADT128.rsp",
243                "VADT192.rsp",
244                "VADT256.rsp",
245                "VNT128.rsp",
246                "VNT192.rsp",
247                "VNT256.rsp",
248                "VPT128.rsp",
249                "VPT192.rsp",
250                "VPT256.rsp",
251            ],
252            load_nist_ccm_vectors,
253        ),
254    )
255    def test_vectors(self, vector, backend):
256        key = binascii.unhexlify(vector["key"])
257        nonce = binascii.unhexlify(vector["nonce"])
258        adata = binascii.unhexlify(vector["adata"])[: vector["alen"]]
259        ct = binascii.unhexlify(vector["ct"])
260        pt = binascii.unhexlify(vector["payload"])[: vector["plen"]]
261        aesccm = AESCCM(key, vector["tlen"])
262        if vector.get("fail"):
263            with pytest.raises(InvalidTag):
264                aesccm.decrypt(nonce, ct, adata)
265        else:
266            computed_pt = aesccm.decrypt(nonce, ct, adata)
267            assert computed_pt == pt
268            assert aesccm.encrypt(nonce, pt, adata) == ct
269
270    def test_roundtrip(self, backend):
271        key = AESCCM.generate_key(128)
272        aesccm = AESCCM(key)
273        pt = b"encrypt me"
274        ad = b"additional"
275        nonce = os.urandom(12)
276        ct = aesccm.encrypt(nonce, pt, ad)
277        computed_pt = aesccm.decrypt(nonce, ct, ad)
278        assert computed_pt == pt
279
280    def test_nonce_too_long(self, backend):
281        key = AESCCM.generate_key(128)
282        aesccm = AESCCM(key)
283        pt = b"encrypt me" * 6600
284        # pt can be no more than 65536 bytes when nonce is 13 bytes
285        nonce = os.urandom(13)
286        with pytest.raises(ValueError):
287            aesccm.encrypt(nonce, pt, None)
288
289    @pytest.mark.parametrize(
290        ("nonce", "data", "associated_data"),
291        [
292            [object(), b"data", b""],
293            [b"0" * 12, object(), b""],
294            [b"0" * 12, b"data", object()],
295        ],
296    )
297    def test_params_not_bytes(self, nonce, data, associated_data, backend):
298        key = AESCCM.generate_key(128)
299        aesccm = AESCCM(key)
300        with pytest.raises(TypeError):
301            aesccm.encrypt(nonce, data, associated_data)
302
303    def test_bad_key(self, backend):
304        with pytest.raises(TypeError):
305            AESCCM(object())
306
307        with pytest.raises(ValueError):
308            AESCCM(b"0" * 31)
309
310    def test_bad_generate_key(self, backend):
311        with pytest.raises(TypeError):
312            AESCCM.generate_key(object())
313
314        with pytest.raises(ValueError):
315            AESCCM.generate_key(129)
316
317    def test_associated_data_none_equal_to_empty_bytestring(self, backend):
318        key = AESCCM.generate_key(128)
319        aesccm = AESCCM(key)
320        nonce = os.urandom(12)
321        ct1 = aesccm.encrypt(nonce, b"some_data", None)
322        ct2 = aesccm.encrypt(nonce, b"some_data", b"")
323        assert ct1 == ct2
324        pt1 = aesccm.decrypt(nonce, ct1, None)
325        pt2 = aesccm.decrypt(nonce, ct2, b"")
326        assert pt1 == pt2
327
328    def test_decrypt_data_too_short(self, backend):
329        key = AESCCM.generate_key(128)
330        aesccm = AESCCM(key)
331        with pytest.raises(InvalidTag):
332            aesccm.decrypt(b"0" * 12, b"0", None)
333
334    def test_buffer_protocol(self, backend):
335        key = AESCCM.generate_key(128)
336        aesccm = AESCCM(key)
337        pt = b"encrypt me"
338        ad = b"additional"
339        nonce = os.urandom(12)
340        ct = aesccm.encrypt(nonce, pt, ad)
341        computed_pt = aesccm.decrypt(nonce, ct, ad)
342        assert computed_pt == pt
343        aesccm2 = AESCCM(bytearray(key))
344        ct2 = aesccm2.encrypt(bytearray(nonce), pt, ad)
345        assert ct2 == ct
346        computed_pt2 = aesccm2.decrypt(bytearray(nonce), ct2, ad)
347        assert computed_pt2 == pt
348
349
350def _load_gcm_vectors():
351    vectors = _load_all_params(
352        os.path.join("ciphers", "AES", "GCM"),
353        [
354            "gcmDecrypt128.rsp",
355            "gcmDecrypt192.rsp",
356            "gcmDecrypt256.rsp",
357            "gcmEncryptExtIV128.rsp",
358            "gcmEncryptExtIV192.rsp",
359            "gcmEncryptExtIV256.rsp",
360        ],
361        load_nist_vectors,
362    )
363    return [x for x in vectors if len(x["tag"]) == 32]
364
365
366@pytest.mark.requires_backend_interface(interface=CipherBackend)
367class TestAESGCM(object):
368    def test_data_too_large(self):
369        key = AESGCM.generate_key(128)
370        aesgcm = AESGCM(key)
371        nonce = b"0" * 12
372
373        with pytest.raises(OverflowError):
374            aesgcm.encrypt(nonce, FakeData(), b"")
375
376        with pytest.raises(OverflowError):
377            aesgcm.encrypt(nonce, b"", FakeData())
378
379    @pytest.mark.parametrize("vector", _load_gcm_vectors())
380    def test_vectors(self, backend, vector):
381        nonce = binascii.unhexlify(vector["iv"])
382
383        if len(nonce) < 8:
384            pytest.skip("GCM does not support less than 64-bit IVs")
385
386        if backend._fips_enabled and len(nonce) != 12:
387            # Red Hat disables non-96-bit IV support as part of its FIPS
388            # patches.
389            pytest.skip("Non-96-bit IVs unsupported in FIPS mode.")
390
391        key = binascii.unhexlify(vector["key"])
392        aad = binascii.unhexlify(vector["aad"])
393        ct = binascii.unhexlify(vector["ct"])
394        pt = binascii.unhexlify(vector.get("pt", b""))
395        tag = binascii.unhexlify(vector["tag"])
396        aesgcm = AESGCM(key)
397        if vector.get("fail") is True:
398            with pytest.raises(InvalidTag):
399                aesgcm.decrypt(nonce, ct + tag, aad)
400        else:
401            computed_ct = aesgcm.encrypt(nonce, pt, aad)
402            assert computed_ct[:-16] == ct
403            assert computed_ct[-16:] == tag
404            computed_pt = aesgcm.decrypt(nonce, ct + tag, aad)
405            assert computed_pt == pt
406
407    @pytest.mark.parametrize(
408        ("nonce", "data", "associated_data"),
409        [
410            [object(), b"data", b""],
411            [b"0" * 12, object(), b""],
412            [b"0" * 12, b"data", object()],
413        ],
414    )
415    def test_params_not_bytes(self, nonce, data, associated_data, backend):
416        key = AESGCM.generate_key(128)
417        aesgcm = AESGCM(key)
418        with pytest.raises(TypeError):
419            aesgcm.encrypt(nonce, data, associated_data)
420
421        with pytest.raises(TypeError):
422            aesgcm.decrypt(nonce, data, associated_data)
423
424    @pytest.mark.parametrize("length", [7, 129])
425    def test_invalid_nonce_length(self, length, backend):
426        if backend._fips_enabled:
427            # Red Hat disables non-96-bit IV support as part of its FIPS
428            # patches.
429            pytest.skip("Non-96-bit IVs unsupported in FIPS mode.")
430
431        key = AESGCM.generate_key(128)
432        aesgcm = AESGCM(key)
433        with pytest.raises(ValueError):
434            aesgcm.encrypt(b"\x00" * length, b"hi", None)
435
436    def test_bad_key(self, backend):
437        with pytest.raises(TypeError):
438            AESGCM(object())
439
440        with pytest.raises(ValueError):
441            AESGCM(b"0" * 31)
442
443    def test_bad_generate_key(self, backend):
444        with pytest.raises(TypeError):
445            AESGCM.generate_key(object())
446
447        with pytest.raises(ValueError):
448            AESGCM.generate_key(129)
449
450    def test_associated_data_none_equal_to_empty_bytestring(self, backend):
451        key = AESGCM.generate_key(128)
452        aesgcm = AESGCM(key)
453        nonce = os.urandom(12)
454        ct1 = aesgcm.encrypt(nonce, b"some_data", None)
455        ct2 = aesgcm.encrypt(nonce, b"some_data", b"")
456        assert ct1 == ct2
457        pt1 = aesgcm.decrypt(nonce, ct1, None)
458        pt2 = aesgcm.decrypt(nonce, ct2, b"")
459        assert pt1 == pt2
460
461    def test_buffer_protocol(self, backend):
462        key = AESGCM.generate_key(128)
463        aesgcm = AESGCM(key)
464        pt = b"encrypt me"
465        ad = b"additional"
466        nonce = os.urandom(12)
467        ct = aesgcm.encrypt(nonce, pt, ad)
468        computed_pt = aesgcm.decrypt(nonce, ct, ad)
469        assert computed_pt == pt
470        aesgcm2 = AESGCM(bytearray(key))
471        ct2 = aesgcm2.encrypt(bytearray(nonce), pt, ad)
472        assert ct2 == ct
473        computed_pt2 = aesgcm2.decrypt(bytearray(nonce), ct2, ad)
474        assert computed_pt2 == pt
475