1 //! The ChaCha20 block function. Defined in RFC 8439 Section 2.3.
2 //!
3 //! <https://tools.ietf.org/html/rfc8439#section-2.3>
4 //!
5 //! SSE2-optimized implementation for x86/x86-64 CPUs.
6 
7 use crate::{rounds::Rounds, BLOCK_SIZE, CONSTANTS, IV_SIZE, KEY_SIZE};
8 use core::{convert::TryInto, marker::PhantomData};
9 
10 #[cfg(target_arch = "x86")]
11 use core::arch::x86::*;
12 #[cfg(target_arch = "x86_64")]
13 use core::arch::x86_64::*;
14 
15 /// Size of buffers passed to `generate` and `apply_keystream` for this backend
16 pub(crate) const BUFFER_SIZE: usize = BLOCK_SIZE;
17 
18 /// The ChaCha20 block function (SSE2 accelerated implementation for x86/x86_64)
19 // TODO(tarcieri): zeroize?
20 #[derive(Clone)]
21 pub(crate) struct Block<R: Rounds> {
22     v0: __m128i,
23     v1: __m128i,
24     v2: __m128i,
25     iv: [i32; 2],
26     rounds: PhantomData<R>,
27 }
28 
29 impl<R: Rounds> Block<R> {
30     /// Initialize block function with the given key size, IV, and number of rounds
31     #[inline]
new(key: &[u8; KEY_SIZE], iv: [u8; IV_SIZE]) -> Self32     pub(crate) fn new(key: &[u8; KEY_SIZE], iv: [u8; IV_SIZE]) -> Self {
33         let (v0, v1, v2) = unsafe { key_setup(key) };
34         let iv = [
35             i32::from_le_bytes(iv[4..].try_into().unwrap()),
36             i32::from_le_bytes(iv[..4].try_into().unwrap()),
37         ];
38 
39         Self {
40             v0,
41             v1,
42             v2,
43             iv,
44             rounds: PhantomData,
45         }
46     }
47 
48     #[inline]
generate(&self, counter: u64, output: &mut [u8])49     pub(crate) fn generate(&self, counter: u64, output: &mut [u8]) {
50         debug_assert_eq!(output.len(), BUFFER_SIZE);
51 
52         unsafe {
53             let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
54             let mut v3 = iv_setup(self.iv, counter);
55             self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
56             store(v0, v1, v2, v3, output)
57         }
58     }
59 
60     #[inline]
61     #[cfg(feature = "cipher")]
62     #[allow(clippy::cast_ptr_alignment)] // loadu/storeu support unaligned loads/stores
apply_keystream(&self, counter: u64, output: &mut [u8])63     pub(crate) fn apply_keystream(&self, counter: u64, output: &mut [u8]) {
64         debug_assert_eq!(output.len(), BUFFER_SIZE);
65 
66         unsafe {
67             let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
68             let mut v3 = iv_setup(self.iv, counter);
69             self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
70 
71             for (chunk, a) in output.chunks_mut(0x10).zip(&[v0, v1, v2, v3]) {
72                 let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
73                 let out = _mm_xor_si128(*a, b);
74                 _mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, out);
75             }
76         }
77     }
78 
79     #[inline]
80     #[target_feature(enable = "sse2")]
rounds( &self, v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i, )81     unsafe fn rounds(
82         &self,
83         v0: &mut __m128i,
84         v1: &mut __m128i,
85         v2: &mut __m128i,
86         v3: &mut __m128i,
87     ) {
88         let v3_orig = *v3;
89 
90         for _ in 0..(R::COUNT / 2) {
91             double_quarter_round(v0, v1, v2, v3);
92         }
93 
94         *v0 = _mm_add_epi32(*v0, self.v0);
95         *v1 = _mm_add_epi32(*v1, self.v1);
96         *v2 = _mm_add_epi32(*v2, self.v2);
97         *v3 = _mm_add_epi32(*v3, v3_orig);
98     }
99 }
100 
101 #[inline]
102 #[target_feature(enable = "sse2")]
103 #[allow(clippy::cast_ptr_alignment)] // loadu supports unaligned loads
key_setup(key: &[u8; KEY_SIZE]) -> (__m128i, __m128i, __m128i)104 unsafe fn key_setup(key: &[u8; KEY_SIZE]) -> (__m128i, __m128i, __m128i) {
105     let v0 = _mm_loadu_si128(CONSTANTS.as_ptr() as *const __m128i);
106     let v1 = _mm_loadu_si128(key.as_ptr().offset(0x00) as *const __m128i);
107     let v2 = _mm_loadu_si128(key.as_ptr().offset(0x10) as *const __m128i);
108     (v0, v1, v2)
109 }
110 
111 #[inline]
112 #[target_feature(enable = "sse2")]
iv_setup(iv: [i32; 2], counter: u64) -> __m128i113 unsafe fn iv_setup(iv: [i32; 2], counter: u64) -> __m128i {
114     _mm_set_epi32(
115         iv[0],
116         iv[1],
117         ((counter >> 32) & 0xffff_ffff) as i32,
118         (counter & 0xffff_ffff) as i32,
119     )
120 }
121 
122 #[inline]
123 #[target_feature(enable = "sse2")]
124 #[allow(clippy::cast_ptr_alignment)] // storeu supports unaligned stores
store(v0: __m128i, v1: __m128i, v2: __m128i, v3: __m128i, output: &mut [u8])125 unsafe fn store(v0: __m128i, v1: __m128i, v2: __m128i, v3: __m128i, output: &mut [u8]) {
126     _mm_storeu_si128(output.as_mut_ptr().offset(0x00) as *mut __m128i, v0);
127     _mm_storeu_si128(output.as_mut_ptr().offset(0x10) as *mut __m128i, v1);
128     _mm_storeu_si128(output.as_mut_ptr().offset(0x20) as *mut __m128i, v2);
129     _mm_storeu_si128(output.as_mut_ptr().offset(0x30) as *mut __m128i, v3);
130 }
131 
132 #[inline]
133 #[target_feature(enable = "sse2")]
double_quarter_round( v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i, )134 unsafe fn double_quarter_round(
135     v0: &mut __m128i,
136     v1: &mut __m128i,
137     v2: &mut __m128i,
138     v3: &mut __m128i,
139 ) {
140     add_xor_rot(v0, v1, v2, v3);
141     rows_to_cols(v0, v1, v2, v3);
142     add_xor_rot(v0, v1, v2, v3);
143     cols_to_rows(v0, v1, v2, v3);
144 }
145 
146 #[inline]
147 #[target_feature(enable = "sse2")]
rows_to_cols(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i)148 unsafe fn rows_to_cols(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i) {
149     // v1 >>>= 32; v2 >>>= 64; v3 >>>= 96;
150     *v1 = _mm_shuffle_epi32(*v1, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
151     *v2 = _mm_shuffle_epi32(*v2, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
152     *v3 = _mm_shuffle_epi32(*v3, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
153 }
154 
155 #[inline]
156 #[target_feature(enable = "sse2")]
cols_to_rows(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i)157 unsafe fn cols_to_rows(_v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i) {
158     // v1 <<<= 32; v2 <<<= 64; v3 <<<= 96;
159     *v1 = _mm_shuffle_epi32(*v1, 0b_10_01_00_11); // _MM_SHUFFLE(2, 1, 0, 3)
160     *v2 = _mm_shuffle_epi32(*v2, 0b_01_00_11_10); // _MM_SHUFFLE(1, 0, 3, 2)
161     *v3 = _mm_shuffle_epi32(*v3, 0b_00_11_10_01); // _MM_SHUFFLE(0, 3, 2, 1)
162 }
163 
164 #[inline]
165 #[target_feature(enable = "sse2")]
add_xor_rot(v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i)166 unsafe fn add_xor_rot(v0: &mut __m128i, v1: &mut __m128i, v2: &mut __m128i, v3: &mut __m128i) {
167     // v0 += v1; v3 ^= v0; v3 <<<= (16, 16, 16, 16);
168     *v0 = _mm_add_epi32(*v0, *v1);
169     *v3 = _mm_xor_si128(*v3, *v0);
170     *v3 = _mm_xor_si128(_mm_slli_epi32(*v3, 16), _mm_srli_epi32(*v3, 16));
171 
172     // v2 += v3; v1 ^= v2; v1 <<<= (12, 12, 12, 12);
173     *v2 = _mm_add_epi32(*v2, *v3);
174     *v1 = _mm_xor_si128(*v1, *v2);
175     *v1 = _mm_xor_si128(_mm_slli_epi32(*v1, 12), _mm_srli_epi32(*v1, 20));
176 
177     // v0 += v1; v3 ^= v0; v3 <<<= (8, 8, 8, 8);
178     *v0 = _mm_add_epi32(*v0, *v1);
179     *v3 = _mm_xor_si128(*v3, *v0);
180     *v3 = _mm_xor_si128(_mm_slli_epi32(*v3, 8), _mm_srli_epi32(*v3, 24));
181 
182     // v2 += v3; v1 ^= v2; v1 <<<= (7, 7, 7, 7);
183     *v2 = _mm_add_epi32(*v2, *v3);
184     *v1 = _mm_xor_si128(*v1, *v2);
185     *v1 = _mm_xor_si128(_mm_slli_epi32(*v1, 7), _mm_srli_epi32(*v1, 25));
186 }
187 
188 #[cfg(all(test, target_feature = "sse2"))]
189 mod tests {
190     use super::*;
191     use crate::rounds::R20;
192     use crate::{block::soft::Block as SoftBlock, BLOCK_SIZE};
193     use core::convert::TryInto;
194 
195     // random inputs for testing
196     const R_CNT: u64 = 0x9fe625b6d23a8fa8u64;
197     const R_IV: [u8; IV_SIZE] = [0x2f, 0x96, 0xa8, 0x4a, 0xf8, 0x92, 0xbc, 0x94];
198     const R_KEY: [u8; KEY_SIZE] = [
199         0x11, 0xf2, 0x72, 0x99, 0xe1, 0x79, 0x6d, 0xef, 0xb, 0xdc, 0x6a, 0x58, 0x1f, 0x1, 0x58,
200         0x94, 0x92, 0x19, 0x69, 0x3f, 0xe9, 0x35, 0x16, 0x72, 0x63, 0xd1, 0xd, 0x94, 0x6d, 0x31,
201         0x34, 0x11,
202     ];
203 
204     #[test]
init_and_store()205     fn init_and_store() {
206         unsafe {
207             let (v0, v1, v2) = key_setup(&R_KEY);
208 
209             let v3 = iv_setup(
210                 [
211                     i32::from_le_bytes(R_IV[4..].try_into().unwrap()),
212                     i32::from_le_bytes(R_IV[..4].try_into().unwrap()),
213                 ],
214                 R_CNT,
215             );
216 
217             let vs = [v0, v1, v2, v3];
218 
219             let mut output = [0u8; BLOCK_SIZE];
220             store(vs[0], vs[1], vs[2], vs[3], &mut output);
221 
222             let expected = [
223                 1634760805, 857760878, 2036477234, 1797285236, 2574447121, 4016929249, 1483398155,
224                 2488795423, 1063852434, 1914058217, 2483933539, 288633197, 3527053224, 2682660278,
225                 1252562479, 2495386360,
226             ];
227 
228             for (i, chunk) in output.chunks(4).enumerate() {
229                 assert_eq!(expected[i], u32::from_le_bytes(chunk.try_into().unwrap()));
230             }
231         }
232     }
233 
234     #[test]
init_and_double_round()235     fn init_and_double_round() {
236         unsafe {
237             let (mut v0, mut v1, mut v2) = key_setup(&R_KEY);
238 
239             let mut v3 = iv_setup(
240                 [
241                     i32::from_le_bytes(R_IV[4..].try_into().unwrap()),
242                     i32::from_le_bytes(R_IV[..4].try_into().unwrap()),
243                 ],
244                 R_CNT,
245             );
246 
247             double_quarter_round(&mut v0, &mut v1, &mut v2, &mut v3);
248 
249             let mut output = [0u8; BLOCK_SIZE];
250             store(v0, v1, v2, v3, &mut output);
251 
252             let expected = [
253                 562456049, 3130322832, 1534507163, 1938142593, 1427879055, 3727017100, 1549525649,
254                 2358041203, 1010155040, 657444539, 2865892668, 2826477124, 737507996, 3254278724,
255                 3376929372, 928763221,
256             ];
257 
258             for (i, chunk) in output.chunks(4).enumerate() {
259                 assert_eq!(expected[i], u32::from_le_bytes(chunk.try_into().unwrap()));
260             }
261         }
262     }
263 
264     #[test]
generate_vs_scalar_impl()265     fn generate_vs_scalar_impl() {
266         let mut soft_result = [0u8; BLOCK_SIZE];
267         SoftBlock::<R20>::new(&R_KEY, R_IV).generate(R_CNT, &mut soft_result);
268 
269         let mut simd_result = [0u8; BLOCK_SIZE];
270         Block::<R20>::new(&R_KEY, R_IV).generate(R_CNT, &mut simd_result);
271 
272         assert_eq!(&soft_result[..], &simd_result[..])
273     }
274 }
275