1 //! AES in counter mode (a.k.a. AES-CTR)
2 
3 // TODO(tarcieri): support generic CTR API
4 
5 #![allow(clippy::unreadable_literal)]
6 
7 use super::arch::*;
8 use core::mem;
9 
10 use super::{Aes128, Aes192, Aes256};
11 use crate::BLOCK_SIZE;
12 use cipher::{
13     consts::U16,
14     errors::{LoopError, OverflowError},
15     generic_array::GenericArray,
16     BlockCipher, FromBlockCipher, SeekNum, StreamCipher, StreamCipherSeek,
17 };
18 
19 const PAR_BLOCKS: usize = 8;
20 const PAR_BLOCKS_SIZE: usize = PAR_BLOCKS * BLOCK_SIZE;
21 
22 #[inline(always)]
xor(buf: &mut [u8], key: &[u8])23 pub fn xor(buf: &mut [u8], key: &[u8]) {
24     debug_assert_eq!(buf.len(), key.len());
25     for (a, b) in buf.iter_mut().zip(key) {
26         *a ^= *b;
27     }
28 }
29 
30 #[inline(always)]
xor_block8(buf: &mut [u8], ctr: [__m128i; 8])31 fn xor_block8(buf: &mut [u8], ctr: [__m128i; 8]) {
32     debug_assert_eq!(buf.len(), PAR_BLOCKS_SIZE);
33 
34     // Safety: `loadu` and `storeu` support unaligned access
35     #[allow(clippy::cast_ptr_alignment)]
36     unsafe {
37         // compiler should unroll this loop
38         for i in 0..8 {
39             let ptr = buf.as_mut_ptr().offset(16 * i) as *mut __m128i;
40             let data = _mm_loadu_si128(ptr);
41             let data = _mm_xor_si128(data, ctr[i as usize]);
42             _mm_storeu_si128(ptr, data);
43         }
44     }
45 }
46 
47 #[inline(always)]
swap_bytes(v: __m128i) -> __m128i48 fn swap_bytes(v: __m128i) -> __m128i {
49     unsafe {
50         let mask = _mm_set_epi64x(0x08090a0b0c0d0e0f, 0x0001020304050607);
51         _mm_shuffle_epi8(v, mask)
52     }
53 }
54 
55 #[inline(always)]
inc_be(v: __m128i) -> __m128i56 fn inc_be(v: __m128i) -> __m128i {
57     unsafe { _mm_add_epi64(v, _mm_set_epi64x(1, 0)) }
58 }
59 
60 #[inline(always)]
load(val: &GenericArray<u8, U16>) -> __m128i61 fn load(val: &GenericArray<u8, U16>) -> __m128i {
62     // Safety: `loadu` supports unaligned loads
63     #[allow(clippy::cast_ptr_alignment)]
64     unsafe {
65         _mm_loadu_si128(val.as_ptr() as *const __m128i)
66     }
67 }
68 
69 macro_rules! impl_ctr {
70     ($name:ident, $cipher:ty, $doc:expr) => {
71         #[doc=$doc]
72         #[derive(Clone)]
73         #[cfg_attr(docsrs, doc(cfg(feature = "ctr")))]
74         pub struct $name {
75             nonce: __m128i,
76             ctr: __m128i,
77             cipher: $cipher,
78             block: [u8; BLOCK_SIZE],
79             pos: u8,
80         }
81 
82         impl $name {
83             #[inline(always)]
84             fn gen_block(&mut self) {
85                 let block = self.cipher.encrypt(swap_bytes(self.ctr));
86                 self.block = unsafe { mem::transmute(block) }
87             }
88 
89             #[inline(always)]
90             fn next_block(&mut self) -> __m128i {
91                 let block = swap_bytes(self.ctr);
92                 self.ctr = inc_be(self.ctr);
93                 self.cipher.encrypt(block)
94             }
95 
96             #[inline(always)]
97             fn next_block8(&mut self) -> [__m128i; 8] {
98                 let mut ctr = self.ctr;
99                 let mut block8: [__m128i; 8] = unsafe { mem::zeroed() };
100                 for i in 0..8 {
101                     block8[i] = swap_bytes(ctr);
102                     ctr = inc_be(ctr);
103                 }
104                 self.ctr = ctr;
105 
106                 self.cipher.encrypt8(block8)
107             }
108 
109             #[inline(always)]
110             fn get_u64_ctr(&self) -> u64 {
111                 let (ctr, nonce) = unsafe {
112                     (
113                         mem::transmute::<__m128i, [u64; 2]>(self.ctr)[1],
114                         mem::transmute::<__m128i, [u64; 2]>(self.nonce)[1],
115                     )
116                 };
117                 ctr.wrapping_sub(nonce)
118             }
119 
120             /// Check if provided data will not overflow counter
121             #[inline(always)]
122             fn check_data_len(&self, data: &[u8]) -> Result<(), LoopError> {
123                 let bs = BLOCK_SIZE;
124                 let leftover_bytes = bs - self.pos as usize;
125                 if data.len() < leftover_bytes {
126                     return Ok(());
127                 }
128                 let blocks = 1 + (data.len() - leftover_bytes) / bs;
129                 self.get_u64_ctr()
130                     .checked_add(blocks as u64)
131                     .ok_or(LoopError)
132                     .map(|_| ())
133             }
134         }
135 
136         impl FromBlockCipher for $name {
137             type BlockCipher = $cipher;
138             type NonceSize = <$cipher as BlockCipher>::BlockSize;
139 
140             fn from_block_cipher(
141                 cipher: $cipher,
142                 nonce: &GenericArray<u8, Self::NonceSize>,
143             ) -> Self {
144                 let nonce = swap_bytes(load(nonce));
145                 Self {
146                     nonce,
147                     ctr: nonce,
148                     cipher,
149                     block: [0u8; BLOCK_SIZE],
150                     pos: 0,
151                 }
152             }
153         }
154 
155         impl StreamCipher for $name {
156             #[inline]
157             fn try_apply_keystream(&mut self, mut data: &mut [u8]) -> Result<(), LoopError> {
158                 self.check_data_len(data)?;
159                 let bs = BLOCK_SIZE;
160                 let pos = self.pos as usize;
161                 debug_assert!(bs > pos);
162 
163                 if pos != 0 {
164                     if data.len() < bs - pos {
165                         let n = pos + data.len();
166                         xor(data, &self.block[pos..n]);
167                         self.pos = n as u8;
168                         return Ok(());
169                     } else {
170                         let (l, r) = data.split_at_mut(bs - pos);
171                         data = r;
172                         xor(l, &self.block[pos..]);
173                         self.ctr = inc_be(self.ctr);
174                     }
175                 }
176 
177                 let mut chunks = data.chunks_exact_mut(PAR_BLOCKS_SIZE);
178                 for chunk in &mut chunks {
179                     xor_block8(chunk, self.next_block8());
180                 }
181                 data = chunks.into_remainder();
182 
183                 let mut chunks = data.chunks_exact_mut(bs);
184                 for chunk in &mut chunks {
185                     let block = self.next_block();
186 
187                     unsafe {
188                         let t = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
189                         let res = _mm_xor_si128(block, t);
190                         _mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, res);
191                     }
192                 }
193 
194                 let rem = chunks.into_remainder();
195                 self.pos = rem.len() as u8;
196                 if !rem.is_empty() {
197                     self.gen_block();
198                     for (a, b) in rem.iter_mut().zip(&self.block) {
199                         *a ^= *b;
200                     }
201                 }
202 
203                 Ok(())
204             }
205         }
206 
207         impl StreamCipherSeek for $name {
208             fn try_current_pos<T: SeekNum>(&self) -> Result<T, OverflowError> {
209                 T::from_block_byte(self.get_u64_ctr(), self.pos, BLOCK_SIZE as u8)
210             }
211 
212             fn try_seek<T: SeekNum>(&mut self, pos: T) -> Result<(), LoopError> {
213                 let res: (u64, u8) = pos.to_block_byte(BLOCK_SIZE as u8)?;
214                 self.ctr = unsafe { _mm_add_epi64(self.nonce, _mm_set_epi64x(res.0 as i64, 0)) };
215                 self.pos = res.1;
216                 if self.pos != 0 {
217                     self.gen_block()
218                 }
219                 Ok(())
220             }
221         }
222 
223         opaque_debug::implement!($name);
224     };
225 }
226 
227 impl_ctr!(Aes128Ctr, Aes128, "AES-128 in CTR mode");
228 impl_ctr!(Aes192Ctr, Aes192, "AES-192 in CTR mode");
229 impl_ctr!(Aes256Ctr, Aes256, "AES-256 in CTR mode");
230