1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 /*!
8 
9 The `aessafe` module implements the AES algorithm completely in software without using any table
10 lookups or other timing dependant mechanisms. This module actually contains two seperate
11 implementations - an implementation that works on a single block at a time and a second
12 implementation that processes 8 blocks in parallel. Some block encryption modes really only work if
13 you are processing a single blocks (CFB, OFB, and CBC encryption for example) while other modes
14 are trivially parallelizable (CTR and CBC decryption). Processing more blocks at once allows for
15 greater efficiency, especially when using wide registers, such as the XMM registers available in
16 x86 processors.
17 
18 ## AES Algorithm
19 
20 There are lots of places to go to on the internet for an involved description of how AES works. For
21 the purposes of this description, it sufficies to say that AES is just a block cipher that takes
22 a key of 16, 24, or 32 bytes and uses that to either encrypt or decrypt a block of 16 bytes. An
23 encryption or decryption operation consists of a number of rounds which involve some combination of
24 the following 4 basic operations:
25 
26 * ShiftRows
27 * MixColumns
28 * SubBytes
29 * AddRoundKey
30 
31 ## Timing problems
32 
33 Most software implementations of AES use a large set of lookup tables - generally at least the
34 SubBytes step is implemented via lookup tables; faster implementations generally implement the
35 MixColumns step this way as well. This is largely a design flaw in the AES implementation as it was
36 not realized during the NIST standardization process that table lookups can lead to security
37 problems [1]. The issue is that not all table lookups occur in constant time - an address that was
38 recently used is looked up much faster than one that hasn't been used in a while. A careful
39 adversary can measure the amount of time that each AES operation takes and use that information to
40 help determine the secret key or plain text information. More specifically, its not table lookups
41 that lead to these types of timing attacks - the issue is table lookups that use secret information
42 as part of the address to lookup. A table lookup that is performed the exact same way every time
43 regardless of the key or plaintext doesn't leak any information. This implementation uses no data
44 dependant table lookups.
45 
46 ## Bit Slicing
47 
48 Bit Slicing is a technique that is basically a software emulation of hardware implementation
49 techniques. One of the earliest implementations of this technique was for a DES implementation [4].
50 In hardware, table lookups do not present the same timing problems as they do in software, however
51 they present other problems - namely that a 256 byte S-box table takes up a huge amount of space on
52 a chip. Hardware implementations, thus, tend to avoid table lookups and instead calculate the
53 contents of the S-Boxes as part of every operation. So, the key to an efficient Bit Sliced software
54 implementation is to re-arrange all of the bits of data to process into a form that can easily be
55 applied in much the same way that it would be in hardeware. It is fortunate, that AES was designed
56 such that these types of hardware implementations could be very efficient - the contents of the
57 S-boxes are defined by a mathematical formula.
58 
59 A hardware implementation works on single bits at a time. Unlike adding variables in software,
60 however, that occur generally one at a time, hardware implementations are extremely parallel and
61 operate on many, many bits at once. Bit Slicing emulates that by moving all "equivalent" bits into
62 common registers and then operating on large groups of bits all at once. Calculating the S-box value
63 for a single bit is extremely expensive, but its much cheaper when you can amortize that cost over
64 128 bits (as in an XMM register). This implementation follows the same strategy as in [5] and that
65 is an excellent source for more specific details. However, a short description follows.
66 
67 The input data is simply a collection of bytes. Each byte is comprised of 8 bits, a low order bit
68 (bit 0) through a high order bit (bit 7). Bit slicing the input data simply takes all of the low
69 order bits (bit 0) from the input data, and moves them into a single register (eg: XMM0). Next, all
70 of them 2nd lowest bits are moved into their own register (eg: XMM1), and so on. After completion,
71 we're left with 8 variables, each of which contains an equivalent set of bits. The exact order of
72 those bits is irrevent for the implementation of the SubBytes step, however, it is very important
73 for the MixColumns step. Again, see [5] for details. Due to the design of AES, its them possible to
74 execute the entire AES operation using just bitwise exclusive ors and rotates once we have Bit
75 Sliced the input data. After the completion of the AES operation, we then un-Bit Slice the data
76 to give us our output. Clearly, the more bits that we can process at once, the faster this will go -
77 thus, the version that processes 8 blocks at once is roughly 8 times faster than processing just a
78 single block at a time.
79 
80 The ShiftRows step is fairly straight-forward to implement on the Bit Sliced state. The MixColumns
81 and especially the SubBytes steps are more complicated. This implementation draws heavily on the
82 formulas from [5], [6], and [7] to implement these steps.
83 
84 ## Implementation
85 
86 Both implementations work basically the same way and share pretty much all of their code. The key
87 is first processed to create all of the round keys where each round key is just a 16 byte chunk of
88 data that is combined into the AES state by the AddRoundKey step as part of each encryption or
89 decryption round. Processing the round key can be expensive, so this is done before encryption or
90 decryption. Before encrypting or decrypting data, the data to be processed by be Bit Sliced into 8
91 seperate variables where each variable holds equivalent bytes from the state. This Bit Sliced state
92 is stored as a Bs8State<T>, where T is the type that stores each set of bits. The first
93 implementation stores these bits in a u32 which permits up to 8 * 32 = 1024 bits of data to be
94 processed at once. This implementation only processes a single block at a time, so, in reality, only
95 512 bits are processed at once and the remaining 512 bits of the variables are unused. The 2nd
96 implementation uses u32x4s - vectors of 4 u32s. Thus, we can process 8 * 128 = 4096 bits at once,
97 which corresponds exactly to 8 blocks.
98 
99 The Bs8State struct implements the AesOps trait, which contains methods for each of the 4 main steps
100 of the AES algorithm. The types, T, each implement the AesBitValueOps trait, which containts methods
101 necessary for processing a collection or bit values and the AesOps trait relies heavily on this
102 trait to perform its operations.
103 
104 The Bs4State and Bs2State struct implement operations of various subfields of the full GF(2^8)
105 finite field which allows for efficient computation of the AES S-Boxes. See [7] for details.
106 
107 ## References
108 
109 [1] - "Cache-Collision Timing Attacks Against AES". Joseph Bonneau and Ilya Mironov.
110       http://www.jbonneau.com/doc/BM06-CHES-aes_cache_timing.pdf
111 [2] - "Software mitigations to hedge AES against cache-based software side channel vulnerabilities".
112       Ernie Brickell, et al. http://eprint.iacr.org/2006/052.pdf.
113 [3] - "Cache Attacks and Countermeasures: the Case of AES (Extended Version)".
114       Dag Arne Osvik, et al. tau.ac.il/~tromer/papers/cache.pdf‎.
115 [4] - "A Fast New DES Implementation in Software". Eli Biham.
116       http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.52.5429&rep=rep1&type=pdf.
117 [5] - "Faster and Timing-Attack Resistant AES-GCM". Emilia K ̈asper and Peter Schwabe.
118       http://www.chesworkshop.org/ches2009/presentations/01_Session_1/CHES2009_ekasper.pdf.
119 [6] - "FAST AES DECRYPTION". Vinit Azad. http://webcache.googleusercontent.com/
120       search?q=cache:ld_f8pSgURcJ:csusdspace.calstate.edu/bitstream/handle/10211.9/1224/
121       Vinit_Azad_MS_Report.doc%3Fsequence%3D2+&cd=4&hl=en&ct=clnk&gl=us&client=ubuntu.
122 [7] - "A Very Compact Rijndael S-box". D. Canright.
123       http://www.dtic.mil/cgi-bin/GetTRDoc?AD=ADA434781.
124 */
125 
126 use std::ops::{BitAnd, BitXor, Not};
127 use std::default::Default;
128 
129 use cryptoutil::{read_u32v_le, write_u32_le};
130 use simd::u32x4;
131 use step_by::RangeExt;
132 use symmetriccipher::{BlockEncryptor, BlockEncryptorX8, BlockDecryptor, BlockDecryptorX8};
133 
134 const U32X4_0: u32x4 = u32x4(0, 0, 0, 0);
135 const U32X4_1: u32x4 = u32x4(0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff);
136 
137 macro_rules! define_aes_struct(
138     (
139         $name:ident,
140         $rounds:expr
141     ) => (
142         #[derive(Clone, Copy)]
143         pub struct $name {
144             sk: [Bs8State<u16>; ($rounds + 1)]
145         }
146     )
147 );
148 
149 macro_rules! define_aes_impl(
150     (
151         $name:ident,
152         $mode:ident,
153         $rounds:expr,
154         $key_size:expr
155     ) => (
156         impl $name {
157             pub fn new(key: &[u8]) -> $name {
158                 let mut a =  $name {
159                     sk: [Bs8State(0, 0, 0, 0, 0, 0, 0, 0); ($rounds + 1)]
160                 };
161                 let mut tmp = [[0u32; 4]; ($rounds + 1)];
162                 create_round_keys(key, KeyType::$mode, &mut tmp);
163                 for i in 0..$rounds + 1 {
164                     a.sk[i] = bit_slice_4x4_with_u16(tmp[i][0], tmp[i][1], tmp[i][2], tmp[i][3]);
165                 }
166                 a
167             }
168         }
169     )
170 );
171 
172 macro_rules! define_aes_enc(
173     (
174         $name:ident,
175         $rounds:expr
176     ) => (
177         impl BlockEncryptor for $name {
178             fn block_size(&self) -> usize { 16 }
179             fn encrypt_block(&self, input: &[u8], output: &mut [u8]) {
180                 let mut bs = bit_slice_1x16_with_u16(input);
181                 bs = encrypt_core(&bs, &self.sk);
182                 un_bit_slice_1x16_with_u16(&bs, output);
183             }
184         }
185     )
186 );
187 
188 macro_rules! define_aes_dec(
189     (
190         $name:ident,
191         $rounds:expr
192     ) => (
193         impl BlockDecryptor for $name {
194             fn block_size(&self) -> usize { 16 }
195             fn decrypt_block(&self, input: &[u8], output: &mut [u8]) {
196                 let mut bs = bit_slice_1x16_with_u16(input);
197                 bs = decrypt_core(&bs, &self.sk);
198                 un_bit_slice_1x16_with_u16(&bs, output);
199             }
200         }
201     )
202 );
203 
204 define_aes_struct!(AesSafe128Encryptor, 10);
205 define_aes_struct!(AesSafe128Decryptor, 10);
206 define_aes_impl!(AesSafe128Encryptor, Encryption, 10, 16);
207 define_aes_impl!(AesSafe128Decryptor, Decryption, 10, 16);
208 define_aes_enc!(AesSafe128Encryptor, 10);
209 define_aes_dec!(AesSafe128Decryptor, 10);
210 
211 define_aes_struct!(AesSafe192Encryptor, 12);
212 define_aes_struct!(AesSafe192Decryptor, 12);
213 define_aes_impl!(AesSafe192Encryptor, Encryption, 12, 24);
214 define_aes_impl!(AesSafe192Decryptor, Decryption, 12, 24);
215 define_aes_enc!(AesSafe192Encryptor, 12);
216 define_aes_dec!(AesSafe192Decryptor, 12);
217 
218 define_aes_struct!(AesSafe256Encryptor, 14);
219 define_aes_struct!(AesSafe256Decryptor, 14);
220 define_aes_impl!(AesSafe256Encryptor, Encryption, 14, 32);
221 define_aes_impl!(AesSafe256Decryptor, Decryption, 14, 32);
222 define_aes_enc!(AesSafe256Encryptor, 14);
223 define_aes_dec!(AesSafe256Decryptor, 14);
224 
225 macro_rules! define_aes_struct_x8(
226     (
227         $name:ident,
228         $rounds:expr
229     ) => (
230         #[derive(Clone, Copy)]
231         pub struct $name {
232             sk: [Bs8State<u32x4>; ($rounds + 1)]
233         }
234     )
235 );
236 
237 macro_rules! define_aes_impl_x8(
238     (
239         $name:ident,
240         $mode:ident,
241         $rounds:expr,
242         $key_size:expr
243     ) => (
244         impl $name {
245             pub fn new(key: &[u8]) -> $name {
246                 let mut a =  $name {
247                     sk: [
248                         Bs8State(
249                             U32X4_0,
250                             U32X4_0,
251                             U32X4_0,
252                             U32X4_0,
253                             U32X4_0,
254                             U32X4_0,
255                             U32X4_0,
256                             U32X4_0);
257                         ($rounds + 1)]
258                 };
259                 let mut tmp = [[0u32; 4]; ($rounds + 1)];
260                 create_round_keys(key, KeyType::$mode, &mut tmp);
261                 for i in 0..$rounds + 1 {
262                     a.sk[i] = bit_slice_fill_4x4_with_u32x4(
263                         tmp[i][0],
264                         tmp[i][1],
265                         tmp[i][2],
266                         tmp[i][3]);
267                 }
268                 a
269             }
270         }
271     )
272 );
273 
274 macro_rules! define_aes_enc_x8(
275     (
276         $name:ident,
277         $rounds:expr
278     ) => (
279         impl BlockEncryptorX8 for $name {
280             fn block_size(&self) -> usize { 16 }
281             fn encrypt_block_x8(&self, input: &[u8], output: &mut [u8]) {
282                 let bs = bit_slice_1x128_with_u32x4(input);
283                 let bs2 = encrypt_core(&bs, &self.sk);
284                 un_bit_slice_1x128_with_u32x4(bs2, output);
285             }
286         }
287     )
288 );
289 
290 macro_rules! define_aes_dec_x8(
291     (
292         $name:ident,
293         $rounds:expr
294     ) => (
295         impl BlockDecryptorX8 for $name {
296             fn block_size(&self) -> usize { 16 }
297             fn decrypt_block_x8(&self, input: &[u8], output: &mut [u8]) {
298                 let bs = bit_slice_1x128_with_u32x4(input);
299                 let bs2 = decrypt_core(&bs, &self.sk);
300                 un_bit_slice_1x128_with_u32x4(bs2, output);
301             }
302         }
303     )
304 );
305 
306 define_aes_struct_x8!(AesSafe128EncryptorX8, 10);
307 define_aes_struct_x8!(AesSafe128DecryptorX8, 10);
308 define_aes_impl_x8!(AesSafe128EncryptorX8, Encryption, 10, 16);
309 define_aes_impl_x8!(AesSafe128DecryptorX8, Decryption, 10, 16);
310 define_aes_enc_x8!(AesSafe128EncryptorX8, 10);
311 define_aes_dec_x8!(AesSafe128DecryptorX8, 10);
312 
313 define_aes_struct_x8!(AesSafe192EncryptorX8, 12);
314 define_aes_struct_x8!(AesSafe192DecryptorX8, 12);
315 define_aes_impl_x8!(AesSafe192EncryptorX8, Encryption, 12, 24);
316 define_aes_impl_x8!(AesSafe192DecryptorX8, Decryption, 12, 24);
317 define_aes_enc_x8!(AesSafe192EncryptorX8, 12);
318 define_aes_dec_x8!(AesSafe192DecryptorX8, 12);
319 
320 define_aes_struct_x8!(AesSafe256EncryptorX8, 14);
321 define_aes_struct_x8!(AesSafe256DecryptorX8, 14);
322 define_aes_impl_x8!(AesSafe256EncryptorX8, Encryption, 14, 32);
323 define_aes_impl_x8!(AesSafe256DecryptorX8, Decryption, 14, 32);
324 define_aes_enc_x8!(AesSafe256EncryptorX8, 14);
325 define_aes_dec_x8!(AesSafe256DecryptorX8, 14);
326 
ffmulx(x: u32) -> u32327 fn ffmulx(x: u32) -> u32 {
328     let m1: u32 = 0x80808080;
329     let m2: u32 = 0x7f7f7f7f;
330     let m3: u32 = 0x0000001b;
331     ((x & m2) << 1) ^ (((x & m1) >> 7) * m3)
332 }
333 
inv_mcol(x: u32) -> u32334 fn inv_mcol(x: u32) -> u32 {
335     let f2 = ffmulx(x);
336     let f4 = ffmulx(f2);
337     let f8 = ffmulx(f4);
338     let f9 = x ^ f8;
339 
340     f2 ^ f4 ^ f8 ^ (f2 ^ f9).rotate_right(8) ^ (f4 ^ f9).rotate_right(16) ^ f9.rotate_right(24)
341 }
342 
sub_word(x: u32) -> u32343 fn sub_word(x: u32) -> u32 {
344     let bs = bit_slice_4x1_with_u16(x).sub_bytes();
345     un_bit_slice_4x1_with_u16(&bs)
346 }
347 
348 enum KeyType {
349     Encryption,
350     Decryption
351 }
352 
353 // This array is not accessed in any key-dependant way, so there are no timing problems inherent in
354 // using it.
355 static RCON: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36];
356 
357 // The round keys are created without bit-slicing the key data. The individual implementations bit
358 // slice the round keys returned from this function. This function, and the few functions above, are
359 // derived from the BouncyCastle AES implementation.
create_round_keys(key: &[u8], key_type: KeyType, round_keys: &mut [[u32; 4]])360 fn create_round_keys(key: &[u8], key_type: KeyType, round_keys: &mut [[u32; 4]]) {
361     let (key_words, rounds) = match key.len() {
362         16 => (4, 10),
363         24 => (6, 12),
364         32 => (8, 14),
365         _ => panic!("Invalid AES key size.")
366     };
367 
368     // The key is copied directly into the first few round keys
369     let mut j = 0;
370     for i in (0..key.len()).step_up(4) {
371         round_keys[j / 4][j % 4] =
372             (key[i] as u32) |
373             ((key[i+1] as u32) << 8) |
374             ((key[i+2] as u32) << 16) |
375             ((key[i+3] as u32) << 24);
376         j += 1;
377     };
378 
379     // Calculate the rest of the round keys
380     for i in key_words..(rounds + 1) * 4 {
381         let mut tmp = round_keys[(i - 1) / 4][(i - 1) % 4];
382         if (i % key_words) == 0 {
383             tmp = sub_word(tmp.rotate_right(8)) ^ RCON[(i / key_words) - 1];
384         } else if (key_words == 8) && ((i % key_words) == 4) {
385             // This is only necessary for AES-256 keys
386             tmp = sub_word(tmp);
387         }
388         round_keys[i / 4][i % 4] = round_keys[(i - key_words) / 4][(i - key_words) % 4] ^ tmp;
389     }
390 
391     // Decryption round keys require extra processing
392     match key_type {
393         KeyType::Decryption => {
394             for j in 1..rounds {
395                 for i in 0..4 {
396                     round_keys[j][i] = inv_mcol(round_keys[j][i]);
397                 }
398             }
399         },
400         KeyType::Encryption => { }
401     }
402 }
403 
404 // This trait defines all of the operations needed for a type to be processed as part of an AES
405 // encryption or decryption operation.
406 trait AesOps {
sub_bytes(self) -> Self407     fn sub_bytes(self) -> Self;
inv_sub_bytes(self) -> Self408     fn inv_sub_bytes(self) -> Self;
shift_rows(self) -> Self409     fn shift_rows(self) -> Self;
inv_shift_rows(self) -> Self410     fn inv_shift_rows(self) -> Self;
mix_columns(self) -> Self411     fn mix_columns(self) -> Self;
inv_mix_columns(self) -> Self412     fn inv_mix_columns(self) -> Self;
add_round_key(self, rk: &Self) -> Self413     fn add_round_key(self, rk: &Self) -> Self;
414 }
415 
encrypt_core<S: AesOps + Copy>(state: &S, sk: &[S]) -> S416 fn encrypt_core<S: AesOps + Copy>(state: &S, sk: &[S]) -> S {
417     // Round 0 - add round key
418     let mut tmp = state.add_round_key(&sk[0]);
419 
420     // Remaining rounds (except last round)
421     for i in 1..sk.len() - 1 {
422         tmp = tmp.sub_bytes();
423         tmp = tmp.shift_rows();
424         tmp = tmp.mix_columns();
425         tmp = tmp.add_round_key(&sk[i]);
426     }
427 
428     // Last round
429     tmp = tmp.sub_bytes();
430     tmp = tmp.shift_rows();
431     tmp = tmp.add_round_key(&sk[sk.len() - 1]);
432 
433     tmp
434 }
435 
decrypt_core<S: AesOps + Copy>(state: &S, sk: &[S]) -> S436 fn decrypt_core<S: AesOps + Copy>(state: &S, sk: &[S]) -> S {
437     // Round 0 - add round key
438     let mut tmp = state.add_round_key(&sk[sk.len() - 1]);
439 
440     // Remaining rounds (except last round)
441     for i in 1..sk.len() - 1 {
442         tmp = tmp.inv_sub_bytes();
443         tmp = tmp.inv_shift_rows();
444         tmp = tmp.inv_mix_columns();
445         tmp = tmp.add_round_key(&sk[sk.len() - 1 - i]);
446     }
447 
448     // Last round
449     tmp = tmp.inv_sub_bytes();
450     tmp = tmp.inv_shift_rows();
451     tmp = tmp.add_round_key(&sk[0]);
452 
453     tmp
454 }
455 
456 #[derive(Clone, Copy)]
457 struct Bs8State<T>(T, T, T, T, T, T, T, T);
458 
459 impl <T: Copy> Bs8State<T> {
split(self) -> (Bs4State<T>, Bs4State<T>)460     fn split(self) -> (Bs4State<T>, Bs4State<T>) {
461         let Bs8State(x0, x1, x2, x3, x4, x5, x6, x7) = self;
462         (Bs4State(x0, x1, x2, x3), Bs4State(x4, x5, x6, x7))
463     }
464 }
465 
466 impl <T: BitXor<Output = T> + Copy> Bs8State<T> {
xor(self, rhs: Bs8State<T>) -> Bs8State<T>467     fn xor(self, rhs: Bs8State<T>) -> Bs8State<T> {
468         let Bs8State(a0, a1, a2, a3, a4, a5, a6, a7) = self;
469         let Bs8State(b0, b1, b2, b3, b4, b5, b6, b7) = rhs;
470         Bs8State(a0 ^ b0, a1 ^ b1, a2 ^ b2, a3 ^ b3, a4 ^ b4, a5 ^ b5, a6 ^ b6, a7 ^ b7)
471     }
472 
473     // We need to be able to convert a Bs8State to and from a polynomial basis and a normal
474     // basis. That transformation could be done via pseudocode that roughly looks like the
475     // following:
476     //
477     // for x in 0..8 {
478     //     for y in 0..8 {
479     //         result.x ^= input.y & MATRIX[7 - y][x]
480     //     }
481     // }
482     //
483     // Where the MATRIX is one of the following depending on the conversion being done.
484     // (The affine transformation step is included in all of these matrices):
485     //
486     // A2X = [
487     //     [ 0,  0,  0, -1, -1,  0,  0, -1],
488     //     [-1, -1,  0,  0, -1, -1, -1, -1],
489     //     [ 0, -1,  0,  0, -1, -1, -1, -1],
490     //     [ 0,  0,  0, -1,  0,  0, -1,  0],
491     //     [-1,  0,  0, -1,  0,  0,  0,  0],
492     //     [-1,  0,  0,  0,  0,  0,  0, -1],
493     //     [-1,  0,  0, -1,  0, -1,  0, -1],
494     //     [-1, -1, -1, -1, -1, -1, -1, -1]
495     // ];
496     //
497     // X2A = [
498     //     [ 0,  0, -1,  0,  0, -1, -1,  0],
499     //     [ 0,  0,  0, -1, -1, -1, -1,  0],
500     //     [ 0, -1, -1, -1,  0, -1, -1,  0],
501     //     [ 0,  0, -1, -1,  0,  0,  0, -1],
502     //     [ 0,  0,  0, -1,  0, -1, -1,  0],
503     //     [-1,  0,  0, -1,  0, -1,  0,  0],
504     //     [ 0, -1, -1, -1, -1,  0, -1, -1],
505     //     [ 0,  0,  0,  0,  0, -1, -1,  0],
506     // ];
507     //
508     // X2S = [
509     //     [ 0,  0,  0, -1, -1,  0, -1,  0],
510     //     [-1,  0, -1, -1,  0, -1,  0,  0],
511     //     [ 0, -1, -1, -1, -1,  0,  0, -1],
512     //     [-1, -1,  0, -1,  0,  0,  0,  0],
513     //     [ 0,  0, -1, -1, -1,  0, -1, -1],
514     //     [ 0,  0, -1,  0,  0,  0,  0,  0],
515     //     [-1, -1,  0,  0,  0,  0,  0,  0],
516     //     [ 0,  0, -1,  0,  0, -1,  0,  0],
517     // ];
518     //
519     // S2X = [
520     //     [ 0,  0, -1, -1,  0,  0,  0, -1],
521     //     [-1,  0,  0, -1, -1, -1, -1,  0],
522     //     [-1,  0, -1,  0,  0,  0,  0,  0],
523     //     [-1, -1,  0, -1,  0, -1, -1, -1],
524     //     [ 0, -1,  0,  0, -1,  0,  0,  0],
525     //     [ 0,  0, -1,  0,  0,  0,  0,  0],
526     //     [-1,  0,  0,  0, -1,  0, -1,  0],
527     //     [-1, -1,  0,  0, -1,  0, -1,  0],
528     // ];
529     //
530     // Looking at the pseudocode implementation, we see that there is no point
531     // in processing any of the elements in those matrices that have zero values
532     // since a logical AND with 0 will produce 0 which will have no effect when it
533     // is XORed into the result.
534     //
535     // LLVM doesn't appear to be able to fully unroll the loops in the pseudocode
536     // above and to eliminate processing of the 0 elements. So, each transformation is
537     // implemented independently directly in fully unrolled form with the 0 elements
538     // removed.
539     //
540     // As an optimization, elements that are XORed together multiple times are
541     // XORed just once and then used multiple times. I wrote a simple program that
542     // greedily looked for terms to combine to create the implementations below.
543     // It is likely that this could be optimized more.
544 
change_basis_a2x(&self) -> Bs8State<T>545     fn change_basis_a2x(&self) -> Bs8State<T> {
546         let t06 = self.6 ^ self.0;
547         let t056 = self.5 ^ t06;
548         let t0156 = t056 ^ self.1;
549         let t13 = self.1 ^ self.3;
550 
551         let x0 = self.2 ^ t06 ^ t13;
552         let x1 = t056;
553         let x2 = self.0;
554         let x3 = self.0 ^ self.4 ^ self.7 ^ t13;
555         let x4 = self.7 ^ t056;
556         let x5 = t0156;
557         let x6 = self.4 ^ t056;
558         let x7 = self.2 ^ self.7 ^ t0156;
559 
560         Bs8State(x0, x1, x2, x3, x4, x5, x6, x7)
561     }
562 
change_basis_x2s(&self) -> Bs8State<T>563     fn change_basis_x2s(&self) -> Bs8State<T> {
564         let t46 = self.4 ^ self.6;
565         let t35 = self.3 ^ self.5;
566         let t06 = self.0 ^ self.6;
567         let t357 = t35 ^ self.7;
568 
569         let x0 = self.1 ^ t46;
570         let x1 = self.1 ^ self.4 ^ self.5;
571         let x2 = self.2 ^ t35 ^ t06;
572         let x3 = t46 ^ t357;
573         let x4 = t357;
574         let x5 = t06;
575         let x6 = self.3 ^ self.7;
576         let x7 = t35;
577 
578         Bs8State(x0, x1, x2, x3, x4, x5, x6, x7)
579     }
580 
change_basis_x2a(&self) -> Bs8State<T>581     fn change_basis_x2a(&self) -> Bs8State<T> {
582         let t15 = self.1 ^ self.5;
583         let t36 = self.3 ^ self.6;
584         let t1356 = t15 ^ t36;
585         let t07 = self.0 ^ self.7;
586 
587         let x0 = self.2;
588         let x1 = t15;
589         let x2 = self.4 ^ self.7 ^ t15;
590         let x3 = self.2 ^ self.4 ^ t1356;
591         let x4 = self.1 ^ self.6;
592         let x5 = self.2 ^ self.5 ^ t36 ^ t07;
593         let x6 = t1356 ^ t07;
594         let x7 = self.1 ^ self.4;
595 
596         Bs8State(x0, x1, x2, x3, x4, x5, x6, x7)
597     }
598 
change_basis_s2x(&self) -> Bs8State<T>599     fn change_basis_s2x(&self) -> Bs8State<T> {
600         let t46 = self.4 ^ self.6;
601         let t01 = self.0 ^ self.1;
602         let t0146 = t01 ^ t46;
603 
604         let x0 = self.5 ^ t0146;
605         let x1 = self.0 ^ self.3 ^ self.4;
606         let x2 = self.2 ^ self.5 ^ self.7;
607         let x3 = self.7 ^ t46;
608         let x4 = self.3 ^ self.6 ^ t01;
609         let x5 = t46;
610         let x6 = t0146;
611         let x7 = self.4 ^ self.7;
612 
613         Bs8State(x0, x1, x2, x3, x4, x5, x6, x7)
614     }
615 }
616 
617 impl <T: Not<Output = T> + Copy> Bs8State<T> {
618     // The special value "x63" is used as part of the sub_bytes and inv_sub_bytes
619     // steps. It is conceptually a Bs8State value where the 0th, 1st, 5th, and 6th
620     // elements are all 1s and the other elements are all 0s. The only thing that
621     // we do with the "x63" value is to XOR a Bs8State with it. We optimize that XOR
622     // below into just inverting 4 of the elements and leaving the other 4 elements
623     // untouched.
xor_x63(self) -> Bs8State<T>624     fn xor_x63(self) -> Bs8State<T> {
625         Bs8State (
626             !self.0,
627             !self.1,
628             self.2,
629             self.3,
630             self.4,
631             !self.5,
632             !self.6,
633             self.7)
634     }
635 }
636 
637 #[derive(Clone, Copy)]
638 struct Bs4State<T>(T, T, T, T);
639 
640 impl <T: Copy> Bs4State<T> {
split(self) -> (Bs2State<T>, Bs2State<T>)641     fn split(self) -> (Bs2State<T>, Bs2State<T>) {
642         let Bs4State(x0, x1, x2, x3) = self;
643         (Bs2State(x0, x1), Bs2State(x2, x3))
644     }
645 
join(self, rhs: Bs4State<T>) -> Bs8State<T>646     fn join(self, rhs: Bs4State<T>) -> Bs8State<T> {
647         let Bs4State(a0, a1, a2, a3) = self;
648         let Bs4State(b0, b1, b2, b3) = rhs;
649         Bs8State(a0, a1, a2, a3, b0, b1, b2, b3)
650     }
651 }
652 
653 impl <T: BitXor<Output = T> + Copy> Bs4State<T> {
xor(self, rhs: Bs4State<T>) -> Bs4State<T>654     fn xor(self, rhs: Bs4State<T>) -> Bs4State<T> {
655         let Bs4State(a0, a1, a2, a3) = self;
656         let Bs4State(b0, b1, b2, b3) = rhs;
657         Bs4State(a0 ^ b0, a1 ^ b1, a2 ^ b2, a3 ^ b3)
658     }
659 }
660 
661 #[derive(Clone, Copy)]
662 struct Bs2State<T>(T, T);
663 
664 impl <T> Bs2State<T> {
split(self) -> (T, T)665     fn split(self) -> (T, T) {
666         let Bs2State(x0, x1) = self;
667         (x0, x1)
668     }
669 
join(self, rhs: Bs2State<T>) -> Bs4State<T>670     fn join(self, rhs: Bs2State<T>) -> Bs4State<T> {
671         let Bs2State(a0, a1) = self;
672         let Bs2State(b0, b1) = rhs;
673         Bs4State(a0, a1, b0, b1)
674     }
675 }
676 
677 impl <T: BitXor<Output = T> + Copy> Bs2State<T> {
xor(self, rhs: Bs2State<T>) -> Bs2State<T>678     fn xor(self, rhs: Bs2State<T>) -> Bs2State<T> {
679         let Bs2State(a0, a1) = self;
680         let Bs2State(b0, b1) = rhs;
681         Bs2State(a0 ^ b0, a1 ^ b1)
682     }
683 }
684 
685 // Bit Slice data in the form of 4 u32s in column-major order
bit_slice_4x4_with_u16(a: u32, b: u32, c: u32, d: u32) -> Bs8State<u16>686 fn bit_slice_4x4_with_u16(a: u32, b: u32, c: u32, d: u32) -> Bs8State<u16> {
687     fn pb(x: u32, bit: u32, shift: u32) -> u16 {
688         (((x >> bit) & 1) as u16) << shift
689     }
690 
691     fn construct(a: u32, b: u32, c: u32, d: u32, bit: u32) -> u16 {
692         pb(a, bit, 0)       | pb(b, bit, 1)       | pb(c, bit, 2)       | pb(d, bit, 3)       |
693         pb(a, bit + 8, 4)   | pb(b, bit + 8, 5)   | pb(c, bit + 8, 6)   | pb(d, bit + 8, 7)   |
694         pb(a, bit + 16, 8)  | pb(b, bit + 16, 9)  | pb(c, bit + 16, 10) | pb(d, bit + 16, 11) |
695         pb(a, bit + 24, 12) | pb(b, bit + 24, 13) | pb(c, bit + 24, 14) | pb(d, bit + 24, 15)
696     }
697 
698     let x0 = construct(a, b, c, d, 0);
699     let x1 = construct(a, b, c, d, 1);
700     let x2 = construct(a, b, c, d, 2);
701     let x3 = construct(a, b, c, d, 3);
702     let x4 = construct(a, b, c, d, 4);
703     let x5 = construct(a, b, c, d, 5);
704     let x6 = construct(a, b, c, d, 6);
705     let x7 = construct(a, b, c, d, 7);
706 
707     Bs8State(x0, x1, x2, x3, x4, x5, x6, x7)
708 }
709 
710 // Bit slice a single u32 value - this is used to calculate the SubBytes step when creating the
711 // round keys.
bit_slice_4x1_with_u16(a: u32) -> Bs8State<u16>712 fn bit_slice_4x1_with_u16(a: u32) -> Bs8State<u16> {
713     bit_slice_4x4_with_u16(a, 0, 0, 0)
714 }
715 
716 // Bit slice a 16 byte array in column major order
bit_slice_1x16_with_u16(data: &[u8]) -> Bs8State<u16>717 fn bit_slice_1x16_with_u16(data: &[u8]) -> Bs8State<u16> {
718     let mut n = [0u32; 4];
719     read_u32v_le(&mut n, data);
720 
721     let a = n[0];
722     let b = n[1];
723     let c = n[2];
724     let d = n[3];
725 
726     bit_slice_4x4_with_u16(a, b, c, d)
727 }
728 
729 // Un Bit Slice into a set of 4 u32s
un_bit_slice_4x4_with_u16(bs: &Bs8State<u16>) -> (u32, u32, u32, u32)730 fn un_bit_slice_4x4_with_u16(bs: &Bs8State<u16>) -> (u32, u32, u32, u32) {
731     fn pb(x: u16, bit: u32, shift: u32) -> u32 {
732         (((x >> bit) & 1) as u32) << shift
733     }
734 
735     fn deconstruct(bs: &Bs8State<u16>, bit: u32) -> u32 {
736         let Bs8State(x0, x1, x2, x3, x4, x5, x6, x7) = *bs;
737 
738         pb(x0, bit, 0) | pb(x1, bit, 1) | pb(x2, bit, 2) | pb(x3, bit, 3) |
739         pb(x4, bit, 4) | pb(x5, bit, 5) | pb(x6, bit, 6) | pb(x7, bit, 7) |
740 
741         pb(x0, bit + 4, 8)  | pb(x1, bit + 4, 9)  | pb(x2, bit + 4, 10) | pb(x3, bit + 4, 11) |
742         pb(x4, bit + 4, 12) | pb(x5, bit + 4, 13) | pb(x6, bit + 4, 14) | pb(x7, bit + 4, 15) |
743 
744         pb(x0, bit + 8, 16) | pb(x1, bit + 8, 17) | pb(x2, bit + 8, 18) | pb(x3, bit + 8, 19) |
745         pb(x4, bit + 8, 20) | pb(x5, bit + 8, 21) | pb(x6, bit + 8, 22) | pb(x7, bit + 8, 23) |
746 
747         pb(x0, bit + 12, 24) | pb(x1, bit + 12, 25) | pb(x2, bit + 12, 26) | pb(x3, bit + 12, 27) |
748         pb(x4, bit + 12, 28) | pb(x5, bit + 12, 29) | pb(x6, bit + 12, 30) | pb(x7, bit + 12, 31)
749     }
750 
751     let a = deconstruct(bs, 0);
752     let b = deconstruct(bs, 1);
753     let c = deconstruct(bs, 2);
754     let d = deconstruct(bs, 3);
755 
756     (a, b, c, d)
757 }
758 
759 // Un Bit Slice into a single u32. This is used when creating the round keys.
un_bit_slice_4x1_with_u16(bs: &Bs8State<u16>) -> u32760 fn un_bit_slice_4x1_with_u16(bs: &Bs8State<u16>) -> u32 {
761     let (a, _, _, _) = un_bit_slice_4x4_with_u16(bs);
762     a
763 }
764 
765 // Un Bit Slice into a 16 byte array
un_bit_slice_1x16_with_u16(bs: &Bs8State<u16>, output: &mut [u8])766 fn un_bit_slice_1x16_with_u16(bs: &Bs8State<u16>, output: &mut [u8]) {
767     let (a, b, c, d) = un_bit_slice_4x4_with_u16(bs);
768 
769     write_u32_le(&mut output[0..4], a);
770     write_u32_le(&mut output[4..8], b);
771     write_u32_le(&mut output[8..12], c);
772     write_u32_le(&mut output[12..16], d);
773 }
774 
775 // Bit Slice a 128 byte array of eight 16 byte blocks. Each block is in column major order.
bit_slice_1x128_with_u32x4(data: &[u8]) -> Bs8State<u32x4>776 fn bit_slice_1x128_with_u32x4(data: &[u8]) -> Bs8State<u32x4> {
777     let bit0 = u32x4(0x01010101, 0x01010101, 0x01010101, 0x01010101);
778     let bit1 = u32x4(0x02020202, 0x02020202, 0x02020202, 0x02020202);
779     let bit2 = u32x4(0x04040404, 0x04040404, 0x04040404, 0x04040404);
780     let bit3 = u32x4(0x08080808, 0x08080808, 0x08080808, 0x08080808);
781     let bit4 = u32x4(0x10101010, 0x10101010, 0x10101010, 0x10101010);
782     let bit5 = u32x4(0x20202020, 0x20202020, 0x20202020, 0x20202020);
783     let bit6 = u32x4(0x40404040, 0x40404040, 0x40404040, 0x40404040);
784     let bit7 = u32x4(0x80808080, 0x80808080, 0x80808080, 0x80808080);
785 
786     fn read_row_major(data: &[u8]) -> u32x4 {
787         u32x4(
788             (data[0] as u32) |
789             ((data[4] as u32) << 8) |
790             ((data[8] as u32) << 16) |
791             ((data[12] as u32) << 24),
792             (data[1] as u32) |
793             ((data[5] as u32) << 8) |
794             ((data[9] as u32) << 16) |
795             ((data[13] as u32) << 24),
796             (data[2] as u32) |
797             ((data[6] as u32) << 8) |
798             ((data[10] as u32) << 16) |
799             ((data[14] as u32) << 24),
800             (data[3] as u32) |
801             ((data[7] as u32) << 8) |
802             ((data[11] as u32) << 16) |
803             ((data[15] as u32) << 24))
804     }
805 
806     let t0 = read_row_major(&data[0..16]);
807     let t1 = read_row_major(&data[16..32]);
808     let t2 = read_row_major(&data[32..48]);
809     let t3 = read_row_major(&data[48..64]);
810     let t4 = read_row_major(&data[64..80]);
811     let t5 = read_row_major(&data[80..96]);
812     let t6 = read_row_major(&data[96..112]);
813     let t7 = read_row_major(&data[112..128]);
814 
815     let x0 = (t0 & bit0) | (t1.lsh(1) & bit1) | (t2.lsh(2) & bit2) | (t3.lsh(3) & bit3) |
816         (t4.lsh(4) & bit4) | (t5.lsh(5) & bit5) | (t6.lsh(6) & bit6) | (t7.lsh(7) & bit7);
817     let x1 = (t0.rsh(1) & bit0) | (t1 & bit1) | (t2.lsh(1) & bit2) | (t3.lsh(2) & bit3) |
818         (t4.lsh(3) & bit4) | (t5.lsh(4) & bit5) | (t6.lsh(5) & bit6) | (t7.lsh(6) & bit7);
819     let x2 = (t0.rsh(2) & bit0) | (t1.rsh(1) & bit1) | (t2 & bit2) | (t3.lsh(1) & bit3) |
820         (t4.lsh(2) & bit4) | (t5.lsh(3) & bit5) | (t6.lsh(4) & bit6) | (t7.lsh(5) & bit7);
821     let x3 = (t0.rsh(3) & bit0) | (t1.rsh(2) & bit1) | (t2.rsh(1) & bit2) | (t3 & bit3) |
822         (t4.lsh(1) & bit4) | (t5.lsh(2) & bit5) | (t6.lsh(3) & bit6) | (t7.lsh(4) & bit7);
823     let x4 = (t0.rsh(4) & bit0) | (t1.rsh(3) & bit1) | (t2.rsh(2) & bit2) | (t3.rsh(1) & bit3) |
824         (t4 & bit4) | (t5.lsh(1) & bit5) | (t6.lsh(2) & bit6) | (t7.lsh(3) & bit7);
825     let x5 = (t0.rsh(5) & bit0) | (t1.rsh(4) & bit1) | (t2.rsh(3) & bit2) | (t3.rsh(2) & bit3) |
826         (t4.rsh(1) & bit4) | (t5 & bit5) | (t6.lsh(1) & bit6) | (t7.lsh(2) & bit7);
827     let x6 = (t0.rsh(6) & bit0) | (t1.rsh(5) & bit1) | (t2.rsh(4) & bit2) | (t3.rsh(3) & bit3) |
828         (t4.rsh(2) & bit4) | (t5.rsh(1) & bit5) | (t6 & bit6) | (t7.lsh(1) & bit7);
829     let x7 = (t0.rsh(7) & bit0) | (t1.rsh(6) & bit1) | (t2.rsh(5) & bit2) | (t3.rsh(4) & bit3) |
830         (t4.rsh(3) & bit4) | (t5.rsh(2) & bit5) | (t6.rsh(1) & bit6) | (t7 & bit7);
831 
832     Bs8State(x0, x1, x2, x3, x4, x5, x6, x7)
833 }
834 
835 // Bit slice a set of 4 u32s by filling a full 128 byte data block with those repeated values. This
836 // is used as part of bit slicing the round keys.
bit_slice_fill_4x4_with_u32x4(a: u32, b: u32, c: u32, d: u32) -> Bs8State<u32x4>837 fn bit_slice_fill_4x4_with_u32x4(a: u32, b: u32, c: u32, d: u32) -> Bs8State<u32x4> {
838     let mut tmp = [0u8; 128];
839     for i in 0..8 {
840         write_u32_le(&mut tmp[i * 16..i * 16 + 4], a);
841         write_u32_le(&mut tmp[i * 16 + 4..i * 16 + 8], b);
842         write_u32_le(&mut tmp[i * 16 + 8..i * 16 + 12], c);
843         write_u32_le(&mut tmp[i * 16 + 12..i * 16 + 16], d);
844     }
845     bit_slice_1x128_with_u32x4(&tmp)
846 }
847 
848 // Un bit slice into a 128 byte buffer.
un_bit_slice_1x128_with_u32x4(bs: Bs8State<u32x4>, output: &mut [u8])849 fn un_bit_slice_1x128_with_u32x4(bs: Bs8State<u32x4>, output: &mut [u8]) {
850     let Bs8State(t0, t1, t2, t3, t4, t5, t6, t7) = bs;
851 
852     let bit0 = u32x4(0x01010101, 0x01010101, 0x01010101, 0x01010101);
853     let bit1 = u32x4(0x02020202, 0x02020202, 0x02020202, 0x02020202);
854     let bit2 = u32x4(0x04040404, 0x04040404, 0x04040404, 0x04040404);
855     let bit3 = u32x4(0x08080808, 0x08080808, 0x08080808, 0x08080808);
856     let bit4 = u32x4(0x10101010, 0x10101010, 0x10101010, 0x10101010);
857     let bit5 = u32x4(0x20202020, 0x20202020, 0x20202020, 0x20202020);
858     let bit6 = u32x4(0x40404040, 0x40404040, 0x40404040, 0x40404040);
859     let bit7 = u32x4(0x80808080, 0x80808080, 0x80808080, 0x80808080);
860 
861     // decode the individual blocks, in row-major order
862     // TODO: this is identical to the same block in bit_slice_1x128_with_u32x4
863     let x0 = (t0 & bit0) | (t1.lsh(1) & bit1) | (t2.lsh(2) & bit2) | (t3.lsh(3) & bit3) |
864         (t4.lsh(4) & bit4) | (t5.lsh(5) & bit5) | (t6.lsh(6) & bit6) | (t7.lsh(7) & bit7);
865     let x1 = (t0.rsh(1) & bit0) | (t1 & bit1) | (t2.lsh(1) & bit2) | (t3.lsh(2) & bit3) |
866         (t4.lsh(3) & bit4) | (t5.lsh(4) & bit5) | (t6.lsh(5) & bit6) | (t7.lsh(6) & bit7);
867     let x2 = (t0.rsh(2) & bit0) | (t1.rsh(1) & bit1) | (t2 & bit2) | (t3.lsh(1) & bit3) |
868         (t4.lsh(2) & bit4) | (t5.lsh(3) & bit5) | (t6.lsh(4) & bit6) | (t7.lsh(5) & bit7);
869     let x3 = (t0.rsh(3) & bit0) | (t1.rsh(2) & bit1) | (t2.rsh(1) & bit2) | (t3 & bit3) |
870         (t4.lsh(1) & bit4) | (t5.lsh(2) & bit5) | (t6.lsh(3) & bit6) | (t7.lsh(4) & bit7);
871     let x4 = (t0.rsh(4) & bit0) | (t1.rsh(3) & bit1) | (t2.rsh(2) & bit2) | (t3.rsh(1) & bit3) |
872         (t4 & bit4) | (t5.lsh(1) & bit5) | (t6.lsh(2) & bit6) | (t7.lsh(3) & bit7);
873     let x5 = (t0.rsh(5) & bit0) | (t1.rsh(4) & bit1) | (t2.rsh(3) & bit2) | (t3.rsh(2) & bit3) |
874         (t4.rsh(1) & bit4) | (t5 & bit5) | (t6.lsh(1) & bit6) | (t7.lsh(2) & bit7);
875     let x6 = (t0.rsh(6) & bit0) | (t1.rsh(5) & bit1) | (t2.rsh(4) & bit2) | (t3.rsh(3) & bit3) |
876         (t4.rsh(2) & bit4) | (t5.rsh(1) & bit5) | (t6 & bit6) | (t7.lsh(1) & bit7);
877     let x7 = (t0.rsh(7) & bit0) | (t1.rsh(6) & bit1) | (t2.rsh(5) & bit2) | (t3.rsh(4) & bit3) |
878         (t4.rsh(3) & bit4) | (t5.rsh(2) & bit5) | (t6.rsh(1) & bit6) | (t7 & bit7);
879 
880     fn write_row_major(block: u32x4, output: &mut [u8]) {
881         let u32x4(a0, a1, a2, a3) = block;
882         output[0] = a0 as u8;
883         output[1] = a1 as u8;
884         output[2] = a2 as u8;
885         output[3] = a3 as u8;
886         output[4] = (a0 >> 8) as u8;
887         output[5] = (a1 >> 8) as u8;
888         output[6] = (a2 >> 8) as u8;
889         output[7] = (a3 >> 8) as u8;
890         output[8] = (a0 >> 16) as u8;
891         output[9] = (a1 >> 16) as u8;
892         output[10] = (a2 >> 16) as u8;
893         output[11] = (a3 >> 16) as u8;
894         output[12] = (a0 >> 24) as u8;
895         output[13] = (a1 >> 24) as u8;
896         output[14] = (a2 >> 24) as u8;
897         output[15] = (a3 >> 24) as u8;
898     }
899 
900     write_row_major(x0, &mut output[0..16]);
901     write_row_major(x1, &mut output[16..32]);
902     write_row_major(x2, &mut output[32..48]);
903     write_row_major(x3, &mut output[48..64]);
904     write_row_major(x4, &mut output[64..80]);
905     write_row_major(x5, &mut output[80..96]);
906     write_row_major(x6, &mut output[96..112]);
907     write_row_major(x7, &mut output[112..128])
908 }
909 
910 // The Gf2Ops, Gf4Ops, and Gf8Ops traits specify the functions needed to calculate the AES S-Box
911 // values. This particuar implementation of those S-Box values is taken from [7], so that is where
912 // to look for details on how all that all works. This includes the transformations matrices defined
913 // below for the change_basis operation on the u32 and u32x4 types.
914 
915 // Operations in GF(2^2) using normal basis (Omega^2,Omega)
916 trait Gf2Ops {
917     // multiply
mul(self, y: Self) -> Self918     fn mul(self, y: Self) -> Self;
919 
920     // scale by N = Omega^2
scl_n(self) -> Self921     fn scl_n(self) -> Self;
922 
923     // scale by N^2 = Omega
scl_n2(self) -> Self924     fn scl_n2(self) -> Self;
925 
926     // square
sq(self) -> Self927     fn sq(self) -> Self;
928 
929     // Same as sqaure
inv(self) -> Self930     fn inv(self) -> Self;
931 }
932 
933 impl <T: BitXor<Output = T> + BitAnd<Output = T> + Copy> Gf2Ops for Bs2State<T> {
mul(self, y: Bs2State<T>) -> Bs2State<T>934     fn mul(self, y: Bs2State<T>) -> Bs2State<T> {
935         let (b, a) = self.split();
936         let (d, c) = y.split();
937         let e = (a ^ b) & (c ^ d);
938         let p = (a & c) ^ e;
939         let q = (b & d) ^ e;
940         Bs2State(q, p)
941     }
942 
scl_n(self) -> Bs2State<T>943     fn scl_n(self) -> Bs2State<T> {
944         let (b, a) = self.split();
945         let q = a ^ b;
946         Bs2State(q, b)
947     }
948 
scl_n2(self) -> Bs2State<T>949     fn scl_n2(self) -> Bs2State<T> {
950         let (b, a) = self.split();
951         let p = a ^ b;
952         let q = a;
953         Bs2State(q, p)
954     }
955 
sq(self) -> Bs2State<T>956     fn sq(self) -> Bs2State<T> {
957         let (b, a) = self.split();
958         Bs2State(a, b)
959     }
960 
inv(self) -> Bs2State<T>961     fn inv(self) -> Bs2State<T> {
962         self.sq()
963     }
964 }
965 
966 // Operations in GF(2^4) using normal basis (alpha^8,alpha^2)
967 trait Gf4Ops {
968     // multiply
mul(self, y: Self) -> Self969     fn mul(self, y: Self) -> Self;
970 
971     // square & scale by nu
972     // nu = beta^8 = N^2*alpha^2, N = w^2
sq_scl(self) -> Self973     fn sq_scl(self) -> Self;
974 
975     // inverse
inv(self) -> Self976     fn inv(self) -> Self;
977 }
978 
979 impl <T: BitXor<Output = T> + BitAnd<Output = T> + Copy> Gf4Ops for Bs4State<T> {
mul(self, y: Bs4State<T>) -> Bs4State<T>980     fn mul(self, y: Bs4State<T>) -> Bs4State<T> {
981         let (b, a) = self.split();
982         let (d, c) = y.split();
983         let f = c.xor(d);
984         let e = a.xor(b).mul(f).scl_n();
985         let p = a.mul(c).xor(e);
986         let q = b.mul(d).xor(e);
987         q.join(p)
988     }
989 
sq_scl(self) -> Bs4State<T>990     fn sq_scl(self) -> Bs4State<T> {
991         let (b, a) = self.split();
992         let p = a.xor(b).sq();
993         let q = b.sq().scl_n2();
994         q.join(p)
995     }
996 
inv(self) -> Bs4State<T>997     fn inv(self) -> Bs4State<T> {
998         let (b, a) = self.split();
999         let c = a.xor(b).sq().scl_n();
1000         let d = a.mul(b);
1001         let e = c.xor(d).inv();
1002         let p = e.mul(b);
1003         let q = e.mul(a);
1004         q.join(p)
1005     }
1006 }
1007 
1008 // Operations in GF(2^8) using normal basis (d^16,d)
1009 trait Gf8Ops {
1010     // inverse
inv(&self) -> Self1011     fn inv(&self) -> Self;
1012 }
1013 
1014 impl <T: BitXor<Output = T> + BitAnd<Output = T> + Copy + Default> Gf8Ops for Bs8State<T> {
inv(&self) -> Bs8State<T>1015     fn inv(&self) -> Bs8State<T> {
1016         let (b, a) = self.split();
1017         let c = a.xor(b).sq_scl();
1018         let d = a.mul(b);
1019         let e = c.xor(d).inv();
1020         let p = e.mul(b);
1021         let q = e.mul(a);
1022         q.join(p)
1023     }
1024 }
1025 
1026 impl <T: AesBitValueOps + Copy + 'static> AesOps for Bs8State<T> {
sub_bytes(self) -> Bs8State<T>1027     fn sub_bytes(self) -> Bs8State<T> {
1028         let nb: Bs8State<T> = self.change_basis_a2x();
1029         let inv = nb.inv();
1030         let nb2: Bs8State<T> = inv.change_basis_x2s();
1031         nb2.xor_x63()
1032     }
1033 
inv_sub_bytes(self) -> Bs8State<T>1034     fn inv_sub_bytes(self) -> Bs8State<T> {
1035         let t = self.xor_x63();
1036         let nb: Bs8State<T> = t.change_basis_s2x();
1037         let inv = nb.inv();
1038         inv.change_basis_x2a()
1039     }
1040 
shift_rows(self) -> Bs8State<T>1041     fn shift_rows(self) -> Bs8State<T> {
1042         let Bs8State(x0, x1, x2, x3, x4, x5, x6, x7) = self;
1043         Bs8State(
1044             x0.shift_row(),
1045             x1.shift_row(),
1046             x2.shift_row(),
1047             x3.shift_row(),
1048             x4.shift_row(),
1049             x5.shift_row(),
1050             x6.shift_row(),
1051             x7.shift_row())
1052     }
1053 
inv_shift_rows(self) -> Bs8State<T>1054     fn inv_shift_rows(self) -> Bs8State<T> {
1055         let Bs8State(x0, x1, x2, x3, x4, x5, x6, x7) = self;
1056         Bs8State(
1057             x0.inv_shift_row(),
1058             x1.inv_shift_row(),
1059             x2.inv_shift_row(),
1060             x3.inv_shift_row(),
1061             x4.inv_shift_row(),
1062             x5.inv_shift_row(),
1063             x6.inv_shift_row(),
1064             x7.inv_shift_row())
1065     }
1066 
1067     // Formula from [5]
mix_columns(self) -> Bs8State<T>1068     fn mix_columns(self) -> Bs8State<T> {
1069         let Bs8State(x0, x1, x2, x3, x4, x5, x6, x7) = self;
1070 
1071         let x0out = x7 ^ x7.ror1() ^ x0.ror1() ^ (x0 ^ x0.ror1()).ror2();
1072         let x1out = x0 ^ x0.ror1() ^ x7 ^ x7.ror1() ^ x1.ror1() ^ (x1 ^ x1.ror1()).ror2();
1073         let x2out = x1 ^ x1.ror1() ^ x2.ror1() ^ (x2 ^ x2.ror1()).ror2();
1074         let x3out = x2 ^ x2.ror1() ^ x7 ^ x7.ror1() ^ x3.ror1() ^ (x3 ^ x3.ror1()).ror2();
1075         let x4out = x3 ^ x3.ror1() ^ x7 ^ x7.ror1() ^ x4.ror1() ^ (x4 ^ x4.ror1()).ror2();
1076         let x5out = x4 ^ x4.ror1() ^ x5.ror1() ^ (x5 ^ x5.ror1()).ror2();
1077         let x6out = x5 ^ x5.ror1() ^ x6.ror1() ^ (x6 ^ x6.ror1()).ror2();
1078         let x7out = x6 ^ x6.ror1() ^ x7.ror1() ^ (x7 ^ x7.ror1()).ror2();
1079 
1080         Bs8State(x0out, x1out, x2out, x3out, x4out, x5out, x6out, x7out)
1081     }
1082 
1083     // Formula from [6]
inv_mix_columns(self) -> Bs8State<T>1084     fn inv_mix_columns(self) -> Bs8State<T> {
1085         let Bs8State(x0, x1, x2, x3, x4, x5, x6, x7) = self;
1086 
1087         let x0out = x5 ^ x6 ^ x7 ^
1088             (x5 ^ x7 ^ x0).ror1() ^
1089             (x0 ^ x5 ^ x6).ror2() ^
1090             (x5 ^ x0).ror3();
1091         let x1out = x5 ^ x0 ^
1092             (x6 ^ x5 ^ x0 ^ x7 ^ x1).ror1() ^
1093             (x1 ^ x7 ^ x5).ror2() ^
1094             (x6 ^ x5 ^ x1).ror3();
1095         let x2out = x6 ^ x0 ^ x1 ^
1096             (x7 ^ x6 ^ x1 ^ x2).ror1() ^
1097             (x0 ^ x2 ^ x6).ror2() ^
1098             (x7 ^ x6 ^ x2).ror3();
1099         let x3out = x0 ^ x5 ^ x1 ^ x6 ^ x2 ^
1100             (x0 ^ x5 ^ x2 ^ x3).ror1() ^
1101             (x0 ^ x1 ^ x3 ^ x5 ^ x6 ^ x7).ror2() ^
1102             (x0 ^ x5 ^ x7 ^ x3).ror3();
1103         let x4out = x1 ^ x5 ^ x2 ^ x3 ^
1104             (x1 ^ x6 ^ x5 ^ x3 ^ x7 ^ x4).ror1() ^
1105             (x1 ^ x2 ^ x4 ^ x5 ^ x7).ror2() ^
1106             (x1 ^ x5 ^ x6 ^ x4).ror3();
1107         let x5out = x2 ^ x6 ^ x3 ^ x4 ^
1108             (x2 ^ x7 ^ x6 ^ x4 ^ x5).ror1() ^
1109             (x2 ^ x3 ^ x5 ^ x6).ror2() ^
1110             (x2 ^ x6 ^ x7 ^ x5).ror3();
1111         let x6out =  x3 ^ x7 ^ x4 ^ x5 ^
1112             (x3 ^ x7 ^ x5 ^ x6).ror1() ^
1113             (x3 ^ x4 ^ x6 ^ x7).ror2() ^
1114             (x3 ^ x7 ^ x6).ror3();
1115         let x7out = x4 ^ x5 ^ x6 ^
1116             (x4 ^ x6 ^ x7).ror1() ^
1117             (x4 ^ x5 ^ x7).ror2() ^
1118             (x4 ^ x7).ror3();
1119 
1120         Bs8State(x0out, x1out, x2out, x3out, x4out, x5out, x6out, x7out)
1121     }
1122 
add_round_key(self, rk: &Bs8State<T>) -> Bs8State<T>1123     fn add_round_key(self, rk: &Bs8State<T>) -> Bs8State<T> {
1124         self.xor(*rk)
1125     }
1126 }
1127 
1128 trait AesBitValueOps: BitXor<Output = Self> + BitAnd<Output = Self> + Not<Output = Self> + Default + Sized {
shift_row(self) -> Self1129     fn shift_row(self) -> Self;
inv_shift_row(self) -> Self1130     fn inv_shift_row(self) -> Self;
ror1(self) -> Self1131     fn ror1(self) -> Self;
ror2(self) -> Self1132     fn ror2(self) -> Self;
ror3(self) -> Self1133     fn ror3(self) -> Self;
1134 }
1135 
1136 impl AesBitValueOps for u16 {
shift_row(self) -> u161137     fn shift_row(self) -> u16 {
1138         // first 4 bits represent first row - don't shift
1139         (self & 0x000f) |
1140         // next 4 bits represent 2nd row - left rotate 1 bit
1141         ((self & 0x00e0) >> 1) | ((self & 0x0010) << 3) |
1142         // next 4 bits represent 3rd row - left rotate 2 bits
1143         ((self & 0x0c00) >> 2) | ((self & 0x0300) << 2) |
1144         // next 4 bits represent 4th row - left rotate 3 bits
1145         ((self & 0x8000) >> 3) | ((self & 0x7000) << 1)
1146     }
1147 
inv_shift_row(self) -> u161148     fn inv_shift_row(self) -> u16 {
1149         // first 4 bits represent first row - don't shift
1150         (self & 0x000f) |
1151         // next 4 bits represent 2nd row - right rotate 1 bit
1152         ((self & 0x0080) >> 3) | ((self & 0x0070) << 1) |
1153         // next 4 bits represent 3rd row - right rotate 2 bits
1154         ((self & 0x0c00) >> 2) | ((self & 0x0300) << 2) |
1155         // next 4 bits represent 4th row - right rotate 3 bits
1156         ((self & 0xe000) >> 1) | ((self & 0x1000) << 3)
1157     }
1158 
ror1(self) -> u161159     fn ror1(self) -> u16 {
1160         self >> 4 | self << 12
1161     }
1162 
ror2(self) -> u161163     fn ror2(self) -> u16 {
1164         self >> 8 | self << 8
1165     }
1166 
ror3(self) -> u161167     fn ror3(self) -> u16 {
1168         self >> 12 | self << 4
1169     }
1170 }
1171 
1172 impl u32x4 {
lsh(self, s: u32) -> u32x41173     fn lsh(self, s: u32) -> u32x4 {
1174         let u32x4(a0, a1, a2, a3) = self;
1175         u32x4(
1176             a0 << s,
1177             (a1 << s) | (a0 >> (32 - s)),
1178             (a2 << s) | (a1 >> (32 - s)),
1179             (a3 << s) | (a2 >> (32 - s)))
1180     }
1181 
rsh(self, s: u32) -> u32x41182     fn rsh(self, s: u32) -> u32x4 {
1183         let u32x4(a0, a1, a2, a3) = self;
1184         u32x4(
1185             (a0 >> s) | (a1 << (32 - s)),
1186             (a1 >> s) | (a2 << (32 - s)),
1187             (a2 >> s) | (a3 << (32 - s)),
1188             a3 >> s)
1189     }
1190 }
1191 
1192 impl Not for u32x4 {
1193     type Output = u32x4;
1194 
not(self) -> u32x41195     fn not(self) -> u32x4 {
1196         self ^ U32X4_1
1197     }
1198 }
1199 
1200 impl Default for u32x4 {
default() -> u32x41201     fn default() -> u32x4 {
1202         u32x4(0, 0, 0, 0)
1203     }
1204 }
1205 
1206 impl AesBitValueOps for u32x4 {
shift_row(self) -> u32x41207     fn shift_row(self) -> u32x4 {
1208         let u32x4(a0, a1, a2, a3) = self;
1209         u32x4(a0, a1 >> 8 | a1 << 24, a2 >> 16 | a2 << 16, a3 >> 24 | a3 << 8)
1210     }
1211 
inv_shift_row(self) -> u32x41212     fn inv_shift_row(self) -> u32x4 {
1213         let u32x4(a0, a1, a2, a3) = self;
1214         u32x4(a0, a1 >> 24 | a1 << 8, a2 >> 16 | a2 << 16, a3 >> 8 | a3 << 24)
1215     }
1216 
ror1(self) -> u32x41217     fn ror1(self) -> u32x4 {
1218         let u32x4(a0, a1, a2, a3) = self;
1219         u32x4(a1, a2, a3, a0)
1220     }
1221 
ror2(self) -> u32x41222     fn ror2(self) -> u32x4 {
1223         let u32x4(a0, a1, a2, a3) = self;
1224         u32x4(a2, a3, a0, a1)
1225     }
1226 
ror3(self) -> u32x41227     fn ror3(self) -> u32x4 {
1228         let u32x4(a0, a1, a2, a3) = self;
1229         u32x4(a3, a0, a1, a2)
1230     }
1231 }
1232