1 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4 // option. This file may not be copied, modified, or distributed
5 // except according to those terms.
6 
7 use buffer::{BufferResult, RefReadBuffer, RefWriteBuffer};
8 use symmetriccipher::{Encryptor, Decryptor, SynchronousStreamCipher, SymmetricCipherError};
9 use cryptoutil::{read_u32_le, symm_enc_or_dec, write_u32_le, xor_keystream};
10 use simd::u32x4;
11 
12 use std::cmp;
13 
14 #[derive(Clone, Copy)]
15 struct SalsaState {
16   a: u32x4,
17   b: u32x4,
18   c: u32x4,
19   d: u32x4
20 }
21 
22 #[derive(Copy)]
23 pub struct Salsa20 {
24     state: SalsaState,
25     output: [u8; 64],
26     offset: usize,
27 }
28 
clone(&self) -> Salsa2029 impl Clone for Salsa20 { fn clone(&self) -> Salsa20 { *self } }
30 
31 const S7:u32x4 = u32x4(7, 7, 7, 7);
32 const S9:u32x4 = u32x4(9, 9, 9, 9);
33 const S13:u32x4 = u32x4(13, 13, 13, 13);
34 const S18:u32x4 = u32x4(18, 18, 18, 18);
35 const S32:u32x4 = u32x4(32, 32, 32, 32);
36 
37 macro_rules! prepare_rowround {
38     ($a: expr, $b: expr, $c: expr) => {{
39         let u32x4(a10, a11, a12, a13) = $a;
40         $a = u32x4(a13, a10, a11, a12);
41         let u32x4(b10, b11, b12, b13) = $b;
42         $b = u32x4(b12, b13, b10, b11);
43         let u32x4(c10, c11, c12, c13) = $c;
44         $c = u32x4(c11, c12, c13, c10);
45     }}
46 }
47 
48 macro_rules! prepare_columnround {
49     ($a: expr, $b: expr, $c: expr) => {{
50         let u32x4(a13, a10, a11, a12) = $a;
51         $a = u32x4(a10, a11, a12, a13);
52         let u32x4(b12, b13, b10, b11) = $b;
53         $b = u32x4(b10, b11, b12, b13);
54         let u32x4(c11, c12, c13, c10) = $c;
55         $c = u32x4(c10, c11, c12, c13);
56     }}
57 }
58 
59 macro_rules! add_rotate_xor {
60     ($dst: expr, $a: expr, $b: expr, $shift: expr) => {{
61         let v = $a + $b;
62         let r = S32 - $shift;
63         let right = v >> r;
64         $dst = $dst ^ (v << $shift) ^ right
65     }}
66 }
67 
columnround(state: &mut SalsaState) -> ()68 fn columnround(state: &mut SalsaState) -> () {
69     add_rotate_xor!(state.a, state.d, state.c, S7);
70     add_rotate_xor!(state.b, state.a, state.d, S9);
71     add_rotate_xor!(state.c, state.b, state.a, S13);
72     add_rotate_xor!(state.d, state.c, state.b, S18);
73 }
74 
rowround(state: &mut SalsaState) -> ()75 fn rowround(state: &mut SalsaState) -> () {
76     add_rotate_xor!(state.c, state.d, state.a, S7);
77     add_rotate_xor!(state.b, state.c, state.d, S9);
78     add_rotate_xor!(state.a, state.c, state.b, S13);
79     add_rotate_xor!(state.d, state.a, state.b, S18);
80 }
81 
82 impl Salsa20 {
new(key: &[u8], nonce: &[u8]) -> Salsa2083     pub fn new(key: &[u8], nonce: &[u8]) -> Salsa20 {
84         assert!(key.len() == 16 || key.len() == 32);
85         assert!(nonce.len() == 8);
86         Salsa20 { state: Salsa20::expand(key, nonce), output: [0; 64], offset: 64 }
87     }
88 
new_xsalsa20(key: &[u8], nonce: &[u8]) -> Salsa2089     pub fn new_xsalsa20(key: &[u8], nonce: &[u8]) -> Salsa20 {
90         assert!(key.len() == 32);
91         assert!(nonce.len() == 24);
92         let mut xsalsa20 = Salsa20 { state: Salsa20::expand(key, &nonce[0..16]), output: [0; 64], offset: 64 };
93 
94         let mut new_key = [0; 32];
95         xsalsa20.hsalsa20_hash(&mut new_key);
96         xsalsa20.state = Salsa20::expand(&new_key, &nonce[16..24]);
97 
98         xsalsa20
99     }
100 
expand(key: &[u8], nonce: &[u8]) -> SalsaState101     fn expand(key: &[u8], nonce: &[u8]) -> SalsaState {
102         let constant = match key.len() {
103             16 => b"expand 16-byte k",
104             32 => b"expand 32-byte k",
105             _  => unreachable!(),
106         };
107 
108         // The state vectors are laid out to facilitate SIMD operation,
109         // instead of the natural matrix ordering.
110         //
111         //  * Constant (x0, x5, x10, x15)
112         //  * Key (x1, x2, x3, x4, x11, x12, x13, x14)
113         //  * Input (x6, x7, x8, x9)
114 
115         let key_tail; // (x11, x12, x13, x14)
116         if key.len() == 16 {
117             key_tail = key;
118         } else {
119             key_tail = &key[16..32];
120         }
121 
122         let x8; let x9; // (x8, x9)
123         if nonce.len() == 16 {
124             // HSalsa uses the full 16 byte nonce.
125             x8 = read_u32_le(&nonce[8..12]);
126             x9 = read_u32_le(&nonce[12..16]);
127         } else {
128             x8 = 0;
129             x9 = 0;
130         }
131 
132         SalsaState {
133             a: u32x4(
134                 read_u32_le(&key[12..16]),      // x4
135                 x9,                             // x9
136                 read_u32_le(&key_tail[12..16]), // x14
137                 read_u32_le(&key[8..12]),       // x3
138             ),
139             b: u32x4(
140                 x8,                             // x8
141                 read_u32_le(&key_tail[8..12]),  // x13
142                 read_u32_le(&key[4..8]),        // x2
143                 read_u32_le(&nonce[4..8])       // x7
144             ),
145             c: u32x4(
146                 read_u32_le(&key_tail[4..8]),   // x12
147                 read_u32_le(&key[0..4]),        // x1
148                 read_u32_le(&nonce[0..4]),      // x6
149                 read_u32_le(&key_tail[0..4])    // x11
150             ),
151             d: u32x4(
152                 read_u32_le(&constant[0..4]),   // x0
153                 read_u32_le(&constant[4..8]),   // x5
154                 read_u32_le(&constant[8..12]),  // x10
155                 read_u32_le(&constant[12..16]), // x15
156             )
157         }
158     }
159 
hash(&mut self)160     fn hash(&mut self) {
161         let mut state = self.state;
162         for _ in 0..10 {
163             columnround(&mut state);
164             prepare_rowround!(state.a, state.b, state.c);
165             rowround(&mut state);
166             prepare_columnround!(state.a, state.b, state.c);
167         }
168         let u32x4(x4, x9, x14, x3) = self.state.a + state.a;
169         let u32x4(x8, x13, x2, x7) = self.state.b + state.b;
170         let u32x4(x12, x1, x6, x11) = self.state.c + state.c;
171         let u32x4(x0, x5, x10, x15) = self.state.d + state.d;
172         let lens = [
173              x0,  x1,  x2,  x3,
174              x4,  x5,  x6,  x7,
175              x8,  x9, x10, x11,
176             x12, x13, x14, x15
177         ];
178         for i in 0..lens.len() {
179             write_u32_le(&mut self.output[i*4..(i+1)*4], lens[i]);
180         }
181 
182         self.state.b = self.state.b + u32x4(1, 0, 0, 0);
183         let u32x4(_, _, _, ctr_lo) = self.state.b;
184         if ctr_lo == 0 {
185             self.state.a = self.state.a + u32x4(0, 1, 0, 0);
186         }
187 
188         self.offset = 0;
189     }
190 
hsalsa20_hash(&mut self, out: &mut [u8])191     fn hsalsa20_hash(&mut self, out: &mut [u8]) {
192         let mut state = self.state;
193         for _ in 0..10 {
194             columnround(&mut state);
195             prepare_rowround!(state.a, state.b, state.c);
196             rowround(&mut state);
197             prepare_columnround!(state.a, state.b, state.c);
198         }
199         let u32x4(_, x9, _, _) = state.a;
200         let u32x4(x8, _, _, x7) = state.b;
201         let u32x4(_, _, x6, _) = state.c;
202         let u32x4(x0, x5, x10, x15) = state.d;
203         let lens = [
204             x0, x5, x10, x15,
205             x6, x7, x8, x9
206         ];
207         for i in 0..lens.len() {
208             write_u32_le(&mut out[i*4..(i+1)*4], lens[i]);
209         }
210     }
211 }
212 
213 impl SynchronousStreamCipher for Salsa20 {
process(&mut self, input: &[u8], output: &mut [u8])214     fn process(&mut self, input: &[u8], output: &mut [u8]) {
215         assert!(input.len() == output.len());
216         let len = input.len();
217         let mut i = 0;
218         while i < len {
219             // If there is no keystream available in the output buffer,
220             // generate the next block.
221             if self.offset == 64 {
222                 self.hash();
223             }
224 
225             // Process the min(available keystream, remaining input length).
226             let count = cmp::min(64 - self.offset, len - i);
227             xor_keystream(&mut output[i..i+count], &input[i..i+count], &self.output[self.offset..]);
228             i += count;
229             self.offset += count;
230         }
231     }
232 }
233 
234 impl Encryptor for Salsa20 {
encrypt(&mut self, input: &mut RefReadBuffer, output: &mut RefWriteBuffer, _: bool) -> Result<BufferResult, SymmetricCipherError>235     fn encrypt(&mut self, input: &mut RefReadBuffer, output: &mut RefWriteBuffer, _: bool)
236             -> Result<BufferResult, SymmetricCipherError> {
237         symm_enc_or_dec(self, input, output)
238     }
239 }
240 
241 impl Decryptor for Salsa20 {
decrypt(&mut self, input: &mut RefReadBuffer, output: &mut RefWriteBuffer, _: bool) -> Result<BufferResult, SymmetricCipherError>242     fn decrypt(&mut self, input: &mut RefReadBuffer, output: &mut RefWriteBuffer, _: bool)
243             -> Result<BufferResult, SymmetricCipherError> {
244         symm_enc_or_dec(self, input, output)
245     }
246 }
247 
hsalsa20(key: &[u8], nonce: &[u8], out: &mut [u8])248 pub fn hsalsa20(key: &[u8], nonce: &[u8], out: &mut [u8]) {
249     assert!(key.len() == 32);
250     assert!(nonce.len() == 16);
251     let mut h = Salsa20 { state: Salsa20::expand(key, nonce), output: [0; 64], offset: 64 };
252     h.hsalsa20_hash(out);
253 }
254 
255 #[cfg(test)]
256 mod test {
257     use std::iter::repeat;
258 
259     use salsa20::Salsa20;
260     use symmetriccipher::SynchronousStreamCipher;
261 
262     use digest::Digest;
263     use sha2::Sha256;
264 
265     #[test]
test_salsa20_128bit_ecrypt_set_1_vector_0()266     fn test_salsa20_128bit_ecrypt_set_1_vector_0() {
267         let key = [128u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
268         let nonce = [0u8; 8];
269         let input = [0u8; 64];
270         let mut stream = [0u8; 64];
271         let result =
272             [0x4D, 0xFA, 0x5E, 0x48, 0x1D, 0xA2, 0x3E, 0xA0,
273              0x9A, 0x31, 0x02, 0x20, 0x50, 0x85, 0x99, 0x36,
274              0xDA, 0x52, 0xFC, 0xEE, 0x21, 0x80, 0x05, 0x16,
275              0x4F, 0x26, 0x7C, 0xB6, 0x5F, 0x5C, 0xFD, 0x7F,
276              0x2B, 0x4F, 0x97, 0xE0, 0xFF, 0x16, 0x92, 0x4A,
277              0x52, 0xDF, 0x26, 0x95, 0x15, 0x11, 0x0A, 0x07,
278              0xF9, 0xE4, 0x60, 0xBC, 0x65, 0xEF, 0x95, 0xDA,
279              0x58, 0xF7, 0x40, 0xB7, 0xD1, 0xDB, 0xB0, 0xAA];
280 
281         let mut salsa20 = Salsa20::new(&key, &nonce);
282         salsa20.process(&input, &mut stream);
283         assert!(stream[..] == result[..]);
284     }
285 
286     #[test]
test_salsa20_256bit_ecrypt_set_1_vector_0()287     fn test_salsa20_256bit_ecrypt_set_1_vector_0() {
288         let key =
289             [128u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
290                  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
291         let nonce = [0u8; 8];
292         let input = [0u8; 64];
293         let mut stream = [0u8; 64];
294         let result =
295             [0xE3, 0xBE, 0x8F, 0xDD, 0x8B, 0xEC, 0xA2, 0xE3,
296              0xEA, 0x8E, 0xF9, 0x47, 0x5B, 0x29, 0xA6, 0xE7,
297              0x00, 0x39, 0x51, 0xE1, 0x09, 0x7A, 0x5C, 0x38,
298              0xD2, 0x3B, 0x7A, 0x5F, 0xAD, 0x9F, 0x68, 0x44,
299              0xB2, 0x2C, 0x97, 0x55, 0x9E, 0x27, 0x23, 0xC7,
300              0xCB, 0xBD, 0x3F, 0xE4, 0xFC, 0x8D, 0x9A, 0x07,
301              0x44, 0x65, 0x2A, 0x83, 0xE7, 0x2A, 0x9C, 0x46,
302              0x18, 0x76, 0xAF, 0x4D, 0x7E, 0xF1, 0xA1, 0x17];
303 
304         let mut salsa20 = Salsa20::new(&key, &nonce);
305         salsa20.process(&input, &mut stream);
306         assert!(stream[..] == result[..]);
307     }
308 
309     #[test]
test_salsa20_256bit_nacl_vector_2()310     fn test_salsa20_256bit_nacl_vector_2() {
311         let key = [
312             0xdc,0x90,0x8d,0xda,0x0b,0x93,0x44,0xa9,
313             0x53,0x62,0x9b,0x73,0x38,0x20,0x77,0x88,
314             0x80,0xf3,0xce,0xb4,0x21,0xbb,0x61,0xb9,
315             0x1c,0xbd,0x4c,0x3e,0x66,0x25,0x6c,0xe4
316         ];
317         let nonce = [
318             0x82,0x19,0xe0,0x03,0x6b,0x7a,0x0b,0x37
319         ];
320         let input: Vec<u8> = repeat(0).take(4194304).collect();
321         let mut stream: Vec<u8> = repeat(0).take(input.len()).collect();
322         let output_str = "662b9d0e3463029156069b12f918691a98f7dfb2ca0393c96bbfc6b1fbd630a2";
323 
324         let mut salsa20 = Salsa20::new(&key, &nonce);
325         salsa20.process(input.as_ref(), &mut stream);
326 
327         let mut sh = Sha256::new();
328         sh.input(stream.as_ref());
329         let out_str = sh.result_str();
330         assert!(&out_str[..] == output_str);
331     }
332 
333     #[test]
test_xsalsa20_cryptopp()334     fn test_xsalsa20_cryptopp() {
335         let key =
336             [0x1b, 0x27, 0x55, 0x64, 0x73, 0xe9, 0x85, 0xd4,
337              0x62, 0xcd, 0x51, 0x19, 0x7a, 0x9a, 0x46, 0xc7,
338              0x60, 0x09, 0x54, 0x9e, 0xac, 0x64, 0x74, 0xf2,
339              0x06, 0xc4, 0xee, 0x08, 0x44, 0xf6, 0x83, 0x89];
340         let nonce =
341             [0x69, 0x69, 0x6e, 0xe9, 0x55, 0xb6, 0x2b, 0x73,
342              0xcd, 0x62, 0xbd, 0xa8, 0x75, 0xfc, 0x73, 0xd6,
343              0x82, 0x19, 0xe0, 0x03, 0x6b, 0x7a, 0x0b, 0x37];
344         let input = [0u8; 139];
345         let mut stream = [0u8; 139];
346         let result =
347             [0xee, 0xa6, 0xa7, 0x25, 0x1c, 0x1e, 0x72, 0x91,
348              0x6d, 0x11, 0xc2, 0xcb, 0x21, 0x4d, 0x3c, 0x25,
349              0x25, 0x39, 0x12, 0x1d, 0x8e, 0x23, 0x4e, 0x65,
350              0x2d, 0x65, 0x1f, 0xa4, 0xc8, 0xcf, 0xf8, 0x80,
351              0x30, 0x9e, 0x64, 0x5a, 0x74, 0xe9, 0xe0, 0xa6,
352              0x0d, 0x82, 0x43, 0xac, 0xd9, 0x17, 0x7a, 0xb5,
353              0x1a, 0x1b, 0xeb, 0x8d, 0x5a, 0x2f, 0x5d, 0x70,
354              0x0c, 0x09, 0x3c, 0x5e, 0x55, 0x85, 0x57, 0x96,
355              0x25, 0x33, 0x7b, 0xd3, 0xab, 0x61, 0x9d, 0x61,
356              0x57, 0x60, 0xd8, 0xc5, 0xb2, 0x24, 0xa8, 0x5b,
357              0x1d, 0x0e, 0xfe, 0x0e, 0xb8, 0xa7, 0xee, 0x16,
358              0x3a, 0xbb, 0x03, 0x76, 0x52, 0x9f, 0xcc, 0x09,
359              0xba, 0xb5, 0x06, 0xc6, 0x18, 0xe1, 0x3c, 0xe7,
360              0x77, 0xd8, 0x2c, 0x3a, 0xe9, 0xd1, 0xa6, 0xf9,
361              0x72, 0xd4, 0x16, 0x02, 0x87, 0xcb, 0xfe, 0x60,
362              0xbf, 0x21, 0x30, 0xfc, 0x0a, 0x6f, 0xf6, 0x04,
363              0x9d, 0x0a, 0x5c, 0x8a, 0x82, 0xf4, 0x29, 0x23,
364              0x1f, 0x00, 0x80];
365 
366         let mut xsalsa20 = Salsa20::new_xsalsa20(&key, &nonce);
367         xsalsa20.process(&input, &mut stream);
368         assert!(stream[..] == result[..]);
369     }
370 }
371 
372 #[cfg(all(test, feature = "with-bench"))]
373 mod bench {
374     use test::Bencher;
375     use symmetriccipher::SynchronousStreamCipher;
376     use salsa20::Salsa20;
377 
378     #[bench]
salsa20_10(bh: & mut Bencher)379     pub fn salsa20_10(bh: & mut Bencher) {
380         let mut salsa20 = Salsa20::new(&[0; 32], &[0; 8]);
381         let input = [1u8; 10];
382         let mut output = [0u8; 10];
383         bh.iter( || {
384             salsa20.process(&input, &mut output);
385         });
386         bh.bytes = input.len() as u64;
387     }
388 
389     #[bench]
salsa20_1k(bh: & mut Bencher)390     pub fn salsa20_1k(bh: & mut Bencher) {
391         let mut salsa20 = Salsa20::new(&[0; 32], &[0; 8]);
392         let input = [1u8; 1024];
393         let mut output = [0u8; 1024];
394         bh.iter( || {
395             salsa20.process(&input, &mut output);
396         });
397         bh.bytes = input.len() as u64;
398     }
399 
400     #[bench]
salsa20_64k(bh: & mut Bencher)401     pub fn salsa20_64k(bh: & mut Bencher) {
402         let mut salsa20 = Salsa20::new(&[0; 32], &[0; 8]);
403         let input = [1u8; 65536];
404         let mut output = [0u8; 65536];
405         bh.iter( || {
406             salsa20.process(&input, &mut output);
407         });
408         bh.bytes = input.len() as u64;
409     }
410 }
411