1# Copyright 2013-2018 Donald Stufft and individual contributors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from __future__ import absolute_import, division, print_function
16
17import binascii
18import json
19import os
20import random
21
22import pytest
23
24import six
25
26from nacl._sodium import ffi
27from nacl.bindings.crypto_secretstream import (
28    crypto_secretstream_xchacha20poly1305_ABYTES,
29    crypto_secretstream_xchacha20poly1305_HEADERBYTES,
30    crypto_secretstream_xchacha20poly1305_KEYBYTES,
31    crypto_secretstream_xchacha20poly1305_STATEBYTES,
32    crypto_secretstream_xchacha20poly1305_TAG_FINAL,
33    crypto_secretstream_xchacha20poly1305_TAG_MESSAGE,
34    crypto_secretstream_xchacha20poly1305_TAG_PUSH,
35    crypto_secretstream_xchacha20poly1305_TAG_REKEY,
36    crypto_secretstream_xchacha20poly1305_init_pull,
37    crypto_secretstream_xchacha20poly1305_init_push,
38    crypto_secretstream_xchacha20poly1305_keygen,
39    crypto_secretstream_xchacha20poly1305_pull,
40    crypto_secretstream_xchacha20poly1305_push,
41    crypto_secretstream_xchacha20poly1305_rekey,
42    crypto_secretstream_xchacha20poly1305_state,
43)
44from nacl.utils import random as randombytes
45
46
47def read_secretstream_vectors():
48    DATA = "secretstream-test-vectors.json"
49    path = os.path.join(os.path.dirname(__file__), "data", DATA)
50    with open(path, 'r') as fp:
51        jvectors = json.load(fp)
52    unhex = binascii.unhexlify
53    vectors = [
54        [
55            unhex(v['key']),
56            unhex(v['header']),
57            [
58                [
59                    c['tag'],
60                    unhex(c['ad']) if c['ad'] is not None else None,
61                    unhex(c['message']),
62                    unhex(c['ciphertext']),
63                ]
64                for c in v['chunks']
65            ],
66        ]
67        for v in jvectors
68    ]
69    return vectors
70
71
72@pytest.mark.parametrize(
73    ('key', 'header', 'chunks'),
74    read_secretstream_vectors(),
75)
76def test_vectors(key, header, chunks):
77    state = crypto_secretstream_xchacha20poly1305_state()
78    crypto_secretstream_xchacha20poly1305_init_pull(state, header, key)
79    for tag, ad, message, ciphertext in chunks:
80        m, t = crypto_secretstream_xchacha20poly1305_pull(
81            state, ciphertext, ad)
82        assert m == message
83        assert t == tag
84
85
86def test_it_like_libsodium():
87    ad_len = random.randint(1, 100)
88    m1_len = random.randint(1, 1000)
89    m2_len = random.randint(1, 1000)
90    m3_len = random.randint(1, 1000)
91
92    ad = randombytes(ad_len)
93    m1 = randombytes(m1_len)
94    m2 = randombytes(m2_len)
95    m3 = randombytes(m3_len)
96    m1_ = m1[:]
97    m2_ = m2[:]
98    m3_ = m3[:]
99
100    state = crypto_secretstream_xchacha20poly1305_state()
101
102    k = crypto_secretstream_xchacha20poly1305_keygen()
103    assert len(k) == crypto_secretstream_xchacha20poly1305_KEYBYTES
104
105    # push
106
107    assert len(state.statebuf) == \
108        crypto_secretstream_xchacha20poly1305_STATEBYTES
109    header = crypto_secretstream_xchacha20poly1305_init_push(state, k)
110    assert len(header) == crypto_secretstream_xchacha20poly1305_HEADERBYTES
111
112    c1 = crypto_secretstream_xchacha20poly1305_push(state, m1)
113    assert len(c1) == m1_len + crypto_secretstream_xchacha20poly1305_ABYTES
114
115    c2 = crypto_secretstream_xchacha20poly1305_push(state, m2, ad)
116    assert len(c2) == m2_len + crypto_secretstream_xchacha20poly1305_ABYTES
117
118    c3 = crypto_secretstream_xchacha20poly1305_push(
119        state, m3, ad=ad, tag=crypto_secretstream_xchacha20poly1305_TAG_FINAL)
120    assert len(c3) == m3_len + crypto_secretstream_xchacha20poly1305_ABYTES
121
122    # pull
123
124    crypto_secretstream_xchacha20poly1305_init_pull(state, header, k)
125
126    m1, tag = crypto_secretstream_xchacha20poly1305_pull(state, c1)
127    assert tag == crypto_secretstream_xchacha20poly1305_TAG_MESSAGE
128    assert m1 == m1_
129
130    m2, tag = crypto_secretstream_xchacha20poly1305_pull(state, c2, ad)
131    assert tag == crypto_secretstream_xchacha20poly1305_TAG_MESSAGE
132    assert m2 == m2_
133
134    with pytest.raises(RuntimeError) as excinfo:
135        crypto_secretstream_xchacha20poly1305_pull(state, c3)
136    assert str(excinfo.value) == 'Unexpected failure'
137    m3, tag = crypto_secretstream_xchacha20poly1305_pull(state, c3, ad)
138    assert tag == crypto_secretstream_xchacha20poly1305_TAG_FINAL
139    assert m3 == m3_
140
141    # previous with FINAL tag
142
143    with pytest.raises(RuntimeError) as excinfo:
144        crypto_secretstream_xchacha20poly1305_pull(state, c3, ad)
145    assert str(excinfo.value) == 'Unexpected failure'
146
147    # previous without a tag
148
149    with pytest.raises(RuntimeError) as excinfo:
150        crypto_secretstream_xchacha20poly1305_pull(state, c2, None)
151    assert str(excinfo.value) == 'Unexpected failure'
152
153    # short ciphertext
154
155    with pytest.raises(ValueError) as excinfo:
156        c2len = random.randint(
157            1,
158            crypto_secretstream_xchacha20poly1305_ABYTES - 1
159        )
160        crypto_secretstream_xchacha20poly1305_pull(state, c2[:c2len])
161    assert str(excinfo.value) == 'Ciphertext is too short'
162    with pytest.raises(ValueError) as excinfo:
163        crypto_secretstream_xchacha20poly1305_pull(state, b'')
164    assert str(excinfo.value) == 'Ciphertext is too short'
165
166    # empty ciphertext
167
168    with pytest.raises(ValueError) as excinfo:
169        crypto_secretstream_xchacha20poly1305_pull(
170            state,
171            c2[:crypto_secretstream_xchacha20poly1305_ABYTES - 1],
172            None,
173        )
174    assert str(excinfo.value) == 'Ciphertext is too short'
175
176    # without explicit rekeying
177
178    header = crypto_secretstream_xchacha20poly1305_init_push(state, k)
179    c1 = crypto_secretstream_xchacha20poly1305_push(state, m1)
180    c2 = crypto_secretstream_xchacha20poly1305_push(state, m2)
181
182    crypto_secretstream_xchacha20poly1305_init_pull(state, header, k)
183    m1, tag = crypto_secretstream_xchacha20poly1305_pull(state, c1)
184    assert m1 == m1_
185    m2, tag = crypto_secretstream_xchacha20poly1305_pull(state, c2)
186    assert m2 == m2_
187
188    # with explicit rekeying
189
190    header = crypto_secretstream_xchacha20poly1305_init_push(state, k)
191    c1 = crypto_secretstream_xchacha20poly1305_push(state, m1)
192
193    crypto_secretstream_xchacha20poly1305_rekey(state)
194
195    c2 = crypto_secretstream_xchacha20poly1305_push(state, m2)
196
197    crypto_secretstream_xchacha20poly1305_init_pull(state, header, k)
198    m1, tag = crypto_secretstream_xchacha20poly1305_pull(state, c1)
199    assert m1 == m1_
200
201    with pytest.raises(RuntimeError):
202        crypto_secretstream_xchacha20poly1305_pull(state, c2)
203
204    crypto_secretstream_xchacha20poly1305_rekey(state)
205
206    m2, tag = crypto_secretstream_xchacha20poly1305_pull(state, c2)
207    assert m2 == m2_
208
209    # with explicit rekeying using TAG_REKEY
210
211    header = crypto_secretstream_xchacha20poly1305_init_push(state, k)
212
213    state_save = ffi.buffer(state.statebuf)[:]
214
215    c1 = crypto_secretstream_xchacha20poly1305_push(
216        state, m1, tag=crypto_secretstream_xchacha20poly1305_TAG_REKEY)
217
218    c2 = crypto_secretstream_xchacha20poly1305_push(state, m2)
219
220    csave = c2[:]
221
222    crypto_secretstream_xchacha20poly1305_init_pull(state, header, k)
223    m1, tag = crypto_secretstream_xchacha20poly1305_pull(state, c1)
224    assert m1 == m1_
225    assert tag == crypto_secretstream_xchacha20poly1305_TAG_REKEY
226
227    m2, tag = crypto_secretstream_xchacha20poly1305_pull(state, c2)
228    assert m2 == m2_
229    assert tag == crypto_secretstream_xchacha20poly1305_TAG_MESSAGE
230
231    # avoid using from_buffer until at least cffi >= 1.10 in setup.py
232    # state = ffi.from_buffer(state_save)
233    for i in range(crypto_secretstream_xchacha20poly1305_STATEBYTES):
234        state.statebuf[i] = six.indexbytes(state_save, i)
235
236    c1 = crypto_secretstream_xchacha20poly1305_push(state, m1)
237
238    c2 = crypto_secretstream_xchacha20poly1305_push(state, m2)
239    assert csave != c2
240
241    # New stream
242
243    header = crypto_secretstream_xchacha20poly1305_init_push(state, k)
244
245    c1 = crypto_secretstream_xchacha20poly1305_push(
246        state, m1, tag=crypto_secretstream_xchacha20poly1305_TAG_PUSH)
247    assert len(c1) == m1_len + crypto_secretstream_xchacha20poly1305_ABYTES
248
249    # snip tests that require introspection into the state buffer
250    # to test the nonce as we're using an opaque pointer
251
252
253def test_max_message_size(monkeypatch):
254    import nacl.bindings.crypto_secretstream as css
255    # we want to create an oversized message but don't want to blow out
256    # memory so knock it down a bit for this test
257    monkeypatch.setattr(
258        css,
259        'crypto_secretstream_xchacha20poly1305_MESSAGEBYTES_MAX',
260        2**10 - 1,
261    )
262    m = b'0' * (css.crypto_secretstream_xchacha20poly1305_MESSAGEBYTES_MAX + 1)
263    k = crypto_secretstream_xchacha20poly1305_keygen()
264    state = crypto_secretstream_xchacha20poly1305_state()
265    crypto_secretstream_xchacha20poly1305_init_push(state, k)
266    with pytest.raises(ValueError) as excinfo:
267        crypto_secretstream_xchacha20poly1305_push(state, m, None, 0)
268    assert str(excinfo.value) == 'Message is too long'
269