1 //! The ChaCha20 core function. Defined in RFC 8439 Section 2.3.
2 //!
3 //! <https://tools.ietf.org/html/rfc8439#section-2.3>
4 //!
5 //! SSE2-optimized implementation for x86/x86-64 CPUs.
6 
7 use super::autodetect::BUFFER_SIZE;
8 use crate::{rounds::Rounds, BLOCK_SIZE, CONSTANTS, IV_SIZE, KEY_SIZE};
9 use core::{convert::TryInto, marker::PhantomData};
10 
11 #[cfg(target_arch = "x86")]
12 use core::arch::x86::*;
13 #[cfg(target_arch = "x86_64")]
14 use core::arch::x86_64::*;
15 
16 /// The ChaCha20 core function (SSE2 accelerated implementation for x86/x86_64)
17 // TODO(tarcieri): zeroize?
18 #[derive(Clone)]
19 pub struct Core<R: Rounds> {
20     v0: __m128i,
21     v1: __m128i,
22     v2: __m128i,
23     iv: [i32; 2],
24     rounds: PhantomData<R>,
25 }
26 
27 impl<R: Rounds> Core<R> {
28     /// Initialize core function with the given key size, IV, and number of rounds
29     #[inline]
new(key: &[u8; KEY_SIZE], iv: [u8; IV_SIZE]) -> Self30     pub fn new(key: &[u8; KEY_SIZE], iv: [u8; IV_SIZE]) -> Self {
31         let (v0, v1, v2) = unsafe { key_setup(key) };
32         let iv = [
33             i32::from_le_bytes(iv[4..].try_into().unwrap()),
34             i32::from_le_bytes(iv[..4].try_into().unwrap()),
35         ];
36 
37         Self {
38             v0,
39             v1,
40             v2,
41             iv,
42             rounds: PhantomData,
43         }
44     }
45 
46     #[inline]
generate(&self, counter: u64, output: &mut [u8])47     pub fn generate(&self, counter: u64, output: &mut [u8]) {
48         debug_assert_eq!(output.len(), BUFFER_SIZE);
49 
50         for (i, chunk) in output.chunks_exact_mut(BLOCK_SIZE).enumerate() {
51             unsafe {
52                 let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
53                 let mut v3 = iv_setup(self.iv, counter.checked_add(i as u64).unwrap());
54                 self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
55                 store(v0, v1, v2, v3, chunk)
56             }
57         }
58     }
59 
60     #[inline]
61     #[cfg(feature = "cipher")]
62     #[allow(clippy::cast_ptr_alignment)] // loadu/storeu support unaligned loads/stores
apply_keystream(&self, counter: u64, output: &mut [u8])63     pub fn apply_keystream(&self, counter: u64, output: &mut [u8]) {
64         debug_assert_eq!(output.len(), BUFFER_SIZE);
65 
66         for (i, chunk) in output.chunks_exact_mut(BLOCK_SIZE).enumerate() {
67             unsafe {
68                 let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
69                 let mut v3 = iv_setup(self.iv, counter.checked_add(i as u64).unwrap());
70                 self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
71 
72                 for (ch, a) in chunk.chunks_exact_mut(0x10).zip(&[v0, v1, v2, v3]) {
73                     let b = _mm_loadu_si128(ch.as_ptr() as *const __m128i);
74                     let out = _mm_xor_si128(*a, b);
75                     _mm_storeu_si128(ch.as_mut_ptr() as *mut __m128i, out);
76                 }
77             }
78         }
79     }
80 
81     #[inline]
82     #[target_feature(enable = "sse2")]
rounds( &self, v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i, )83     unsafe fn rounds(
84         &self,
85         v0: &mut __m128i,
86         v1: &mut __m128i,
87         v2: &mut __m128i,
88         v3: &mut __m128i,
89     ) {
90         let v3_orig = *v3;
91 
92         for _ in 0..(R::COUNT / 2) {
93             double_quarter_round(v0, v1, v2, v3);
94         }
95 
96         *v0 = _mm_add_epi32(*v0, self.v0);
97         *v1 = _mm_add_epi32(*v1, self.v1);
98         *v2 = _mm_add_epi32(*v2, self.v2);
99         *v3 = _mm_add_epi32(*v3, v3_orig);
100     }
101 }
102 
103 #[inline]
104 #[target_feature(enable = "sse2")]
105 #[allow(clippy::cast_ptr_alignment)] // loadu supports unaligned loads
key_setup(key: &[u8; KEY_SIZE]) -> (__m128i, __m128i, __m128i)106 unsafe fn key_setup(key: &[u8; KEY_SIZE]) -> (__m128i, __m128i, __m128i) {
107     let v0 = _mm_loadu_si128(CONSTANTS.as_ptr() as *const __m128i);
108     let v1 = _mm_loadu_si128(key.as_ptr().offset(0x00) as *const __m128i);
109     let v2 = _mm_loadu_si128(key.as_ptr().offset(0x10) as *const __m128i);
110     (v0, v1, v2)
111 }
112 
113 #[inline]
114 #[target_feature(enable = "sse2")]
iv_setup(iv: [i32; 2], counter: u64) -> __m128i115 unsafe fn iv_setup(iv: [i32; 2], counter: u64) -> __m128i {
116     _mm_set_epi32(
117         iv[0],
118         iv[1],
119         ((counter >> 32) & 0xffff_ffff) as i32,
120         (counter & 0xffff_ffff) as i32,
121     )
122 }
123 
124 #[inline]
125 #[target_feature(enable = "sse2")]
126 #[allow(clippy::cast_ptr_alignment)] // storeu supports unaligned stores
store(v0: __m128i, v1: __m128i, v2: __m128i, v3: __m128i, output: &mut [u8])127 unsafe fn store(v0: __m128i, v1: __m128i, v2: __m128i, v3: __m128i, output: &mut [u8]) {
128     _mm_storeu_si128(output.as_mut_ptr().offset(0x00) as *mut __m128i, v0);
129     _mm_storeu_si128(output.as_mut_ptr().offset(0x10) as *mut __m128i, v1);
130     _mm_storeu_si128(output.as_mut_ptr().offset(0x20) as *mut __m128i, v2);
131     _mm_storeu_si128(output.as_mut_ptr().offset(0x30) as *mut __m128i, v3);
132 }
133 
134 #[inline]
135 #[target_feature(enable = "sse2")]
double_quarter_round( v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i, )136 unsafe fn double_quarter_round(
137     v0: &mut __m128i,
138     v1: &mut __m128i,
139     v2: &mut __m128i,
140     v3: &mut __m128i,
141 ) {
142     add_xor_rot(v0, v1, v2, v3);
143     rows_to_cols(v0, v1, v2, v3);
144     add_xor_rot(v0, v1, v2, v3);
145     cols_to_rows(v0, v1, v2, v3);
146 }
147 
148 #[inline]
149 #[target_feature(enable = "sse2")]
rows_to_cols(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i)150 unsafe fn rows_to_cols(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i) {
151     // v1 >>>= 32; v2 >>>= 64; v3 >>>= 96;
152     *v1 = _mm_shuffle_epi32(*v1, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
153     *v2 = _mm_shuffle_epi32(*v2, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
154     *v3 = _mm_shuffle_epi32(*v3, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
155 }
156 
157 #[inline]
158 #[target_feature(enable = "sse2")]
cols_to_rows(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i)159 unsafe fn cols_to_rows(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i) {
160     // v1 <<<= 32; v2 <<<= 64; v3 <<<= 96;
161     *v1 = _mm_shuffle_epi32(*v1, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
162     *v2 = _mm_shuffle_epi32(*v2, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
163     *v3 = _mm_shuffle_epi32(*v3, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
164 }
165 
166 #[inline]
167 #[target_feature(enable = "sse2")]
add_xor_rot(v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i)168 unsafe fn add_xor_rot(v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i) {
169     // v0 += v1; v3 ^= v0; v3 <<<= (16, 16, 16, 16);
170     *v0 = _mm_add_epi32(*v0, *v1);
171     *v3 = _mm_xor_si128(*v3, *v0);
172     *v3 = _mm_xor_si128(_mm_slli_epi32(*v3, 16), _mm_srli_epi32(*v3, 16));
173 
174     // v2 += v3; v1 ^= v2; v1 <<<= (12, 12, 12, 12);
175     *v2 = _mm_add_epi32(*v2, *v3);
176     *v1 = _mm_xor_si128(*v1, *v2);
177     *v1 = _mm_xor_si128(_mm_slli_epi32(*v1, 12), _mm_srli_epi32(*v1, 20));
178 
179     // v0 += v1; v3 ^= v0; v3 <<<= (8, 8, 8, 8);
180     *v0 = _mm_add_epi32(*v0, *v1);
181     *v3 = _mm_xor_si128(*v3, *v0);
182     *v3 = _mm_xor_si128(_mm_slli_epi32(*v3, 8), _mm_srli_epi32(*v3, 24));
183 
184     // v2 += v3; v1 ^= v2; v1 <<<= (7, 7, 7, 7);
185     *v2 = _mm_add_epi32(*v2, *v3);
186     *v1 = _mm_xor_si128(*v1, *v2);
187     *v1 = _mm_xor_si128(_mm_slli_epi32(*v1, 7), _mm_srli_epi32(*v1, 25));
188 }
189 
190 #[cfg(all(test, target_feature = "sse2"))]
191 mod tests {
192     use super::*;
193     use crate::rounds::R20;
194     use crate::{backend::soft, BLOCK_SIZE};
195     use core::convert::TryInto;
196 
197     // random inputs for testing
198     const R_CNT: u64 = 0x9fe625b6d23a8fa8u64;
199     const R_IV: [u8; IV_SIZE] = [0x2f, 0x96, 0xa8, 0x4a, 0xf8, 0x92, 0xbc, 0x94];
200     const R_KEY: [u8; KEY_SIZE] = [
201         0x11, 0xf2, 0x72, 0x99, 0xe1, 0x79, 0x6d, 0xef, 0xb, 0xdc, 0x6a, 0x58, 0x1f, 0x1, 0x58,
202         0x94, 0x92, 0x19, 0x69, 0x3f, 0xe9, 0x35, 0x16, 0x72, 0x63, 0xd1, 0xd, 0x94, 0x6d, 0x31,
203         0x34, 0x11,
204     ];
205 
206     #[test]
init_and_store()207     fn init_and_store() {
208         unsafe {
209             let (v0, v1, v2) = key_setup(&R_KEY);
210 
211             let v3 = iv_setup(
212                 [
213                     i32::from_le_bytes(R_IV[4..].try_into().unwrap()),
214                     i32::from_le_bytes(R_IV[..4].try_into().unwrap()),
215                 ],
216                 R_CNT,
217             );
218 
219             let vs = [v0, v1, v2, v3];
220 
221             let mut output = [0u8; BLOCK_SIZE];
222             store(vs[0], vs[1], vs[2], vs[3], &mut output);
223 
224             let expected = [
225                 1634760805, 857760878, 2036477234, 1797285236, 2574447121, 4016929249, 1483398155,
226                 2488795423, 1063852434, 1914058217, 2483933539, 288633197, 3527053224, 2682660278,
227                 1252562479, 2495386360,
228             ];
229 
230             for (i, chunk) in output.chunks(4).enumerate() {
231                 assert_eq!(expected[i], u32::from_le_bytes(chunk.try_into().unwrap()));
232             }
233         }
234     }
235 
236     #[test]
init_and_double_round()237     fn init_and_double_round() {
238         unsafe {
239             let (mut v0, mut v1, mut v2) = key_setup(&R_KEY);
240 
241             let mut v3 = iv_setup(
242                 [
243                     i32::from_le_bytes(R_IV[4..].try_into().unwrap()),
244                     i32::from_le_bytes(R_IV[..4].try_into().unwrap()),
245                 ],
246                 R_CNT,
247             );
248 
249             double_quarter_round(&mut v0, &mut v1, &mut v2, &mut v3);
250 
251             let mut output = [0u8; BLOCK_SIZE];
252             store(v0, v1, v2, v3, &mut output);
253 
254             let expected = [
255                 562456049, 3130322832, 1534507163, 1938142593, 1427879055, 3727017100, 1549525649,
256                 2358041203, 1010155040, 657444539, 2865892668, 2826477124, 737507996, 3254278724,
257                 3376929372, 928763221,
258             ];
259 
260             for (i, chunk) in output.chunks(4).enumerate() {
261                 assert_eq!(expected[i], u32::from_le_bytes(chunk.try_into().unwrap()));
262             }
263         }
264     }
265 
266     #[test]
generate_vs_scalar_impl()267     fn generate_vs_scalar_impl() {
268         let mut soft_result = [0u8; soft::BUFFER_SIZE];
269         soft::Core::<R20>::new(&R_KEY, R_IV).generate(R_CNT, &mut soft_result);
270 
271         let mut simd_result = [0u8; BUFFER_SIZE];
272         Core::<R20>::new(&R_KEY, R_IV).generate(R_CNT, &mut simd_result);
273 
274         assert_eq!(&soft_result[..], &simd_result[..soft::BUFFER_SIZE])
275     }
276 }
277