1 //! AES in counter mode (a.k.a. AES-CTR)
2
3 // TODO(tarcieri): support generic CTR API
4
5 #![allow(clippy::unreadable_literal)]
6
7 use super::arch::*;
8 use core::mem;
9
10 use super::{Aes128, Aes192, Aes256};
11 use crate::BLOCK_SIZE;
12 use cipher::{
13 consts::U16,
14 errors::{LoopError, OverflowError},
15 generic_array::GenericArray,
16 BlockCipher, FromBlockCipher, SeekNum, StreamCipher, StreamCipherSeek,
17 };
18
19 const PAR_BLOCKS: usize = 8;
20 const PAR_BLOCKS_SIZE: usize = PAR_BLOCKS * BLOCK_SIZE;
21
22 #[inline(always)]
xor(buf: &mut [u8], key: &[u8])23 pub fn xor(buf: &mut [u8], key: &[u8]) {
24 debug_assert_eq!(buf.len(), key.len());
25 for (a, b) in buf.iter_mut().zip(key) {
26 *a ^= *b;
27 }
28 }
29
30 #[inline(always)]
xor_block8(buf: &mut [u8], ctr: [__m128i; 8])31 fn xor_block8(buf: &mut [u8], ctr: [__m128i; 8]) {
32 debug_assert_eq!(buf.len(), PAR_BLOCKS_SIZE);
33
34 // Safety: `loadu` and `storeu` support unaligned access
35 #[allow(clippy::cast_ptr_alignment)]
36 unsafe {
37 // compiler should unroll this loop
38 for i in 0..8 {
39 let ptr = buf.as_mut_ptr().offset(16 * i) as *mut __m128i;
40 let data = _mm_loadu_si128(ptr);
41 let data = _mm_xor_si128(data, ctr[i as usize]);
42 _mm_storeu_si128(ptr, data);
43 }
44 }
45 }
46
47 #[inline(always)]
swap_bytes(v: __m128i) -> __m128i48 fn swap_bytes(v: __m128i) -> __m128i {
49 unsafe {
50 let mask = _mm_set_epi64x(0x08090a0b0c0d0e0f, 0x0001020304050607);
51 _mm_shuffle_epi8(v, mask)
52 }
53 }
54
55 #[inline(always)]
inc_be(v: __m128i) -> __m128i56 fn inc_be(v: __m128i) -> __m128i {
57 unsafe { _mm_add_epi64(v, _mm_set_epi64x(1, 0)) }
58 }
59
60 #[inline(always)]
load(val: &GenericArray<u8, U16>) -> __m128i61 fn load(val: &GenericArray<u8, U16>) -> __m128i {
62 // Safety: `loadu` supports unaligned loads
63 #[allow(clippy::cast_ptr_alignment)]
64 unsafe {
65 _mm_loadu_si128(val.as_ptr() as *const __m128i)
66 }
67 }
68
69 macro_rules! impl_ctr {
70 ($name:ident, $cipher:ty, $doc:expr) => {
71 #[doc=$doc]
72 #[derive(Clone)]
73 #[cfg_attr(docsrs, doc(cfg(feature = "ctr")))]
74 pub struct $name {
75 nonce: __m128i,
76 ctr: __m128i,
77 cipher: $cipher,
78 block: [u8; BLOCK_SIZE],
79 pos: u8,
80 }
81
82 impl $name {
83 #[inline(always)]
84 fn gen_block(&mut self) {
85 let block = self.cipher.encrypt(swap_bytes(self.ctr));
86 self.block = unsafe { mem::transmute(block) }
87 }
88
89 #[inline(always)]
90 fn next_block(&mut self) -> __m128i {
91 let block = swap_bytes(self.ctr);
92 self.ctr = inc_be(self.ctr);
93 self.cipher.encrypt(block)
94 }
95
96 #[inline(always)]
97 fn next_block8(&mut self) -> [__m128i; 8] {
98 let mut ctr = self.ctr;
99 let mut block8: [__m128i; 8] = unsafe { mem::zeroed() };
100 for i in 0..8 {
101 block8[i] = swap_bytes(ctr);
102 ctr = inc_be(ctr);
103 }
104 self.ctr = ctr;
105
106 self.cipher.encrypt8(block8)
107 }
108
109 #[inline(always)]
110 fn get_u64_ctr(&self) -> u64 {
111 let (ctr, nonce) = unsafe {
112 (
113 mem::transmute::<__m128i, [u64; 2]>(self.ctr)[1],
114 mem::transmute::<__m128i, [u64; 2]>(self.nonce)[1],
115 )
116 };
117 ctr.wrapping_sub(nonce)
118 }
119
120 /// Check if provided data will not overflow counter
121 #[inline(always)]
122 fn check_data_len(&self, data: &[u8]) -> Result<(), LoopError> {
123 let bs = BLOCK_SIZE;
124 let leftover_bytes = bs - self.pos as usize;
125 if data.len() < leftover_bytes {
126 return Ok(());
127 }
128 let blocks = 1 + (data.len() - leftover_bytes) / bs;
129 self.get_u64_ctr()
130 .checked_add(blocks as u64)
131 .ok_or(LoopError)
132 .map(|_| ())
133 }
134 }
135
136 impl FromBlockCipher for $name {
137 type BlockCipher = $cipher;
138 type NonceSize = <$cipher as BlockCipher>::BlockSize;
139
140 fn from_block_cipher(
141 cipher: $cipher,
142 nonce: &GenericArray<u8, Self::NonceSize>,
143 ) -> Self {
144 let nonce = swap_bytes(load(nonce));
145 Self {
146 nonce,
147 ctr: nonce,
148 cipher,
149 block: [0u8; BLOCK_SIZE],
150 pos: 0,
151 }
152 }
153 }
154
155 impl StreamCipher for $name {
156 #[inline]
157 fn try_apply_keystream(&mut self, mut data: &mut [u8]) -> Result<(), LoopError> {
158 self.check_data_len(data)?;
159 let bs = BLOCK_SIZE;
160 let pos = self.pos as usize;
161 debug_assert!(bs > pos);
162
163 if pos != 0 {
164 if data.len() < bs - pos {
165 let n = pos + data.len();
166 xor(data, &self.block[pos..n]);
167 self.pos = n as u8;
168 return Ok(());
169 } else {
170 let (l, r) = data.split_at_mut(bs - pos);
171 data = r;
172 xor(l, &self.block[pos..]);
173 self.ctr = inc_be(self.ctr);
174 }
175 }
176
177 let mut chunks = data.chunks_exact_mut(PAR_BLOCKS_SIZE);
178 for chunk in &mut chunks {
179 xor_block8(chunk, self.next_block8());
180 }
181 data = chunks.into_remainder();
182
183 let mut chunks = data.chunks_exact_mut(bs);
184 for chunk in &mut chunks {
185 let block = self.next_block();
186
187 unsafe {
188 let t = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
189 let res = _mm_xor_si128(block, t);
190 _mm_storeu_si128(chunk.as_mut_ptr() as *mut __m128i, res);
191 }
192 }
193
194 let rem = chunks.into_remainder();
195 self.pos = rem.len() as u8;
196 if !rem.is_empty() {
197 self.gen_block();
198 for (a, b) in rem.iter_mut().zip(&self.block) {
199 *a ^= *b;
200 }
201 }
202
203 Ok(())
204 }
205 }
206
207 impl StreamCipherSeek for $name {
208 fn try_current_pos<T: SeekNum>(&self) -> Result<T, OverflowError> {
209 T::from_block_byte(self.get_u64_ctr(), self.pos, BLOCK_SIZE as u8)
210 }
211
212 fn try_seek<T: SeekNum>(&mut self, pos: T) -> Result<(), LoopError> {
213 let res: (u64, u8) = pos.to_block_byte(BLOCK_SIZE as u8)?;
214 self.ctr = unsafe { _mm_add_epi64(self.nonce, _mm_set_epi64x(res.0 as i64, 0)) };
215 self.pos = res.1;
216 if self.pos != 0 {
217 self.gen_block()
218 }
219 Ok(())
220 }
221 }
222
223 opaque_debug::implement!($name);
224 };
225 }
226
227 impl_ctr!(Aes128Ctr, Aes128, "AES-128 in CTR mode");
228 impl_ctr!(Aes192Ctr, Aes192, "AES-192 in CTR mode");
229 impl_ctr!(Aes256Ctr, Aes256, "AES-256 in CTR mode");
230