1 #[cfg(target_arch = "x86")]
2 use core::arch::x86::*;
3 #[cfg(target_arch = "x86_64")]
4 use core::arch::x86_64::*;
5 
6 use crate::{
7     counter_high, counter_low, CVWords, IncrementCounter, BLOCK_LEN, IV, MSG_SCHEDULE, OUT_LEN,
8 };
9 use arrayref::{array_mut_ref, mut_array_refs};
10 
11 pub const DEGREE: usize = 8;
12 
13 #[inline(always)]
loadu(src: *const u8) -> __m256i14 unsafe fn loadu(src: *const u8) -> __m256i {
15     // This is an unaligned load, so the pointer cast is allowed.
16     _mm256_loadu_si256(src as *const __m256i)
17 }
18 
19 #[inline(always)]
storeu(src: __m256i, dest: *mut u8)20 unsafe fn storeu(src: __m256i, dest: *mut u8) {
21     // This is an unaligned store, so the pointer cast is allowed.
22     _mm256_storeu_si256(dest as *mut __m256i, src)
23 }
24 
25 #[inline(always)]
add(a: __m256i, b: __m256i) -> __m256i26 unsafe fn add(a: __m256i, b: __m256i) -> __m256i {
27     _mm256_add_epi32(a, b)
28 }
29 
30 #[inline(always)]
xor(a: __m256i, b: __m256i) -> __m256i31 unsafe fn xor(a: __m256i, b: __m256i) -> __m256i {
32     _mm256_xor_si256(a, b)
33 }
34 
35 #[inline(always)]
set1(x: u32) -> __m256i36 unsafe fn set1(x: u32) -> __m256i {
37     _mm256_set1_epi32(x as i32)
38 }
39 
40 #[inline(always)]
set8(a: u32, b: u32, c: u32, d: u32, e: u32, f: u32, g: u32, h: u32) -> __m256i41 unsafe fn set8(a: u32, b: u32, c: u32, d: u32, e: u32, f: u32, g: u32, h: u32) -> __m256i {
42     _mm256_setr_epi32(
43         a as i32, b as i32, c as i32, d as i32, e as i32, f as i32, g as i32, h as i32,
44     )
45 }
46 
47 // These rotations are the "simple/shifts version". For the
48 // "complicated/shuffles version", see
49 // https://github.com/sneves/blake2-avx2/blob/b3723921f668df09ece52dcd225a36d4a4eea1d9/blake2s-common.h#L63-L66.
50 // For a discussion of the tradeoffs, see
51 // https://github.com/sneves/blake2-avx2/pull/5. Due to an LLVM bug
52 // (https://bugs.llvm.org/show_bug.cgi?id=44379), this version performs better
53 // on recent x86 chips.
54 
55 #[inline(always)]
rot16(x: __m256i) -> __m256i56 unsafe fn rot16(x: __m256i) -> __m256i {
57     _mm256_or_si256(_mm256_srli_epi32(x, 16), _mm256_slli_epi32(x, 32 - 16))
58 }
59 
60 #[inline(always)]
rot12(x: __m256i) -> __m256i61 unsafe fn rot12(x: __m256i) -> __m256i {
62     _mm256_or_si256(_mm256_srli_epi32(x, 12), _mm256_slli_epi32(x, 32 - 12))
63 }
64 
65 #[inline(always)]
rot8(x: __m256i) -> __m256i66 unsafe fn rot8(x: __m256i) -> __m256i {
67     _mm256_or_si256(_mm256_srli_epi32(x, 8), _mm256_slli_epi32(x, 32 - 8))
68 }
69 
70 #[inline(always)]
rot7(x: __m256i) -> __m256i71 unsafe fn rot7(x: __m256i) -> __m256i {
72     _mm256_or_si256(_mm256_srli_epi32(x, 7), _mm256_slli_epi32(x, 32 - 7))
73 }
74 
75 #[inline(always)]
round(v: &mut [__m256i; 16], m: &[__m256i; 16], r: usize)76 unsafe fn round(v: &mut [__m256i; 16], m: &[__m256i; 16], r: usize) {
77     v[0] = add(v[0], m[MSG_SCHEDULE[r][0] as usize]);
78     v[1] = add(v[1], m[MSG_SCHEDULE[r][2] as usize]);
79     v[2] = add(v[2], m[MSG_SCHEDULE[r][4] as usize]);
80     v[3] = add(v[3], m[MSG_SCHEDULE[r][6] as usize]);
81     v[0] = add(v[0], v[4]);
82     v[1] = add(v[1], v[5]);
83     v[2] = add(v[2], v[6]);
84     v[3] = add(v[3], v[7]);
85     v[12] = xor(v[12], v[0]);
86     v[13] = xor(v[13], v[1]);
87     v[14] = xor(v[14], v[2]);
88     v[15] = xor(v[15], v[3]);
89     v[12] = rot16(v[12]);
90     v[13] = rot16(v[13]);
91     v[14] = rot16(v[14]);
92     v[15] = rot16(v[15]);
93     v[8] = add(v[8], v[12]);
94     v[9] = add(v[9], v[13]);
95     v[10] = add(v[10], v[14]);
96     v[11] = add(v[11], v[15]);
97     v[4] = xor(v[4], v[8]);
98     v[5] = xor(v[5], v[9]);
99     v[6] = xor(v[6], v[10]);
100     v[7] = xor(v[7], v[11]);
101     v[4] = rot12(v[4]);
102     v[5] = rot12(v[5]);
103     v[6] = rot12(v[6]);
104     v[7] = rot12(v[7]);
105     v[0] = add(v[0], m[MSG_SCHEDULE[r][1] as usize]);
106     v[1] = add(v[1], m[MSG_SCHEDULE[r][3] as usize]);
107     v[2] = add(v[2], m[MSG_SCHEDULE[r][5] as usize]);
108     v[3] = add(v[3], m[MSG_SCHEDULE[r][7] as usize]);
109     v[0] = add(v[0], v[4]);
110     v[1] = add(v[1], v[5]);
111     v[2] = add(v[2], v[6]);
112     v[3] = add(v[3], v[7]);
113     v[12] = xor(v[12], v[0]);
114     v[13] = xor(v[13], v[1]);
115     v[14] = xor(v[14], v[2]);
116     v[15] = xor(v[15], v[3]);
117     v[12] = rot8(v[12]);
118     v[13] = rot8(v[13]);
119     v[14] = rot8(v[14]);
120     v[15] = rot8(v[15]);
121     v[8] = add(v[8], v[12]);
122     v[9] = add(v[9], v[13]);
123     v[10] = add(v[10], v[14]);
124     v[11] = add(v[11], v[15]);
125     v[4] = xor(v[4], v[8]);
126     v[5] = xor(v[5], v[9]);
127     v[6] = xor(v[6], v[10]);
128     v[7] = xor(v[7], v[11]);
129     v[4] = rot7(v[4]);
130     v[5] = rot7(v[5]);
131     v[6] = rot7(v[6]);
132     v[7] = rot7(v[7]);
133 
134     v[0] = add(v[0], m[MSG_SCHEDULE[r][8] as usize]);
135     v[1] = add(v[1], m[MSG_SCHEDULE[r][10] as usize]);
136     v[2] = add(v[2], m[MSG_SCHEDULE[r][12] as usize]);
137     v[3] = add(v[3], m[MSG_SCHEDULE[r][14] as usize]);
138     v[0] = add(v[0], v[5]);
139     v[1] = add(v[1], v[6]);
140     v[2] = add(v[2], v[7]);
141     v[3] = add(v[3], v[4]);
142     v[15] = xor(v[15], v[0]);
143     v[12] = xor(v[12], v[1]);
144     v[13] = xor(v[13], v[2]);
145     v[14] = xor(v[14], v[3]);
146     v[15] = rot16(v[15]);
147     v[12] = rot16(v[12]);
148     v[13] = rot16(v[13]);
149     v[14] = rot16(v[14]);
150     v[10] = add(v[10], v[15]);
151     v[11] = add(v[11], v[12]);
152     v[8] = add(v[8], v[13]);
153     v[9] = add(v[9], v[14]);
154     v[5] = xor(v[5], v[10]);
155     v[6] = xor(v[6], v[11]);
156     v[7] = xor(v[7], v[8]);
157     v[4] = xor(v[4], v[9]);
158     v[5] = rot12(v[5]);
159     v[6] = rot12(v[6]);
160     v[7] = rot12(v[7]);
161     v[4] = rot12(v[4]);
162     v[0] = add(v[0], m[MSG_SCHEDULE[r][9] as usize]);
163     v[1] = add(v[1], m[MSG_SCHEDULE[r][11] as usize]);
164     v[2] = add(v[2], m[MSG_SCHEDULE[r][13] as usize]);
165     v[3] = add(v[3], m[MSG_SCHEDULE[r][15] as usize]);
166     v[0] = add(v[0], v[5]);
167     v[1] = add(v[1], v[6]);
168     v[2] = add(v[2], v[7]);
169     v[3] = add(v[3], v[4]);
170     v[15] = xor(v[15], v[0]);
171     v[12] = xor(v[12], v[1]);
172     v[13] = xor(v[13], v[2]);
173     v[14] = xor(v[14], v[3]);
174     v[15] = rot8(v[15]);
175     v[12] = rot8(v[12]);
176     v[13] = rot8(v[13]);
177     v[14] = rot8(v[14]);
178     v[10] = add(v[10], v[15]);
179     v[11] = add(v[11], v[12]);
180     v[8] = add(v[8], v[13]);
181     v[9] = add(v[9], v[14]);
182     v[5] = xor(v[5], v[10]);
183     v[6] = xor(v[6], v[11]);
184     v[7] = xor(v[7], v[8]);
185     v[4] = xor(v[4], v[9]);
186     v[5] = rot7(v[5]);
187     v[6] = rot7(v[6]);
188     v[7] = rot7(v[7]);
189     v[4] = rot7(v[4]);
190 }
191 
192 #[inline(always)]
interleave128(a: __m256i, b: __m256i) -> (__m256i, __m256i)193 unsafe fn interleave128(a: __m256i, b: __m256i) -> (__m256i, __m256i) {
194     (
195         _mm256_permute2x128_si256(a, b, 0x20),
196         _mm256_permute2x128_si256(a, b, 0x31),
197     )
198 }
199 
200 // There are several ways to do a transposition. We could do it naively, with 8 separate
201 // _mm256_set_epi32 instructions, referencing each of the 32 words explicitly. Or we could copy
202 // the vecs into contiguous storage and then use gather instructions. This third approach is to use
203 // a series of unpack instructions to interleave the vectors. In my benchmarks, interleaving is the
204 // fastest approach. To test this, run `cargo +nightly bench --bench libtest load_8` in the
205 // https://github.com/oconnor663/bao_experiments repo.
206 #[inline(always)]
transpose_vecs(vecs: &mut [__m256i; DEGREE])207 unsafe fn transpose_vecs(vecs: &mut [__m256i; DEGREE]) {
208     // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high is 22/33/66/77.
209     let ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
210     let ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
211     let cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
212     let cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
213     let ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
214     let ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
215     let gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
216     let gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
217 
218     // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is 11/33.
219     let abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
220     let abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
221     let abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
222     let abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
223     let efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
224     let efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
225     let efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
226     let efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
227 
228     // Interleave 128-bit lanes.
229     let (abcdefgh_0, abcdefgh_4) = interleave128(abcd_04, efgh_04);
230     let (abcdefgh_1, abcdefgh_5) = interleave128(abcd_15, efgh_15);
231     let (abcdefgh_2, abcdefgh_6) = interleave128(abcd_26, efgh_26);
232     let (abcdefgh_3, abcdefgh_7) = interleave128(abcd_37, efgh_37);
233 
234     vecs[0] = abcdefgh_0;
235     vecs[1] = abcdefgh_1;
236     vecs[2] = abcdefgh_2;
237     vecs[3] = abcdefgh_3;
238     vecs[4] = abcdefgh_4;
239     vecs[5] = abcdefgh_5;
240     vecs[6] = abcdefgh_6;
241     vecs[7] = abcdefgh_7;
242 }
243 
244 #[inline(always)]
transpose_msg_vecs(inputs: &[*const u8; DEGREE], block_offset: usize) -> [__m256i; 16]245 unsafe fn transpose_msg_vecs(inputs: &[*const u8; DEGREE], block_offset: usize) -> [__m256i; 16] {
246     let mut vecs = [
247         loadu(inputs[0].add(block_offset + 0 * 4 * DEGREE)),
248         loadu(inputs[1].add(block_offset + 0 * 4 * DEGREE)),
249         loadu(inputs[2].add(block_offset + 0 * 4 * DEGREE)),
250         loadu(inputs[3].add(block_offset + 0 * 4 * DEGREE)),
251         loadu(inputs[4].add(block_offset + 0 * 4 * DEGREE)),
252         loadu(inputs[5].add(block_offset + 0 * 4 * DEGREE)),
253         loadu(inputs[6].add(block_offset + 0 * 4 * DEGREE)),
254         loadu(inputs[7].add(block_offset + 0 * 4 * DEGREE)),
255         loadu(inputs[0].add(block_offset + 1 * 4 * DEGREE)),
256         loadu(inputs[1].add(block_offset + 1 * 4 * DEGREE)),
257         loadu(inputs[2].add(block_offset + 1 * 4 * DEGREE)),
258         loadu(inputs[3].add(block_offset + 1 * 4 * DEGREE)),
259         loadu(inputs[4].add(block_offset + 1 * 4 * DEGREE)),
260         loadu(inputs[5].add(block_offset + 1 * 4 * DEGREE)),
261         loadu(inputs[6].add(block_offset + 1 * 4 * DEGREE)),
262         loadu(inputs[7].add(block_offset + 1 * 4 * DEGREE)),
263     ];
264     for i in 0..DEGREE {
265         _mm_prefetch(inputs[i].add(block_offset + 256) as *const i8, _MM_HINT_T0);
266     }
267     let squares = mut_array_refs!(&mut vecs, DEGREE, DEGREE);
268     transpose_vecs(squares.0);
269     transpose_vecs(squares.1);
270     vecs
271 }
272 
273 #[inline(always)]
load_counters(counter: u64, increment_counter: IncrementCounter) -> (__m256i, __m256i)274 unsafe fn load_counters(counter: u64, increment_counter: IncrementCounter) -> (__m256i, __m256i) {
275     let mask = if increment_counter.yes() { !0 } else { 0 };
276     (
277         set8(
278             counter_low(counter + (mask & 0)),
279             counter_low(counter + (mask & 1)),
280             counter_low(counter + (mask & 2)),
281             counter_low(counter + (mask & 3)),
282             counter_low(counter + (mask & 4)),
283             counter_low(counter + (mask & 5)),
284             counter_low(counter + (mask & 6)),
285             counter_low(counter + (mask & 7)),
286         ),
287         set8(
288             counter_high(counter + (mask & 0)),
289             counter_high(counter + (mask & 1)),
290             counter_high(counter + (mask & 2)),
291             counter_high(counter + (mask & 3)),
292             counter_high(counter + (mask & 4)),
293             counter_high(counter + (mask & 5)),
294             counter_high(counter + (mask & 6)),
295             counter_high(counter + (mask & 7)),
296         ),
297     )
298 }
299 
300 #[target_feature(enable = "avx2")]
hash8( inputs: &[*const u8; DEGREE], blocks: usize, key: &CVWords, counter: u64, increment_counter: IncrementCounter, flags: u8, flags_start: u8, flags_end: u8, out: &mut [u8; DEGREE * OUT_LEN], )301 pub unsafe fn hash8(
302     inputs: &[*const u8; DEGREE],
303     blocks: usize,
304     key: &CVWords,
305     counter: u64,
306     increment_counter: IncrementCounter,
307     flags: u8,
308     flags_start: u8,
309     flags_end: u8,
310     out: &mut [u8; DEGREE * OUT_LEN],
311 ) {
312     let mut h_vecs = [
313         set1(key[0]),
314         set1(key[1]),
315         set1(key[2]),
316         set1(key[3]),
317         set1(key[4]),
318         set1(key[5]),
319         set1(key[6]),
320         set1(key[7]),
321     ];
322     let (counter_low_vec, counter_high_vec) = load_counters(counter, increment_counter);
323     let mut block_flags = flags | flags_start;
324 
325     for block in 0..blocks {
326         if block + 1 == blocks {
327             block_flags |= flags_end;
328         }
329         let block_len_vec = set1(BLOCK_LEN as u32); // full blocks only
330         let block_flags_vec = set1(block_flags as u32);
331         let msg_vecs = transpose_msg_vecs(inputs, block * BLOCK_LEN);
332 
333         // The transposed compression function. Note that inlining this
334         // manually here improves compile times by a lot, compared to factoring
335         // it out into its own function and making it #[inline(always)]. Just
336         // guessing, it might have something to do with loop unrolling.
337         let mut v = [
338             h_vecs[0],
339             h_vecs[1],
340             h_vecs[2],
341             h_vecs[3],
342             h_vecs[4],
343             h_vecs[5],
344             h_vecs[6],
345             h_vecs[7],
346             set1(IV[0]),
347             set1(IV[1]),
348             set1(IV[2]),
349             set1(IV[3]),
350             counter_low_vec,
351             counter_high_vec,
352             block_len_vec,
353             block_flags_vec,
354         ];
355         round(&mut v, &msg_vecs, 0);
356         round(&mut v, &msg_vecs, 1);
357         round(&mut v, &msg_vecs, 2);
358         round(&mut v, &msg_vecs, 3);
359         round(&mut v, &msg_vecs, 4);
360         round(&mut v, &msg_vecs, 5);
361         round(&mut v, &msg_vecs, 6);
362         h_vecs[0] = xor(v[0], v[8]);
363         h_vecs[1] = xor(v[1], v[9]);
364         h_vecs[2] = xor(v[2], v[10]);
365         h_vecs[3] = xor(v[3], v[11]);
366         h_vecs[4] = xor(v[4], v[12]);
367         h_vecs[5] = xor(v[5], v[13]);
368         h_vecs[6] = xor(v[6], v[14]);
369         h_vecs[7] = xor(v[7], v[15]);
370 
371         block_flags = flags;
372     }
373 
374     transpose_vecs(&mut h_vecs);
375     storeu(h_vecs[0], out.as_mut_ptr().add(0 * 4 * DEGREE));
376     storeu(h_vecs[1], out.as_mut_ptr().add(1 * 4 * DEGREE));
377     storeu(h_vecs[2], out.as_mut_ptr().add(2 * 4 * DEGREE));
378     storeu(h_vecs[3], out.as_mut_ptr().add(3 * 4 * DEGREE));
379     storeu(h_vecs[4], out.as_mut_ptr().add(4 * 4 * DEGREE));
380     storeu(h_vecs[5], out.as_mut_ptr().add(5 * 4 * DEGREE));
381     storeu(h_vecs[6], out.as_mut_ptr().add(6 * 4 * DEGREE));
382     storeu(h_vecs[7], out.as_mut_ptr().add(7 * 4 * DEGREE));
383 }
384 
385 #[target_feature(enable = "avx2")]
hash_many<const N: usize>( mut inputs: &[&[u8; N]], key: &CVWords, mut counter: u64, increment_counter: IncrementCounter, flags: u8, flags_start: u8, flags_end: u8, mut out: &mut [u8], )386 pub unsafe fn hash_many<const N: usize>(
387     mut inputs: &[&[u8; N]],
388     key: &CVWords,
389     mut counter: u64,
390     increment_counter: IncrementCounter,
391     flags: u8,
392     flags_start: u8,
393     flags_end: u8,
394     mut out: &mut [u8],
395 ) {
396     debug_assert!(out.len() >= inputs.len() * OUT_LEN, "out too short");
397     while inputs.len() >= DEGREE && out.len() >= DEGREE * OUT_LEN {
398         // Safe because the layout of arrays is guaranteed, and because the
399         // `blocks` count is determined statically from the argument type.
400         let input_ptrs: &[*const u8; DEGREE] = &*(inputs.as_ptr() as *const [*const u8; DEGREE]);
401         let blocks = N / BLOCK_LEN;
402         hash8(
403             input_ptrs,
404             blocks,
405             key,
406             counter,
407             increment_counter,
408             flags,
409             flags_start,
410             flags_end,
411             array_mut_ref!(out, 0, DEGREE * OUT_LEN),
412         );
413         if increment_counter.yes() {
414             counter += DEGREE as u64;
415         }
416         inputs = &inputs[DEGREE..];
417         out = &mut out[DEGREE * OUT_LEN..];
418     }
419     crate::sse41::hash_many(
420         inputs,
421         key,
422         counter,
423         increment_counter,
424         flags,
425         flags_start,
426         flags_end,
427         out,
428     );
429 }
430 
431 #[cfg(test)]
432 mod test {
433     use super::*;
434 
435     #[test]
test_transpose()436     fn test_transpose() {
437         if !crate::platform::avx2_detected() {
438             return;
439         }
440 
441         #[target_feature(enable = "avx2")]
442         unsafe fn transpose_wrapper(vecs: &mut [__m256i; DEGREE]) {
443             transpose_vecs(vecs);
444         }
445 
446         let mut matrix = [[0 as u32; DEGREE]; DEGREE];
447         for i in 0..DEGREE {
448             for j in 0..DEGREE {
449                 matrix[i][j] = (i * DEGREE + j) as u32;
450             }
451         }
452 
453         unsafe {
454             let mut vecs: [__m256i; DEGREE] = core::mem::transmute(matrix);
455             transpose_wrapper(&mut vecs);
456             matrix = core::mem::transmute(vecs);
457         }
458 
459         for i in 0..DEGREE {
460             for j in 0..DEGREE {
461                 // Reversed indexes from above.
462                 assert_eq!(matrix[j][i], (i * DEGREE + j) as u32);
463             }
464         }
465     }
466 
467     #[test]
test_hash_many()468     fn test_hash_many() {
469         if !crate::platform::avx2_detected() {
470             return;
471         }
472         crate::test::test_hash_many_fn(hash_many, hash_many);
473     }
474 }
475