1 //! The ChaCha20 core function. Defined in RFC 8439 Section 2.3.
2 //!
3 //! <https://tools.ietf.org/html/rfc8439#section-2.3>
4 //!
5 //! AVX2-optimized implementation for x86/x86-64 CPUs adapted from the SUPERCOP
6 //! `goll_gueron` backend (public domain) described in:
7 //!
8 //! Goll, M., and Gueron,S.: Vectorization of ChaCha Stream Cipher. Cryptology ePrint Archive,
9 //! Report 2013/759, November, 2013, <https://eprint.iacr.org/2013/759.pdf>
10 
11 use super::autodetect::BUFFER_SIZE;
12 use crate::{rounds::Rounds, BLOCK_SIZE, CONSTANTS, IV_SIZE, KEY_SIZE};
13 use core::{convert::TryInto, marker::PhantomData};
14 
15 #[cfg(target_arch = "x86")]
16 use core::arch::x86::*;
17 #[cfg(target_arch = "x86_64")]
18 use core::arch::x86_64::*;
19 
20 /// The ChaCha20 core function (AVX2 accelerated implementation for x86/x86_64)
21 // TODO(tarcieri): zeroize?
22 #[derive(Clone)]
23 pub(crate) struct Core<R: Rounds> {
24     v0: __m256i,
25     v1: __m256i,
26     v2: __m256i,
27     iv: [i32; 2],
28     rounds: PhantomData<R>,
29 }
30 
31 impl<R: Rounds> Core<R> {
32     /// Initialize core function with the given key size, IV, and number of rounds
33     #[inline]
new(key: &[u8; KEY_SIZE], iv: [u8; IV_SIZE]) -> Self34     pub fn new(key: &[u8; KEY_SIZE], iv: [u8; IV_SIZE]) -> Self {
35         let (v0, v1, v2) = unsafe { key_setup(key) };
36         let iv = [
37             i32::from_le_bytes(iv[4..].try_into().unwrap()),
38             i32::from_le_bytes(iv[..4].try_into().unwrap()),
39         ];
40 
41         Self {
42             v0,
43             v1,
44             v2,
45             iv,
46             rounds: PhantomData,
47         }
48     }
49 
50     #[inline]
generate(&self, counter: u64, output: &mut [u8])51     pub fn generate(&self, counter: u64, output: &mut [u8]) {
52         unsafe {
53             let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
54             let mut v3 = iv_setup(self.iv, counter);
55             self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
56             store(v0, v1, v2, v3, output);
57         }
58     }
59 
60     #[inline]
61     #[cfg(feature = "cipher")]
62     #[allow(clippy::cast_ptr_alignment)] // loadu/storeu support unaligned loads/stores
apply_keystream(&self, counter: u64, output: &mut [u8])63     pub fn apply_keystream(&self, counter: u64, output: &mut [u8]) {
64         debug_assert_eq!(output.len(), BUFFER_SIZE);
65 
66         unsafe {
67             let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
68             let mut v3 = iv_setup(self.iv, counter);
69             self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
70 
71             for (chunk, a) in output[..BLOCK_SIZE].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
72                 let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
73                 let out = _mm_xor_si128(_mm256_castsi256_si128(*a), b);
74                 _mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, out);
75             }
76 
77             for (chunk, a) in output[BLOCK_SIZE..].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
78                 let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
79                 let out = _mm_xor_si128(_mm256_extractf128_si256(*a, 1), b);
80                 _mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, out);
81             }
82         }
83     }
84 
85     #[inline]
86     #[target_feature(enable = "avx2")]
rounds( &self, v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i, )87     unsafe fn rounds(
88         &self,
89         v0: &mut __m256i,
90         v1: &mut __m256i,
91         v2: &mut __m256i,
92         v3: &mut __m256i,
93     ) {
94         let v3_orig = *v3;
95 
96         for _ in 0..(R::COUNT / 2) {
97             double_quarter_round(v0, v1, v2, v3);
98         }
99 
100         *v0 = _mm256_add_epi32(*v0, self.v0);
101         *v1 = _mm256_add_epi32(*v1, self.v1);
102         *v2 = _mm256_add_epi32(*v2, self.v2);
103         *v3 = _mm256_add_epi32(*v3, v3_orig);
104     }
105 }
106 
107 #[inline]
108 #[target_feature(enable = "avx2")]
109 #[allow(clippy::cast_ptr_alignment)] // loadu supports unaligned loads
key_setup(key: &[u8; KEY_SIZE]) -> (__m256i, __m256i, __m256i)110 unsafe fn key_setup(key: &[u8; KEY_SIZE]) -> (__m256i, __m256i, __m256i) {
111     let v0 = _mm_loadu_si128(CONSTANTS.as_ptr() as *const __m128i);
112     let v1 = _mm_loadu_si128(key.as_ptr().offset(0x00) as *const __m128i);
113     let v2 = _mm_loadu_si128(key.as_ptr().offset(0x10) as *const __m128i);
114 
115     (
116         _mm256_broadcastsi128_si256(v0),
117         _mm256_broadcastsi128_si256(v1),
118         _mm256_broadcastsi128_si256(v2),
119     )
120 }
121 
122 #[inline]
123 #[target_feature(enable = "avx2")]
iv_setup(iv: [i32; 2], counter: u64) -> __m256i124 unsafe fn iv_setup(iv: [i32; 2], counter: u64) -> __m256i {
125     let s3 = _mm_set_epi32(
126         iv[0],
127         iv[1],
128         ((counter >> 32) & 0xffff_ffff) as i32,
129         (counter & 0xffff_ffff) as i32,
130     );
131 
132     _mm256_add_epi64(
133         _mm256_broadcastsi128_si256(s3),
134         _mm256_set_epi64x(0, 1, 0, 0),
135     )
136 }
137 
138 #[inline]
139 #[target_feature(enable = "avx2")]
140 #[allow(clippy::cast_ptr_alignment)] // storeu supports unaligned stores
store(v0: __m256i, v1: __m256i, v2: __m256i, v3: __m256i, output: &mut [u8])141 unsafe fn store(v0: __m256i, v1: __m256i, v2: __m256i, v3: __m256i, output: &mut [u8]) {
142     debug_assert_eq!(output.len(), BUFFER_SIZE);
143 
144     for (chunk, v) in output[..BLOCK_SIZE].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
145         _mm_storeu_si128(
146             chunk.as_mut_ptr() as *mut __m128i,
147             _mm256_castsi256_si128(*v),
148         );
149     }
150 
151     for (chunk, v) in output[BLOCK_SIZE..].chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
152         _mm_storeu_si128(
153             chunk.as_mut_ptr() as *mut __m128i,
154             _mm256_extractf128_si256(*v, 1),
155         );
156     }
157 }
158 
159 #[inline]
160 #[target_feature(enable = "avx2")]
double_quarter_round( v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i, )161 unsafe fn double_quarter_round(
162     v0: &mut __m256i,
163     v1: &mut __m256i,
164     v2: &mut __m256i,
165     v3: &mut __m256i,
166 ) {
167     add_xor_rot(v0, v1, v2, v3);
168     rows_to_cols(v0, v1, v2, v3);
169     add_xor_rot(v0, v1, v2, v3);
170     cols_to_rows(v0, v1, v2, v3);
171 }
172 
173 #[inline]
174 #[target_feature(enable = "avx2")]
rows_to_cols(_v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i)175 unsafe fn rows_to_cols(_v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i) {
176     // b = ROR256_V1(b); c = ROR256_V2(c); d = ROR256_V3(d);
177     *v1 = _mm256_shuffle_epi32(*v1, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
178     *v2 = _mm256_shuffle_epi32(*v2, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
179     *v3 = _mm256_shuffle_epi32(*v3, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
180 }
181 
182 #[inline]
183 #[target_feature(enable = "avx2")]
cols_to_rows(_v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i)184 unsafe fn cols_to_rows(_v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i) {
185     // b = ROR256_V3(b); c = ROR256_V2(c); d = ROR256_V1(d);
186     *v1 = _mm256_shuffle_epi32(*v1, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
187     *v2 = _mm256_shuffle_epi32(*v2, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
188     *v3 = _mm256_shuffle_epi32(*v3, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
189 }
190 
191 #[inline]
192 #[target_feature(enable = "avx2")]
add_xor_rot(v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i)193 unsafe fn add_xor_rot(v0: &mut __m256i, v1: &mut __m256i, v2: &mut __m256i, v3: &mut __m256i) {
194     // a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_16(d);
195     *v0 = _mm256_add_epi32(*v0, *v1);
196     *v3 = _mm256_xor_si256(*v3, *v0);
197     *v3 = _mm256_shuffle_epi8(
198         *v3,
199         _mm256_set_epi8(
200             13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11, 10, 5,
201             4, 7, 6, 1, 0, 3, 2,
202         ),
203     );
204 
205     // c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_12(b);
206     *v2 = _mm256_add_epi32(*v2, *v3);
207     *v1 = _mm256_xor_si256(*v1, *v2);
208     *v1 = _mm256_xor_si256(_mm256_slli_epi32(*v1, 12), _mm256_srli_epi32(*v1, 20));
209 
210     // a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_8(d);
211     *v0 = _mm256_add_epi32(*v0, *v1);
212     *v3 = _mm256_xor_si256(*v3, *v0);
213     *v3 = _mm256_shuffle_epi8(
214         *v3,
215         _mm256_set_epi8(
216             14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8, 11, 6,
217             5, 4, 7, 2, 1, 0, 3,
218         ),
219     );
220 
221     // c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_7(b);
222     *v2 = _mm256_add_epi32(*v2, *v3);
223     *v1 = _mm256_xor_si256(*v1, *v2);
224     *v1 = _mm256_xor_si256(_mm256_slli_epi32(*v1, 7), _mm256_srli_epi32(*v1, 25));
225 }
226