1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package chacha20 implements the ChaCha20 and XChaCha20 encryption algorithms
6// as specified in RFC 8439 and draft-irtf-cfrg-xchacha-01.
7package chacha20
8
9import (
10	"crypto/cipher"
11	"encoding/binary"
12	"errors"
13	"math/bits"
14
15	"golang.org/x/crypto/internal/subtle"
16)
17
18const (
19	// KeySize is the size of the key used by this cipher, in bytes.
20	KeySize = 32
21
22	// NonceSize is the size of the nonce used with the standard variant of this
23	// cipher, in bytes.
24	//
25	// Note that this is too short to be safely generated at random if the same
26	// key is reused more than 2³² times.
27	NonceSize = 12
28
29	// NonceSizeX is the size of the nonce used with the XChaCha20 variant of
30	// this cipher, in bytes.
31	NonceSizeX = 24
32)
33
34// Cipher is a stateful instance of ChaCha20 or XChaCha20 using a particular key
35// and nonce. A *Cipher implements the cipher.Stream interface.
36type Cipher struct {
37	// The ChaCha20 state is 16 words: 4 constant, 8 of key, 1 of counter
38	// (incremented after each block), and 3 of nonce.
39	key     [8]uint32
40	counter uint32
41	nonce   [3]uint32
42
43	// The last len bytes of buf are leftover key stream bytes from the previous
44	// XORKeyStream invocation. The size of buf depends on how many blocks are
45	// computed at a time by xorKeyStreamBlocks.
46	buf [bufSize]byte
47	len int
48
49	// overflow is set when the counter overflowed, no more blocks can be
50	// generated, and the next XORKeyStream call should panic.
51	overflow bool
52
53	// The counter-independent results of the first round are cached after they
54	// are computed the first time.
55	precompDone      bool
56	p1, p5, p9, p13  uint32
57	p2, p6, p10, p14 uint32
58	p3, p7, p11, p15 uint32
59}
60
61var _ cipher.Stream = (*Cipher)(nil)
62
63// NewUnauthenticatedCipher creates a new ChaCha20 stream cipher with the given
64// 32 bytes key and a 12 or 24 bytes nonce. If a nonce of 24 bytes is provided,
65// the XChaCha20 construction will be used. It returns an error if key or nonce
66// have any other length.
67//
68// Note that ChaCha20, like all stream ciphers, is not authenticated and allows
69// attackers to silently tamper with the plaintext. For this reason, it is more
70// appropriate as a building block than as a standalone encryption mechanism.
71// Instead, consider using package golang.org/x/crypto/chacha20poly1305.
72func NewUnauthenticatedCipher(key, nonce []byte) (*Cipher, error) {
73	// This function is split into a wrapper so that the Cipher allocation will
74	// be inlined, and depending on how the caller uses the return value, won't
75	// escape to the heap.
76	c := &Cipher{}
77	return newUnauthenticatedCipher(c, key, nonce)
78}
79
80func newUnauthenticatedCipher(c *Cipher, key, nonce []byte) (*Cipher, error) {
81	if len(key) != KeySize {
82		return nil, errors.New("chacha20: wrong key size")
83	}
84	if len(nonce) == NonceSizeX {
85		// XChaCha20 uses the ChaCha20 core to mix 16 bytes of the nonce into a
86		// derived key, allowing it to operate on a nonce of 24 bytes. See
87		// draft-irtf-cfrg-xchacha-01, Section 2.3.
88		key, _ = HChaCha20(key, nonce[0:16])
89		cNonce := make([]byte, NonceSize)
90		copy(cNonce[4:12], nonce[16:24])
91		nonce = cNonce
92	} else if len(nonce) != NonceSize {
93		return nil, errors.New("chacha20: wrong nonce size")
94	}
95
96	key, nonce = key[:KeySize], nonce[:NonceSize] // bounds check elimination hint
97	c.key = [8]uint32{
98		binary.LittleEndian.Uint32(key[0:4]),
99		binary.LittleEndian.Uint32(key[4:8]),
100		binary.LittleEndian.Uint32(key[8:12]),
101		binary.LittleEndian.Uint32(key[12:16]),
102		binary.LittleEndian.Uint32(key[16:20]),
103		binary.LittleEndian.Uint32(key[20:24]),
104		binary.LittleEndian.Uint32(key[24:28]),
105		binary.LittleEndian.Uint32(key[28:32]),
106	}
107	c.nonce = [3]uint32{
108		binary.LittleEndian.Uint32(nonce[0:4]),
109		binary.LittleEndian.Uint32(nonce[4:8]),
110		binary.LittleEndian.Uint32(nonce[8:12]),
111	}
112	return c, nil
113}
114
115// The constant first 4 words of the ChaCha20 state.
116const (
117	j0 uint32 = 0x61707865 // expa
118	j1 uint32 = 0x3320646e // nd 3
119	j2 uint32 = 0x79622d32 // 2-by
120	j3 uint32 = 0x6b206574 // te k
121)
122
123const blockSize = 64
124
125// quarterRound is the core of ChaCha20. It shuffles the bits of 4 state words.
126// It's executed 4 times for each of the 20 ChaCha20 rounds, operating on all 16
127// words each round, in columnar or diagonal groups of 4 at a time.
128func quarterRound(a, b, c, d uint32) (uint32, uint32, uint32, uint32) {
129	a += b
130	d ^= a
131	d = bits.RotateLeft32(d, 16)
132	c += d
133	b ^= c
134	b = bits.RotateLeft32(b, 12)
135	a += b
136	d ^= a
137	d = bits.RotateLeft32(d, 8)
138	c += d
139	b ^= c
140	b = bits.RotateLeft32(b, 7)
141	return a, b, c, d
142}
143
144// SetCounter sets the Cipher counter. The next invocation of XORKeyStream will
145// behave as if (64 * counter) bytes had been encrypted so far.
146//
147// To prevent accidental counter reuse, SetCounter panics if counter is less
148// than the current value.
149//
150// Note that the execution time of XORKeyStream is not independent of the
151// counter value.
152func (s *Cipher) SetCounter(counter uint32) {
153	// Internally, s may buffer multiple blocks, which complicates this
154	// implementation slightly. When checking whether the counter has rolled
155	// back, we must use both s.counter and s.len to determine how many blocks
156	// we have already output.
157	outputCounter := s.counter - uint32(s.len)/blockSize
158	if s.overflow || counter < outputCounter {
159		panic("chacha20: SetCounter attempted to rollback counter")
160	}
161
162	// In the general case, we set the new counter value and reset s.len to 0,
163	// causing the next call to XORKeyStream to refill the buffer. However, if
164	// we're advancing within the existing buffer, we can save work by simply
165	// setting s.len.
166	if counter < s.counter {
167		s.len = int(s.counter-counter) * blockSize
168	} else {
169		s.counter = counter
170		s.len = 0
171	}
172}
173
174// XORKeyStream XORs each byte in the given slice with a byte from the
175// cipher's key stream. Dst and src must overlap entirely or not at all.
176//
177// If len(dst) < len(src), XORKeyStream will panic. It is acceptable
178// to pass a dst bigger than src, and in that case, XORKeyStream will
179// only update dst[:len(src)] and will not touch the rest of dst.
180//
181// Multiple calls to XORKeyStream behave as if the concatenation of
182// the src buffers was passed in a single run. That is, Cipher
183// maintains state and does not reset at each XORKeyStream call.
184func (s *Cipher) XORKeyStream(dst, src []byte) {
185	if len(src) == 0 {
186		return
187	}
188	if len(dst) < len(src) {
189		panic("chacha20: output smaller than input")
190	}
191	dst = dst[:len(src)]
192	if subtle.InexactOverlap(dst, src) {
193		panic("chacha20: invalid buffer overlap")
194	}
195
196	// First, drain any remaining key stream from a previous XORKeyStream.
197	if s.len != 0 {
198		keyStream := s.buf[bufSize-s.len:]
199		if len(src) < len(keyStream) {
200			keyStream = keyStream[:len(src)]
201		}
202		_ = src[len(keyStream)-1] // bounds check elimination hint
203		for i, b := range keyStream {
204			dst[i] = src[i] ^ b
205		}
206		s.len -= len(keyStream)
207		dst, src = dst[len(keyStream):], src[len(keyStream):]
208	}
209	if len(src) == 0 {
210		return
211	}
212
213	// If we'd need to let the counter overflow and keep generating output,
214	// panic immediately. If instead we'd only reach the last block, remember
215	// not to generate any more output after the buffer is drained.
216	numBlocks := (uint64(len(src)) + blockSize - 1) / blockSize
217	if s.overflow || uint64(s.counter)+numBlocks > 1<<32 {
218		panic("chacha20: counter overflow")
219	} else if uint64(s.counter)+numBlocks == 1<<32 {
220		s.overflow = true
221	}
222
223	// xorKeyStreamBlocks implementations expect input lengths that are a
224	// multiple of bufSize. Platform-specific ones process multiple blocks at a
225	// time, so have bufSizes that are a multiple of blockSize.
226
227	full := len(src) - len(src)%bufSize
228	if full > 0 {
229		s.xorKeyStreamBlocks(dst[:full], src[:full])
230	}
231	dst, src = dst[full:], src[full:]
232
233	// If using a multi-block xorKeyStreamBlocks would overflow, use the generic
234	// one that does one block at a time.
235	const blocksPerBuf = bufSize / blockSize
236	if uint64(s.counter)+blocksPerBuf > 1<<32 {
237		s.buf = [bufSize]byte{}
238		numBlocks := (len(src) + blockSize - 1) / blockSize
239		buf := s.buf[bufSize-numBlocks*blockSize:]
240		copy(buf, src)
241		s.xorKeyStreamBlocksGeneric(buf, buf)
242		s.len = len(buf) - copy(dst, buf)
243		return
244	}
245
246	// If we have a partial (multi-)block, pad it for xorKeyStreamBlocks, and
247	// keep the leftover keystream for the next XORKeyStream invocation.
248	if len(src) > 0 {
249		s.buf = [bufSize]byte{}
250		copy(s.buf[:], src)
251		s.xorKeyStreamBlocks(s.buf[:], s.buf[:])
252		s.len = bufSize - copy(dst, s.buf[:])
253	}
254}
255
256func (s *Cipher) xorKeyStreamBlocksGeneric(dst, src []byte) {
257	if len(dst) != len(src) || len(dst)%blockSize != 0 {
258		panic("chacha20: internal error: wrong dst and/or src length")
259	}
260
261	// To generate each block of key stream, the initial cipher state
262	// (represented below) is passed through 20 rounds of shuffling,
263	// alternatively applying quarterRounds by columns (like 1, 5, 9, 13)
264	// or by diagonals (like 1, 6, 11, 12).
265	//
266	//      0:cccccccc   1:cccccccc   2:cccccccc   3:cccccccc
267	//      4:kkkkkkkk   5:kkkkkkkk   6:kkkkkkkk   7:kkkkkkkk
268	//      8:kkkkkkkk   9:kkkkkkkk  10:kkkkkkkk  11:kkkkkkkk
269	//     12:bbbbbbbb  13:nnnnnnnn  14:nnnnnnnn  15:nnnnnnnn
270	//
271	//            c=constant k=key b=blockcount n=nonce
272	var (
273		c0, c1, c2, c3   = j0, j1, j2, j3
274		c4, c5, c6, c7   = s.key[0], s.key[1], s.key[2], s.key[3]
275		c8, c9, c10, c11 = s.key[4], s.key[5], s.key[6], s.key[7]
276		_, c13, c14, c15 = s.counter, s.nonce[0], s.nonce[1], s.nonce[2]
277	)
278
279	// Three quarters of the first round don't depend on the counter, so we can
280	// calculate them here, and reuse them for multiple blocks in the loop, and
281	// for future XORKeyStream invocations.
282	if !s.precompDone {
283		s.p1, s.p5, s.p9, s.p13 = quarterRound(c1, c5, c9, c13)
284		s.p2, s.p6, s.p10, s.p14 = quarterRound(c2, c6, c10, c14)
285		s.p3, s.p7, s.p11, s.p15 = quarterRound(c3, c7, c11, c15)
286		s.precompDone = true
287	}
288
289	// A condition of len(src) > 0 would be sufficient, but this also
290	// acts as a bounds check elimination hint.
291	for len(src) >= 64 && len(dst) >= 64 {
292		// The remainder of the first column round.
293		fcr0, fcr4, fcr8, fcr12 := quarterRound(c0, c4, c8, s.counter)
294
295		// The second diagonal round.
296		x0, x5, x10, x15 := quarterRound(fcr0, s.p5, s.p10, s.p15)
297		x1, x6, x11, x12 := quarterRound(s.p1, s.p6, s.p11, fcr12)
298		x2, x7, x8, x13 := quarterRound(s.p2, s.p7, fcr8, s.p13)
299		x3, x4, x9, x14 := quarterRound(s.p3, fcr4, s.p9, s.p14)
300
301		// The remaining 18 rounds.
302		for i := 0; i < 9; i++ {
303			// Column round.
304			x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
305			x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
306			x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
307			x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
308
309			// Diagonal round.
310			x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
311			x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
312			x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
313			x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
314		}
315
316		// Add back the initial state to generate the key stream, then
317		// XOR the key stream with the source and write out the result.
318		addXor(dst[0:4], src[0:4], x0, c0)
319		addXor(dst[4:8], src[4:8], x1, c1)
320		addXor(dst[8:12], src[8:12], x2, c2)
321		addXor(dst[12:16], src[12:16], x3, c3)
322		addXor(dst[16:20], src[16:20], x4, c4)
323		addXor(dst[20:24], src[20:24], x5, c5)
324		addXor(dst[24:28], src[24:28], x6, c6)
325		addXor(dst[28:32], src[28:32], x7, c7)
326		addXor(dst[32:36], src[32:36], x8, c8)
327		addXor(dst[36:40], src[36:40], x9, c9)
328		addXor(dst[40:44], src[40:44], x10, c10)
329		addXor(dst[44:48], src[44:48], x11, c11)
330		addXor(dst[48:52], src[48:52], x12, s.counter)
331		addXor(dst[52:56], src[52:56], x13, c13)
332		addXor(dst[56:60], src[56:60], x14, c14)
333		addXor(dst[60:64], src[60:64], x15, c15)
334
335		s.counter += 1
336
337		src, dst = src[blockSize:], dst[blockSize:]
338	}
339}
340
341// HChaCha20 uses the ChaCha20 core to generate a derived key from a 32 bytes
342// key and a 16 bytes nonce. It returns an error if key or nonce have any other
343// length. It is used as part of the XChaCha20 construction.
344func HChaCha20(key, nonce []byte) ([]byte, error) {
345	// This function is split into a wrapper so that the slice allocation will
346	// be inlined, and depending on how the caller uses the return value, won't
347	// escape to the heap.
348	out := make([]byte, 32)
349	return hChaCha20(out, key, nonce)
350}
351
352func hChaCha20(out, key, nonce []byte) ([]byte, error) {
353	if len(key) != KeySize {
354		return nil, errors.New("chacha20: wrong HChaCha20 key size")
355	}
356	if len(nonce) != 16 {
357		return nil, errors.New("chacha20: wrong HChaCha20 nonce size")
358	}
359
360	x0, x1, x2, x3 := j0, j1, j2, j3
361	x4 := binary.LittleEndian.Uint32(key[0:4])
362	x5 := binary.LittleEndian.Uint32(key[4:8])
363	x6 := binary.LittleEndian.Uint32(key[8:12])
364	x7 := binary.LittleEndian.Uint32(key[12:16])
365	x8 := binary.LittleEndian.Uint32(key[16:20])
366	x9 := binary.LittleEndian.Uint32(key[20:24])
367	x10 := binary.LittleEndian.Uint32(key[24:28])
368	x11 := binary.LittleEndian.Uint32(key[28:32])
369	x12 := binary.LittleEndian.Uint32(nonce[0:4])
370	x13 := binary.LittleEndian.Uint32(nonce[4:8])
371	x14 := binary.LittleEndian.Uint32(nonce[8:12])
372	x15 := binary.LittleEndian.Uint32(nonce[12:16])
373
374	for i := 0; i < 10; i++ {
375		// Diagonal round.
376		x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
377		x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
378		x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
379		x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
380
381		// Column round.
382		x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
383		x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
384		x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
385		x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
386	}
387
388	_ = out[31] // bounds check elimination hint
389	binary.LittleEndian.PutUint32(out[0:4], x0)
390	binary.LittleEndian.PutUint32(out[4:8], x1)
391	binary.LittleEndian.PutUint32(out[8:12], x2)
392	binary.LittleEndian.PutUint32(out[12:16], x3)
393	binary.LittleEndian.PutUint32(out[16:20], x12)
394	binary.LittleEndian.PutUint32(out[20:24], x13)
395	binary.LittleEndian.PutUint32(out[24:28], x14)
396	binary.LittleEndian.PutUint32(out[28:32], x15)
397	return out, nil
398}
399