1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package chacha20
6
7import (
8	"bytes"
9	"encoding/hex"
10	"fmt"
11	"math/rand"
12	"testing"
13)
14
15func _() {
16	// Assert that bufSize is a multiple of blockSize.
17	var b [1]byte
18	_ = b[bufSize%blockSize]
19}
20
21func hexDecode(s string) []byte {
22	ss, err := hex.DecodeString(s)
23	if err != nil {
24		panic(fmt.Sprintf("cannot decode input %#v: %v", s, err))
25	}
26	return ss
27}
28
29// Run the test cases with the input and output in different buffers.
30func TestNoOverlap(t *testing.T) {
31	for _, c := range testVectors {
32		s, _ := NewUnauthenticatedCipher(hexDecode(c.key), hexDecode(c.nonce))
33		input := hexDecode(c.input)
34		output := make([]byte, len(input))
35		s.XORKeyStream(output, input)
36		got := hex.EncodeToString(output)
37		if got != c.output {
38			t.Errorf("length=%v: got %#v, want %#v", len(input), got, c.output)
39		}
40	}
41}
42
43// Run the test cases with the input and output overlapping entirely.
44func TestOverlap(t *testing.T) {
45	for _, c := range testVectors {
46		s, _ := NewUnauthenticatedCipher(hexDecode(c.key), hexDecode(c.nonce))
47		data := hexDecode(c.input)
48		s.XORKeyStream(data, data)
49		got := hex.EncodeToString(data)
50		if got != c.output {
51			t.Errorf("length=%v: got %#v, want %#v", len(data), got, c.output)
52		}
53	}
54}
55
56// Run the test cases with various source and destination offsets.
57func TestUnaligned(t *testing.T) {
58	const max = 8 // max offset (+1) to test
59	for _, c := range testVectors {
60		data := hexDecode(c.input)
61		input := make([]byte, len(data)+max)
62		output := make([]byte, len(data)+max)
63		for i := 0; i < max; i++ { // input offsets
64			for j := 0; j < max; j++ { // output offsets
65				s, _ := NewUnauthenticatedCipher(hexDecode(c.key), hexDecode(c.nonce))
66
67				input := input[i : i+len(data)]
68				output := output[j : j+len(data)]
69
70				copy(input, data)
71				s.XORKeyStream(output, input)
72				got := hex.EncodeToString(output)
73				if got != c.output {
74					t.Errorf("length=%v: got %#v, want %#v", len(data), got, c.output)
75				}
76			}
77		}
78	}
79}
80
81// Run the test cases by calling XORKeyStream multiple times.
82func TestStep(t *testing.T) {
83	// wide range of step sizes to try and hit edge cases
84	steps := [...]int{1, 3, 4, 7, 8, 17, 24, 30, 64, 256}
85	rnd := rand.New(rand.NewSource(123))
86	for _, c := range testVectors {
87		s, _ := NewUnauthenticatedCipher(hexDecode(c.key), hexDecode(c.nonce))
88		input := hexDecode(c.input)
89		output := make([]byte, len(input))
90
91		// step through the buffers
92		i, step := 0, steps[rnd.Intn(len(steps))]
93		for i+step < len(input) {
94			s.XORKeyStream(output[i:i+step], input[i:i+step])
95			if i+step < len(input) && output[i+step] != 0 {
96				t.Errorf("length=%v, i=%v, step=%v: output overwritten", len(input), i, step)
97			}
98			i += step
99			step = steps[rnd.Intn(len(steps))]
100		}
101		// finish the encryption
102		s.XORKeyStream(output[i:], input[i:])
103		// ensure we tolerate a call with an empty input
104		s.XORKeyStream(output[len(output):], input[len(input):])
105
106		got := hex.EncodeToString(output)
107		if got != c.output {
108			t.Errorf("length=%v: got %#v, want %#v", len(input), got, c.output)
109		}
110	}
111}
112
113func TestSetCounter(t *testing.T) {
114	newCipher := func() *Cipher {
115		s, _ := NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
116		return s
117	}
118	s := newCipher()
119	src := bytes.Repeat([]byte("test"), 32) // two 64-byte blocks
120	dst1 := make([]byte, len(src))
121	s.XORKeyStream(dst1, src)
122	// advance counter to 1 and xor second block
123	s = newCipher()
124	s.SetCounter(1)
125	dst2 := make([]byte, len(src))
126	s.XORKeyStream(dst2[64:], src[64:])
127	if !bytes.Equal(dst1[64:], dst2[64:]) {
128		t.Error("failed to produce identical output using SetCounter")
129	}
130
131	// test again with unaligned blocks; SetCounter should reset the buffer
132	s = newCipher()
133	s.XORKeyStream(dst1[:70], src[:70])
134	s = newCipher()
135	s.XORKeyStream([]byte{0}, []byte{0})
136	s.SetCounter(1)
137	s.XORKeyStream(dst2[64:70], src[64:70])
138	if !bytes.Equal(dst1[64:70], dst2[64:70]) {
139		t.Error("SetCounter did not reset buffer")
140	}
141
142	// advancing to a lower counter value should cause a panic
143	panics := func(fn func()) (p bool) {
144		defer func() { p = recover() != nil }()
145		fn()
146		return
147	}
148	if !panics(func() { s.SetCounter(0) }) {
149		t.Error("counter decreasing should trigger a panic")
150	}
151}
152
153func TestLastBlock(t *testing.T) {
154	panics := func(fn func()) (p bool) {
155		defer func() { p = recover() != nil }()
156		fn()
157		return
158	}
159
160	checkLastBlock := func(b []byte) {
161		t.Helper()
162		// Hardcoded result to check all implementations generate the same output.
163		lastBlock := "ace4cd09e294d1912d4ad205d06f95d9c2f2bfcf453e8753f128765b62215f4d" +
164			"92c74f2f626c6a640c0b1284d839ec81f1696281dafc3e684593937023b58b1d"
165		if got := hex.EncodeToString(b); got != lastBlock {
166			t.Errorf("wrong output for the last block, got %q, want %q", got, lastBlock)
167		}
168	}
169
170	// setting the counter to 0xffffffff and crypting multiple blocks should
171	// trigger a panic
172	s, _ := NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
173	s.SetCounter(0xffffffff)
174	blocks := make([]byte, blockSize*2)
175	if !panics(func() { s.XORKeyStream(blocks, blocks) }) {
176		t.Error("crypting multiple blocks should trigger a panic")
177	}
178
179	// setting the counter to 0xffffffff - 1 and crypting two blocks should not
180	// trigger a panic
181	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
182	s.SetCounter(0xffffffff - 1)
183	if panics(func() { s.XORKeyStream(blocks, blocks) }) {
184		t.Error("crypting the last blocks should not trigger a panic")
185	}
186	checkLastBlock(blocks[blockSize:])
187	// once all the keystream is spent, setting the counter should panic
188	if !panics(func() { s.SetCounter(0xffffffff) }) {
189		t.Error("setting the counter after overflow should trigger a panic")
190	}
191	// crypting a subsequent block *should* panic
192	block := make([]byte, blockSize)
193	if !panics(func() { s.XORKeyStream(block, block) }) {
194		t.Error("crypting after overflow should trigger a panic")
195	}
196
197	// if we crypt less than a full block, we should be able to crypt the rest
198	// in a subsequent call without panicking
199	s, _ = NewUnauthenticatedCipher(make([]byte, KeySize), make([]byte, NonceSize))
200	s.SetCounter(0xffffffff)
201	if panics(func() { s.XORKeyStream(block[:7], block[:7]) }) {
202		t.Error("crypting part of the last block should not trigger a panic")
203	}
204	if panics(func() { s.XORKeyStream(block[7:], block[7:]) }) {
205		t.Error("crypting part of the last block should not trigger a panic")
206	}
207	checkLastBlock(block)
208	// as before, a third call should trigger a panic because all keystream is spent
209	if !panics(func() { s.XORKeyStream(block[:1], block[:1]) }) {
210		t.Error("crypting after overflow should trigger a panic")
211	}
212}
213
214func benchmarkChaCha20(b *testing.B, step, count int) {
215	tot := step * count
216	src := make([]byte, tot)
217	dst := make([]byte, tot)
218	key := make([]byte, KeySize)
219	nonce := make([]byte, NonceSize)
220	b.SetBytes(int64(tot))
221	b.ResetTimer()
222	for i := 0; i < b.N; i++ {
223		c, _ := NewUnauthenticatedCipher(key, nonce)
224		for i := 0; i < tot; i += step {
225			c.XORKeyStream(dst[i:], src[i:i+step])
226		}
227	}
228}
229
230func BenchmarkChaCha20(b *testing.B) {
231	b.Run("64", func(b *testing.B) {
232		benchmarkChaCha20(b, 64, 1)
233	})
234	b.Run("256", func(b *testing.B) {
235		benchmarkChaCha20(b, 256, 1)
236	})
237	b.Run("10x25", func(b *testing.B) {
238		benchmarkChaCha20(b, 10, 25)
239	})
240	b.Run("4096", func(b *testing.B) {
241		benchmarkChaCha20(b, 4096, 1)
242	})
243	b.Run("100x40", func(b *testing.B) {
244		benchmarkChaCha20(b, 100, 40)
245	})
246	b.Run("65536", func(b *testing.B) {
247		benchmarkChaCha20(b, 65536, 1)
248	})
249	b.Run("1000x65", func(b *testing.B) {
250		benchmarkChaCha20(b, 1000, 65)
251	})
252}
253
254func TestHChaCha20(t *testing.T) {
255	// See draft-irtf-cfrg-xchacha-00, Section 2.2.1.
256	key := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
257		0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
258		0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
259		0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f}
260	nonce := []byte{0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x4a,
261		0x00, 0x00, 0x00, 0x00, 0x31, 0x41, 0x59, 0x27}
262	expected := []byte{0x82, 0x41, 0x3b, 0x42, 0x27, 0xb2, 0x7b, 0xfe,
263		0xd3, 0x0e, 0x42, 0x50, 0x8a, 0x87, 0x7d, 0x73,
264		0xa0, 0xf9, 0xe4, 0xd5, 0x8a, 0x74, 0xa8, 0x53,
265		0xc1, 0x2e, 0xc4, 0x13, 0x26, 0xd3, 0xec, 0xdc,
266	}
267	result, err := HChaCha20(key[:], nonce[:])
268	if err != nil {
269		t.Fatal(err)
270	}
271	if !bytes.Equal(expected, result) {
272		t.Errorf("want %x, got %x", expected, result)
273	}
274}
275