1 use crate::*;
2 use arrayref::array_ref;
3 use core::cmp;
4 
5 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
6 pub const MAX_DEGREE: usize = 4;
7 
8 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
9 pub const MAX_DEGREE: usize = 1;
10 
11 // Variants other than Portable are unreachable in no_std, unless CPU features
12 // are explicitly enabled for the build with e.g. RUSTFLAGS="-C target-feature=avx2".
13 // This might change in the future if is_x86_feature_detected moves into libcore.
14 #[allow(dead_code)]
15 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
16 enum Platform {
17     Portable,
18     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
19     SSE41,
20     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21     AVX2,
22 }
23 
24 #[derive(Clone, Copy, Debug)]
25 pub struct Implementation(Platform);
26 
27 impl Implementation {
detect() -> Self28     pub fn detect() -> Self {
29         // Try the different implementations in order of how fast/modern they
30         // are. Currently on non-x86, everything just uses portable.
31         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
32         {
33             if let Some(avx2_impl) = Self::avx2_if_supported() {
34                 return avx2_impl;
35             }
36         }
37         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38         {
39             if let Some(sse41_impl) = Self::sse41_if_supported() {
40                 return sse41_impl;
41             }
42         }
43         Self::portable()
44     }
45 
portable() -> Self46     pub fn portable() -> Self {
47         Implementation(Platform::Portable)
48     }
49 
50     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
51     #[allow(unreachable_code)]
sse41_if_supported() -> Option<Self>52     pub fn sse41_if_supported() -> Option<Self> {
53         // Check whether SSE4.1 support is assumed by the build.
54         #[cfg(target_feature = "sse4.1")]
55         {
56             return Some(Implementation(Platform::SSE41));
57         }
58         // Otherwise dynamically check for support if we can.
59         #[cfg(feature = "std")]
60         {
61             if is_x86_feature_detected!("sse4.1") {
62                 return Some(Implementation(Platform::SSE41));
63             }
64         }
65         None
66     }
67 
68     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
69     #[allow(unreachable_code)]
avx2_if_supported() -> Option<Self>70     pub fn avx2_if_supported() -> Option<Self> {
71         // Check whether AVX2 support is assumed by the build.
72         #[cfg(target_feature = "avx2")]
73         {
74             return Some(Implementation(Platform::AVX2));
75         }
76         // Otherwise dynamically check for support if we can.
77         #[cfg(feature = "std")]
78         {
79             if is_x86_feature_detected!("avx2") {
80                 return Some(Implementation(Platform::AVX2));
81             }
82         }
83         None
84     }
85 
degree(&self) -> usize86     pub fn degree(&self) -> usize {
87         match self.0 {
88             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
89             Platform::AVX2 => avx2::DEGREE,
90             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
91             Platform::SSE41 => sse41::DEGREE,
92             Platform::Portable => 1,
93         }
94     }
95 
compress1_loop( &self, input: &[u8], words: &mut [Word; 8], count: Count, last_node: LastNode, finalize: Finalize, stride: Stride, )96     pub fn compress1_loop(
97         &self,
98         input: &[u8],
99         words: &mut [Word; 8],
100         count: Count,
101         last_node: LastNode,
102         finalize: Finalize,
103         stride: Stride,
104     ) {
105         match self.0 {
106             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
107             Platform::AVX2 => unsafe {
108                 avx2::compress1_loop(input, words, count, last_node, finalize, stride);
109             },
110             // Note that there's an SSE version of compress1 in the official C
111             // implementation, but I haven't ported it yet.
112             _ => {
113                 portable::compress1_loop(input, words, count, last_node, finalize, stride);
114             }
115         }
116     }
117 
compress2_loop(&self, jobs: &mut [Job; 2], finalize: Finalize, stride: Stride)118     pub fn compress2_loop(&self, jobs: &mut [Job; 2], finalize: Finalize, stride: Stride) {
119         match self.0 {
120             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
121             Platform::AVX2 | Platform::SSE41 => unsafe {
122                 sse41::compress2_loop(jobs, finalize, stride)
123             },
124             _ => panic!("unsupported"),
125         }
126     }
127 
compress4_loop(&self, jobs: &mut [Job; 4], finalize: Finalize, stride: Stride)128     pub fn compress4_loop(&self, jobs: &mut [Job; 4], finalize: Finalize, stride: Stride) {
129         match self.0 {
130             #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
131             Platform::AVX2 => unsafe { avx2::compress4_loop(jobs, finalize, stride) },
132             _ => panic!("unsupported"),
133         }
134     }
135 }
136 
137 pub struct Job<'a, 'b> {
138     pub input: &'a [u8],
139     pub words: &'b mut [Word; 8],
140     pub count: Count,
141     pub last_node: LastNode,
142 }
143 
144 impl<'a, 'b> core::fmt::Debug for Job<'a, 'b> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result145     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
146         // NB: Don't print the words. Leaking them would allow length extension.
147         write!(
148             f,
149             "Job {{ input_len: {}, count: {}, last_node: {} }}",
150             self.input.len(),
151             self.count,
152             self.last_node.yes(),
153         )
154     }
155 }
156 
157 // Finalize could just be a bool, but this is easier to read at callsites.
158 #[derive(Clone, Copy, Debug)]
159 pub enum Finalize {
160     Yes,
161     No,
162 }
163 
164 impl Finalize {
yes(&self) -> bool165     pub fn yes(&self) -> bool {
166         match self {
167             Finalize::Yes => true,
168             Finalize::No => false,
169         }
170     }
171 }
172 
173 // Like Finalize, this is easier to read at callsites.
174 #[derive(Clone, Copy, Debug)]
175 pub enum LastNode {
176     Yes,
177     No,
178 }
179 
180 impl LastNode {
yes(&self) -> bool181     pub fn yes(&self) -> bool {
182         match self {
183             LastNode::Yes => true,
184             LastNode::No => false,
185         }
186     }
187 }
188 
189 #[derive(Clone, Copy, Debug)]
190 pub enum Stride {
191     Serial,   // BLAKE2b/BLAKE2s
192     Parallel, // BLAKE2bp/BLAKE2sp
193 }
194 
195 impl Stride {
padded_blockbytes(&self) -> usize196     pub fn padded_blockbytes(&self) -> usize {
197         match self {
198             Stride::Serial => BLOCKBYTES,
199             Stride::Parallel => blake2bp::DEGREE * BLOCKBYTES,
200         }
201     }
202 }
203 
count_low(count: Count) -> Word204 pub(crate) fn count_low(count: Count) -> Word {
205     count as Word
206 }
207 
count_high(count: Count) -> Word208 pub(crate) fn count_high(count: Count) -> Word {
209     (count >> 8 * size_of::<Word>()) as Word
210 }
211 
assemble_count(low: Word, high: Word) -> Count212 pub(crate) fn assemble_count(low: Word, high: Word) -> Count {
213     low as Count + ((high as Count) << 8 * size_of::<Word>())
214 }
215 
flag_word(flag: bool) -> Word216 pub(crate) fn flag_word(flag: bool) -> Word {
217     if flag {
218         !0
219     } else {
220         0
221     }
222 }
223 
224 // Pull a array reference at the given offset straight from the input, if
225 // there's a full block of input available. If there's only a partial block,
226 // copy it into the provided buffer, and return an array reference that. Along
227 // with the array, return the number of bytes of real input, and whether the
228 // input can be finalized (i.e. whether there aren't any more bytes after this
229 // block). Note that this is written so that the optimizer can elide bounds
230 // checks, see: https://godbolt.org/z/0hH2bC
final_block<'a>( input: &'a [u8], offset: usize, buffer: &'a mut [u8; BLOCKBYTES], stride: Stride, ) -> (&'a [u8; BLOCKBYTES], usize, bool)231 pub fn final_block<'a>(
232     input: &'a [u8],
233     offset: usize,
234     buffer: &'a mut [u8; BLOCKBYTES],
235     stride: Stride,
236 ) -> (&'a [u8; BLOCKBYTES], usize, bool) {
237     let capped_offset = cmp::min(offset, input.len());
238     let offset_slice = &input[capped_offset..];
239     if offset_slice.len() >= BLOCKBYTES {
240         let block = array_ref!(offset_slice, 0, BLOCKBYTES);
241         let should_finalize = offset_slice.len() <= stride.padded_blockbytes();
242         (block, BLOCKBYTES, should_finalize)
243     } else {
244         // Copy the final block to the front of the block buffer. The rest of
245         // the buffer is assumed to be initialized to zero.
246         buffer[..offset_slice.len()].copy_from_slice(offset_slice);
247         (buffer, offset_slice.len(), true)
248     }
249 }
250 
input_debug_asserts(input: &[u8], finalize: Finalize)251 pub fn input_debug_asserts(input: &[u8], finalize: Finalize) {
252     // If we're not finalizing, the input must not be empty, and it must be an
253     // even multiple of the block size.
254     if !finalize.yes() {
255         debug_assert!(!input.is_empty());
256         debug_assert_eq!(0, input.len() % BLOCKBYTES);
257     }
258 }
259 
260 #[cfg(test)]
261 mod test {
262     use super::*;
263     use arrayvec::ArrayVec;
264     use core::mem::size_of;
265 
266     #[test]
test_detection()267     fn test_detection() {
268         assert_eq!(Platform::Portable, Implementation::portable().0);
269 
270         #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
271         #[cfg(feature = "std")]
272         {
273             if is_x86_feature_detected!("avx2") {
274                 assert_eq!(Platform::AVX2, Implementation::detect().0);
275                 assert_eq!(
276                     Platform::AVX2,
277                     Implementation::avx2_if_supported().unwrap().0
278                 );
279                 assert_eq!(
280                     Platform::SSE41,
281                     Implementation::sse41_if_supported().unwrap().0
282                 );
283             } else if is_x86_feature_detected!("sse4.1") {
284                 assert_eq!(Platform::SSE41, Implementation::detect().0);
285                 assert!(Implementation::avx2_if_supported().is_none());
286                 assert_eq!(
287                     Platform::SSE41,
288                     Implementation::sse41_if_supported().unwrap().0
289                 );
290             } else {
291                 assert_eq!(Platform::Portable, Implementation::detect().0);
292                 assert!(Implementation::avx2_if_supported().is_none());
293                 assert!(Implementation::sse41_if_supported().is_none());
294             }
295         }
296     }
297 
298     // TODO: Move all of these case tests into the implementation files.
exercise_cases<F>(mut f: F) where F: FnMut(Stride, usize, LastNode, Finalize, Count),299     fn exercise_cases<F>(mut f: F)
300     where
301         F: FnMut(Stride, usize, LastNode, Finalize, Count),
302     {
303         // Chose counts to hit the relevant overflow cases.
304         let counts = &[
305             (0 as Count),
306             ((1 as Count) << (8 * size_of::<Word>())) - BLOCKBYTES as Count,
307             (0 as Count).wrapping_sub(BLOCKBYTES as Count),
308         ];
309         for &stride in &[Stride::Serial, Stride::Parallel] {
310             let lengths = [
311                 0,
312                 1,
313                 BLOCKBYTES - 1,
314                 BLOCKBYTES,
315                 BLOCKBYTES + 1,
316                 2 * BLOCKBYTES - 1,
317                 2 * BLOCKBYTES,
318                 2 * BLOCKBYTES + 1,
319                 stride.padded_blockbytes() - 1,
320                 stride.padded_blockbytes(),
321                 stride.padded_blockbytes() + 1,
322                 2 * stride.padded_blockbytes() - 1,
323                 2 * stride.padded_blockbytes(),
324                 2 * stride.padded_blockbytes() + 1,
325             ];
326             for &length in &lengths {
327                 for &last_node in &[LastNode::No, LastNode::Yes] {
328                     for &finalize in &[Finalize::No, Finalize::Yes] {
329                         if !finalize.yes() && (length == 0 || length % BLOCKBYTES != 0) {
330                             // Skip these cases, they're invalid.
331                             continue;
332                         }
333                         for &count in counts {
334                             // eprintln!("\ncase -----");
335                             // dbg!(stride);
336                             // dbg!(length);
337                             // dbg!(last_node);
338                             // dbg!(finalize);
339                             // dbg!(count);
340 
341                             f(stride, length, last_node, finalize, count);
342                         }
343                     }
344                 }
345             }
346         }
347     }
348 
initial_test_words(input_index: usize) -> [Word; 8]349     fn initial_test_words(input_index: usize) -> [Word; 8] {
350         crate::Params::new()
351             .node_offset(input_index as u64)
352             .to_words()
353     }
354 
355     // Use the portable implementation, one block at a time, to compute the
356     // final state words expected for a given test case.
reference_compression( input: &[u8], stride: Stride, last_node: LastNode, finalize: Finalize, mut count: Count, input_index: usize, ) -> [Word; 8]357     fn reference_compression(
358         input: &[u8],
359         stride: Stride,
360         last_node: LastNode,
361         finalize: Finalize,
362         mut count: Count,
363         input_index: usize,
364     ) -> [Word; 8] {
365         let mut words = initial_test_words(input_index);
366         let mut offset = 0;
367         while offset == 0 || offset < input.len() {
368             let block_size = cmp::min(BLOCKBYTES, input.len() - offset);
369             let maybe_finalize = if offset + stride.padded_blockbytes() < input.len() {
370                 Finalize::No
371             } else {
372                 finalize
373             };
374             portable::compress1_loop(
375                 &input[offset..][..block_size],
376                 &mut words,
377                 count,
378                 last_node,
379                 maybe_finalize,
380                 Stride::Serial,
381             );
382             offset += stride.padded_blockbytes();
383             count = count.wrapping_add(BLOCKBYTES as Count);
384         }
385         words
386     }
387 
388     // For various loop lengths and finalization parameters, make sure that the
389     // implementation gives the same answer as the portable implementation does
390     // when invoked one block at a time. (So even the portable implementation
391     // itself is being tested here, to make sure its loop is correct.) Note
392     // that this doesn't include any fixed test vectors; those are taken from
393     // the blake2-kat.json file (copied from upstream) and tested elsewhere.
exercise_compress1_loop(implementation: Implementation)394     fn exercise_compress1_loop(implementation: Implementation) {
395         let mut input = [0; 100 * BLOCKBYTES];
396         paint_test_input(&mut input);
397 
398         exercise_cases(|stride, length, last_node, finalize, count| {
399             let reference_words =
400                 reference_compression(&input[..length], stride, last_node, finalize, count, 0);
401 
402             let mut test_words = initial_test_words(0);
403             implementation.compress1_loop(
404                 &input[..length],
405                 &mut test_words,
406                 count,
407                 last_node,
408                 finalize,
409                 stride,
410             );
411             assert_eq!(reference_words, test_words);
412         });
413     }
414 
415     #[test]
test_compress1_loop_portable()416     fn test_compress1_loop_portable() {
417         exercise_compress1_loop(Implementation::portable());
418     }
419 
420     #[test]
421     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
test_compress1_loop_sse41()422     fn test_compress1_loop_sse41() {
423         // Currently this just falls back to portable, but we test it anyway.
424         if let Some(imp) = Implementation::sse41_if_supported() {
425             exercise_compress1_loop(imp);
426         }
427     }
428 
429     #[test]
430     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
test_compress1_loop_avx2()431     fn test_compress1_loop_avx2() {
432         if let Some(imp) = Implementation::avx2_if_supported() {
433             exercise_compress1_loop(imp);
434         }
435     }
436 
437     // I use ArrayVec everywhere in here becuase currently these tests pass
438     // under no_std. I might decide that's not worth maintaining at some point,
439     // since really all we care about with no_std is that the library builds,
440     // but for now it's here. Everything is keyed off of this N constant so
441     // that it's easy to copy the code to exercise_compress4_loop.
exercise_compress2_loop(implementation: Implementation)442     fn exercise_compress2_loop(implementation: Implementation) {
443         const N: usize = 2;
444 
445         let mut input_buffer = [0; 100 * BLOCKBYTES];
446         paint_test_input(&mut input_buffer);
447         let mut inputs = ArrayVec::<[_; N]>::new();
448         for i in 0..N {
449             inputs.push(&input_buffer[i..]);
450         }
451 
452         exercise_cases(|stride, length, last_node, finalize, count| {
453             let mut reference_words = ArrayVec::<[_; N]>::new();
454             for i in 0..N {
455                 let words = reference_compression(
456                     &inputs[i][..length],
457                     stride,
458                     last_node,
459                     finalize,
460                     count.wrapping_add((i * BLOCKBYTES) as Count),
461                     i,
462                 );
463                 reference_words.push(words);
464             }
465 
466             let mut test_words = ArrayVec::<[_; N]>::new();
467             for i in 0..N {
468                 test_words.push(initial_test_words(i));
469             }
470             let mut jobs = ArrayVec::<[_; N]>::new();
471             for (i, words) in test_words.iter_mut().enumerate() {
472                 jobs.push(Job {
473                     input: &inputs[i][..length],
474                     words,
475                     count: count.wrapping_add((i * BLOCKBYTES) as Count),
476                     last_node,
477                 });
478             }
479             let mut jobs = jobs.into_inner().expect("full");
480             implementation.compress2_loop(&mut jobs, finalize, stride);
481 
482             for i in 0..N {
483                 assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
484             }
485         });
486     }
487 
488     #[test]
489     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
test_compress2_loop_sse41()490     fn test_compress2_loop_sse41() {
491         if let Some(imp) = Implementation::sse41_if_supported() {
492             exercise_compress2_loop(imp);
493         }
494     }
495 
496     #[test]
497     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
test_compress2_loop_avx2()498     fn test_compress2_loop_avx2() {
499         // Currently this just falls back to SSE4.1, but we test it anyway.
500         if let Some(imp) = Implementation::avx2_if_supported() {
501             exercise_compress2_loop(imp);
502         }
503     }
504 
505     // Copied from exercise_compress2_loop, with a different value of N and an
506     // interior call to compress4_loop.
exercise_compress4_loop(implementation: Implementation)507     fn exercise_compress4_loop(implementation: Implementation) {
508         const N: usize = 4;
509 
510         let mut input_buffer = [0; 100 * BLOCKBYTES];
511         paint_test_input(&mut input_buffer);
512         let mut inputs = ArrayVec::<[_; N]>::new();
513         for i in 0..N {
514             inputs.push(&input_buffer[i..]);
515         }
516 
517         exercise_cases(|stride, length, last_node, finalize, count| {
518             let mut reference_words = ArrayVec::<[_; N]>::new();
519             for i in 0..N {
520                 let words = reference_compression(
521                     &inputs[i][..length],
522                     stride,
523                     last_node,
524                     finalize,
525                     count.wrapping_add((i * BLOCKBYTES) as Count),
526                     i,
527                 );
528                 reference_words.push(words);
529             }
530 
531             let mut test_words = ArrayVec::<[_; N]>::new();
532             for i in 0..N {
533                 test_words.push(initial_test_words(i));
534             }
535             let mut jobs = ArrayVec::<[_; N]>::new();
536             for (i, words) in test_words.iter_mut().enumerate() {
537                 jobs.push(Job {
538                     input: &inputs[i][..length],
539                     words,
540                     count: count.wrapping_add((i * BLOCKBYTES) as Count),
541                     last_node,
542                 });
543             }
544             let mut jobs = jobs.into_inner().expect("full");
545             implementation.compress4_loop(&mut jobs, finalize, stride);
546 
547             for i in 0..N {
548                 assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
549             }
550         });
551     }
552 
553     #[test]
554     #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
test_compress4_loop_avx2()555     fn test_compress4_loop_avx2() {
556         if let Some(imp) = Implementation::avx2_if_supported() {
557             exercise_compress4_loop(imp);
558         }
559     }
560 
561     #[test]
sanity_check_count_size()562     fn sanity_check_count_size() {
563         assert_eq!(size_of::<Count>(), 2 * size_of::<Word>());
564     }
565 }
566