1 #[cfg(target_arch = "x86")]
2 use core::arch::x86::*;
3 #[cfg(target_arch = "x86_64")]
4 use core::arch::x86_64::*;
5
6 use crate::{
7 counter_high, counter_low, CVWords, IncrementCounter, BLOCK_LEN, IV, MSG_SCHEDULE, OUT_LEN,
8 };
9 use arrayref::{array_mut_ref, mut_array_refs};
10
11 pub const DEGREE: usize = 8;
12
13 #[inline(always)]
loadu(src: *const u8) -> __m256i14 unsafe fn loadu(src: *const u8) -> __m256i {
15 // This is an unaligned load, so the pointer cast is allowed.
16 _mm256_loadu_si256(src as *const __m256i)
17 }
18
19 #[inline(always)]
storeu(src: __m256i, dest: *mut u8)20 unsafe fn storeu(src: __m256i, dest: *mut u8) {
21 // This is an unaligned store, so the pointer cast is allowed.
22 _mm256_storeu_si256(dest as *mut __m256i, src)
23 }
24
25 #[inline(always)]
add(a: __m256i, b: __m256i) -> __m256i26 unsafe fn add(a: __m256i, b: __m256i) -> __m256i {
27 _mm256_add_epi32(a, b)
28 }
29
30 #[inline(always)]
xor(a: __m256i, b: __m256i) -> __m256i31 unsafe fn xor(a: __m256i, b: __m256i) -> __m256i {
32 _mm256_xor_si256(a, b)
33 }
34
35 #[inline(always)]
set1(x: u32) -> __m256i36 unsafe fn set1(x: u32) -> __m256i {
37 _mm256_set1_epi32(x as i32)
38 }
39
40 #[inline(always)]
set8(a: u32, b: u32, c: u32, d: u32, e: u32, f: u32, g: u32, h: u32) -> __m256i41 unsafe fn set8(a: u32, b: u32, c: u32, d: u32, e: u32, f: u32, g: u32, h: u32) -> __m256i {
42 _mm256_setr_epi32(
43 a as i32, b as i32, c as i32, d as i32, e as i32, f as i32, g as i32, h as i32,
44 )
45 }
46
47 // These rotations are the "simple/shifts version". For the
48 // "complicated/shuffles version", see
49 // https://github.com/sneves/blake2-avx2/blob/b3723921f668df09ece52dcd225a36d4a4eea1d9/blake2s-common.h#L63-L66.
50 // For a discussion of the tradeoffs, see
51 // https://github.com/sneves/blake2-avx2/pull/5. Due to an LLVM bug
52 // (https://bugs.llvm.org/show_bug.cgi?id=44379), this version performs better
53 // on recent x86 chips.
54
55 #[inline(always)]
rot16(x: __m256i) -> __m256i56 unsafe fn rot16(x: __m256i) -> __m256i {
57 _mm256_or_si256(_mm256_srli_epi32(x, 16), _mm256_slli_epi32(x, 32 - 16))
58 }
59
60 #[inline(always)]
rot12(x: __m256i) -> __m256i61 unsafe fn rot12(x: __m256i) -> __m256i {
62 _mm256_or_si256(_mm256_srli_epi32(x, 12), _mm256_slli_epi32(x, 32 - 12))
63 }
64
65 #[inline(always)]
rot8(x: __m256i) -> __m256i66 unsafe fn rot8(x: __m256i) -> __m256i {
67 _mm256_or_si256(_mm256_srli_epi32(x, 8), _mm256_slli_epi32(x, 32 - 8))
68 }
69
70 #[inline(always)]
rot7(x: __m256i) -> __m256i71 unsafe fn rot7(x: __m256i) -> __m256i {
72 _mm256_or_si256(_mm256_srli_epi32(x, 7), _mm256_slli_epi32(x, 32 - 7))
73 }
74
75 #[inline(always)]
round(v: &mut [__m256i; 16], m: &[__m256i; 16], r: usize)76 unsafe fn round(v: &mut [__m256i; 16], m: &[__m256i; 16], r: usize) {
77 v[0] = add(v[0], m[MSG_SCHEDULE[r][0] as usize]);
78 v[1] = add(v[1], m[MSG_SCHEDULE[r][2] as usize]);
79 v[2] = add(v[2], m[MSG_SCHEDULE[r][4] as usize]);
80 v[3] = add(v[3], m[MSG_SCHEDULE[r][6] as usize]);
81 v[0] = add(v[0], v[4]);
82 v[1] = add(v[1], v[5]);
83 v[2] = add(v[2], v[6]);
84 v[3] = add(v[3], v[7]);
85 v[12] = xor(v[12], v[0]);
86 v[13] = xor(v[13], v[1]);
87 v[14] = xor(v[14], v[2]);
88 v[15] = xor(v[15], v[3]);
89 v[12] = rot16(v[12]);
90 v[13] = rot16(v[13]);
91 v[14] = rot16(v[14]);
92 v[15] = rot16(v[15]);
93 v[8] = add(v[8], v[12]);
94 v[9] = add(v[9], v[13]);
95 v[10] = add(v[10], v[14]);
96 v[11] = add(v[11], v[15]);
97 v[4] = xor(v[4], v[8]);
98 v[5] = xor(v[5], v[9]);
99 v[6] = xor(v[6], v[10]);
100 v[7] = xor(v[7], v[11]);
101 v[4] = rot12(v[4]);
102 v[5] = rot12(v[5]);
103 v[6] = rot12(v[6]);
104 v[7] = rot12(v[7]);
105 v[0] = add(v[0], m[MSG_SCHEDULE[r][1] as usize]);
106 v[1] = add(v[1], m[MSG_SCHEDULE[r][3] as usize]);
107 v[2] = add(v[2], m[MSG_SCHEDULE[r][5] as usize]);
108 v[3] = add(v[3], m[MSG_SCHEDULE[r][7] as usize]);
109 v[0] = add(v[0], v[4]);
110 v[1] = add(v[1], v[5]);
111 v[2] = add(v[2], v[6]);
112 v[3] = add(v[3], v[7]);
113 v[12] = xor(v[12], v[0]);
114 v[13] = xor(v[13], v[1]);
115 v[14] = xor(v[14], v[2]);
116 v[15] = xor(v[15], v[3]);
117 v[12] = rot8(v[12]);
118 v[13] = rot8(v[13]);
119 v[14] = rot8(v[14]);
120 v[15] = rot8(v[15]);
121 v[8] = add(v[8], v[12]);
122 v[9] = add(v[9], v[13]);
123 v[10] = add(v[10], v[14]);
124 v[11] = add(v[11], v[15]);
125 v[4] = xor(v[4], v[8]);
126 v[5] = xor(v[5], v[9]);
127 v[6] = xor(v[6], v[10]);
128 v[7] = xor(v[7], v[11]);
129 v[4] = rot7(v[4]);
130 v[5] = rot7(v[5]);
131 v[6] = rot7(v[6]);
132 v[7] = rot7(v[7]);
133
134 v[0] = add(v[0], m[MSG_SCHEDULE[r][8] as usize]);
135 v[1] = add(v[1], m[MSG_SCHEDULE[r][10] as usize]);
136 v[2] = add(v[2], m[MSG_SCHEDULE[r][12] as usize]);
137 v[3] = add(v[3], m[MSG_SCHEDULE[r][14] as usize]);
138 v[0] = add(v[0], v[5]);
139 v[1] = add(v[1], v[6]);
140 v[2] = add(v[2], v[7]);
141 v[3] = add(v[3], v[4]);
142 v[15] = xor(v[15], v[0]);
143 v[12] = xor(v[12], v[1]);
144 v[13] = xor(v[13], v[2]);
145 v[14] = xor(v[14], v[3]);
146 v[15] = rot16(v[15]);
147 v[12] = rot16(v[12]);
148 v[13] = rot16(v[13]);
149 v[14] = rot16(v[14]);
150 v[10] = add(v[10], v[15]);
151 v[11] = add(v[11], v[12]);
152 v[8] = add(v[8], v[13]);
153 v[9] = add(v[9], v[14]);
154 v[5] = xor(v[5], v[10]);
155 v[6] = xor(v[6], v[11]);
156 v[7] = xor(v[7], v[8]);
157 v[4] = xor(v[4], v[9]);
158 v[5] = rot12(v[5]);
159 v[6] = rot12(v[6]);
160 v[7] = rot12(v[7]);
161 v[4] = rot12(v[4]);
162 v[0] = add(v[0], m[MSG_SCHEDULE[r][9] as usize]);
163 v[1] = add(v[1], m[MSG_SCHEDULE[r][11] as usize]);
164 v[2] = add(v[2], m[MSG_SCHEDULE[r][13] as usize]);
165 v[3] = add(v[3], m[MSG_SCHEDULE[r][15] as usize]);
166 v[0] = add(v[0], v[5]);
167 v[1] = add(v[1], v[6]);
168 v[2] = add(v[2], v[7]);
169 v[3] = add(v[3], v[4]);
170 v[15] = xor(v[15], v[0]);
171 v[12] = xor(v[12], v[1]);
172 v[13] = xor(v[13], v[2]);
173 v[14] = xor(v[14], v[3]);
174 v[15] = rot8(v[15]);
175 v[12] = rot8(v[12]);
176 v[13] = rot8(v[13]);
177 v[14] = rot8(v[14]);
178 v[10] = add(v[10], v[15]);
179 v[11] = add(v[11], v[12]);
180 v[8] = add(v[8], v[13]);
181 v[9] = add(v[9], v[14]);
182 v[5] = xor(v[5], v[10]);
183 v[6] = xor(v[6], v[11]);
184 v[7] = xor(v[7], v[8]);
185 v[4] = xor(v[4], v[9]);
186 v[5] = rot7(v[5]);
187 v[6] = rot7(v[6]);
188 v[7] = rot7(v[7]);
189 v[4] = rot7(v[4]);
190 }
191
192 #[inline(always)]
interleave128(a: __m256i, b: __m256i) -> (__m256i, __m256i)193 unsafe fn interleave128(a: __m256i, b: __m256i) -> (__m256i, __m256i) {
194 (
195 _mm256_permute2x128_si256(a, b, 0x20),
196 _mm256_permute2x128_si256(a, b, 0x31),
197 )
198 }
199
200 // There are several ways to do a transposition. We could do it naively, with 8 separate
201 // _mm256_set_epi32 instructions, referencing each of the 32 words explicitly. Or we could copy
202 // the vecs into contiguous storage and then use gather instructions. This third approach is to use
203 // a series of unpack instructions to interleave the vectors. In my benchmarks, interleaving is the
204 // fastest approach. To test this, run `cargo +nightly bench --bench libtest load_8` in the
205 // https://github.com/oconnor663/bao_experiments repo.
206 #[inline(always)]
transpose_vecs(vecs: &mut [__m256i; DEGREE])207 unsafe fn transpose_vecs(vecs: &mut [__m256i; DEGREE]) {
208 // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high is 22/33/66/77.
209 let ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
210 let ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
211 let cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
212 let cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
213 let ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
214 let ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
215 let gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
216 let gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
217
218 // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is 11/33.
219 let abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
220 let abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
221 let abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
222 let abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
223 let efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
224 let efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
225 let efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
226 let efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
227
228 // Interleave 128-bit lanes.
229 let (abcdefgh_0, abcdefgh_4) = interleave128(abcd_04, efgh_04);
230 let (abcdefgh_1, abcdefgh_5) = interleave128(abcd_15, efgh_15);
231 let (abcdefgh_2, abcdefgh_6) = interleave128(abcd_26, efgh_26);
232 let (abcdefgh_3, abcdefgh_7) = interleave128(abcd_37, efgh_37);
233
234 vecs[0] = abcdefgh_0;
235 vecs[1] = abcdefgh_1;
236 vecs[2] = abcdefgh_2;
237 vecs[3] = abcdefgh_3;
238 vecs[4] = abcdefgh_4;
239 vecs[5] = abcdefgh_5;
240 vecs[6] = abcdefgh_6;
241 vecs[7] = abcdefgh_7;
242 }
243
244 #[inline(always)]
transpose_msg_vecs(inputs: &[*const u8; DEGREE], block_offset: usize) -> [__m256i; 16]245 unsafe fn transpose_msg_vecs(inputs: &[*const u8; DEGREE], block_offset: usize) -> [__m256i; 16] {
246 let mut vecs = [
247 loadu(inputs[0].add(block_offset + 0 * 4 * DEGREE)),
248 loadu(inputs[1].add(block_offset + 0 * 4 * DEGREE)),
249 loadu(inputs[2].add(block_offset + 0 * 4 * DEGREE)),
250 loadu(inputs[3].add(block_offset + 0 * 4 * DEGREE)),
251 loadu(inputs[4].add(block_offset + 0 * 4 * DEGREE)),
252 loadu(inputs[5].add(block_offset + 0 * 4 * DEGREE)),
253 loadu(inputs[6].add(block_offset + 0 * 4 * DEGREE)),
254 loadu(inputs[7].add(block_offset + 0 * 4 * DEGREE)),
255 loadu(inputs[0].add(block_offset + 1 * 4 * DEGREE)),
256 loadu(inputs[1].add(block_offset + 1 * 4 * DEGREE)),
257 loadu(inputs[2].add(block_offset + 1 * 4 * DEGREE)),
258 loadu(inputs[3].add(block_offset + 1 * 4 * DEGREE)),
259 loadu(inputs[4].add(block_offset + 1 * 4 * DEGREE)),
260 loadu(inputs[5].add(block_offset + 1 * 4 * DEGREE)),
261 loadu(inputs[6].add(block_offset + 1 * 4 * DEGREE)),
262 loadu(inputs[7].add(block_offset + 1 * 4 * DEGREE)),
263 ];
264 for i in 0..DEGREE {
265 _mm_prefetch(inputs[i].add(block_offset + 256) as *const i8, _MM_HINT_T0);
266 }
267 let squares = mut_array_refs!(&mut vecs, DEGREE, DEGREE);
268 transpose_vecs(squares.0);
269 transpose_vecs(squares.1);
270 vecs
271 }
272
273 #[inline(always)]
load_counters(counter: u64, increment_counter: IncrementCounter) -> (__m256i, __m256i)274 unsafe fn load_counters(counter: u64, increment_counter: IncrementCounter) -> (__m256i, __m256i) {
275 let mask = if increment_counter.yes() { !0 } else { 0 };
276 (
277 set8(
278 counter_low(counter + (mask & 0)),
279 counter_low(counter + (mask & 1)),
280 counter_low(counter + (mask & 2)),
281 counter_low(counter + (mask & 3)),
282 counter_low(counter + (mask & 4)),
283 counter_low(counter + (mask & 5)),
284 counter_low(counter + (mask & 6)),
285 counter_low(counter + (mask & 7)),
286 ),
287 set8(
288 counter_high(counter + (mask & 0)),
289 counter_high(counter + (mask & 1)),
290 counter_high(counter + (mask & 2)),
291 counter_high(counter + (mask & 3)),
292 counter_high(counter + (mask & 4)),
293 counter_high(counter + (mask & 5)),
294 counter_high(counter + (mask & 6)),
295 counter_high(counter + (mask & 7)),
296 ),
297 )
298 }
299
300 #[target_feature(enable = "avx2")]
hash8( inputs: &[*const u8; DEGREE], blocks: usize, key: &CVWords, counter: u64, increment_counter: IncrementCounter, flags: u8, flags_start: u8, flags_end: u8, out: &mut [u8; DEGREE * OUT_LEN], )301 pub unsafe fn hash8(
302 inputs: &[*const u8; DEGREE],
303 blocks: usize,
304 key: &CVWords,
305 counter: u64,
306 increment_counter: IncrementCounter,
307 flags: u8,
308 flags_start: u8,
309 flags_end: u8,
310 out: &mut [u8; DEGREE * OUT_LEN],
311 ) {
312 let mut h_vecs = [
313 set1(key[0]),
314 set1(key[1]),
315 set1(key[2]),
316 set1(key[3]),
317 set1(key[4]),
318 set1(key[5]),
319 set1(key[6]),
320 set1(key[7]),
321 ];
322 let (counter_low_vec, counter_high_vec) = load_counters(counter, increment_counter);
323 let mut block_flags = flags | flags_start;
324
325 for block in 0..blocks {
326 if block + 1 == blocks {
327 block_flags |= flags_end;
328 }
329 let block_len_vec = set1(BLOCK_LEN as u32); // full blocks only
330 let block_flags_vec = set1(block_flags as u32);
331 let msg_vecs = transpose_msg_vecs(inputs, block * BLOCK_LEN);
332
333 // The transposed compression function. Note that inlining this
334 // manually here improves compile times by a lot, compared to factoring
335 // it out into its own function and making it #[inline(always)]. Just
336 // guessing, it might have something to do with loop unrolling.
337 let mut v = [
338 h_vecs[0],
339 h_vecs[1],
340 h_vecs[2],
341 h_vecs[3],
342 h_vecs[4],
343 h_vecs[5],
344 h_vecs[6],
345 h_vecs[7],
346 set1(IV[0]),
347 set1(IV[1]),
348 set1(IV[2]),
349 set1(IV[3]),
350 counter_low_vec,
351 counter_high_vec,
352 block_len_vec,
353 block_flags_vec,
354 ];
355 round(&mut v, &msg_vecs, 0);
356 round(&mut v, &msg_vecs, 1);
357 round(&mut v, &msg_vecs, 2);
358 round(&mut v, &msg_vecs, 3);
359 round(&mut v, &msg_vecs, 4);
360 round(&mut v, &msg_vecs, 5);
361 round(&mut v, &msg_vecs, 6);
362 h_vecs[0] = xor(v[0], v[8]);
363 h_vecs[1] = xor(v[1], v[9]);
364 h_vecs[2] = xor(v[2], v[10]);
365 h_vecs[3] = xor(v[3], v[11]);
366 h_vecs[4] = xor(v[4], v[12]);
367 h_vecs[5] = xor(v[5], v[13]);
368 h_vecs[6] = xor(v[6], v[14]);
369 h_vecs[7] = xor(v[7], v[15]);
370
371 block_flags = flags;
372 }
373
374 transpose_vecs(&mut h_vecs);
375 storeu(h_vecs[0], out.as_mut_ptr().add(0 * 4 * DEGREE));
376 storeu(h_vecs[1], out.as_mut_ptr().add(1 * 4 * DEGREE));
377 storeu(h_vecs[2], out.as_mut_ptr().add(2 * 4 * DEGREE));
378 storeu(h_vecs[3], out.as_mut_ptr().add(3 * 4 * DEGREE));
379 storeu(h_vecs[4], out.as_mut_ptr().add(4 * 4 * DEGREE));
380 storeu(h_vecs[5], out.as_mut_ptr().add(5 * 4 * DEGREE));
381 storeu(h_vecs[6], out.as_mut_ptr().add(6 * 4 * DEGREE));
382 storeu(h_vecs[7], out.as_mut_ptr().add(7 * 4 * DEGREE));
383 }
384
385 #[target_feature(enable = "avx2")]
hash_many<const N: usize>( mut inputs: &[&[u8; N]], key: &CVWords, mut counter: u64, increment_counter: IncrementCounter, flags: u8, flags_start: u8, flags_end: u8, mut out: &mut [u8], )386 pub unsafe fn hash_many<const N: usize>(
387 mut inputs: &[&[u8; N]],
388 key: &CVWords,
389 mut counter: u64,
390 increment_counter: IncrementCounter,
391 flags: u8,
392 flags_start: u8,
393 flags_end: u8,
394 mut out: &mut [u8],
395 ) {
396 debug_assert!(out.len() >= inputs.len() * OUT_LEN, "out too short");
397 while inputs.len() >= DEGREE && out.len() >= DEGREE * OUT_LEN {
398 // Safe because the layout of arrays is guaranteed, and because the
399 // `blocks` count is determined statically from the argument type.
400 let input_ptrs: &[*const u8; DEGREE] = &*(inputs.as_ptr() as *const [*const u8; DEGREE]);
401 let blocks = N / BLOCK_LEN;
402 hash8(
403 input_ptrs,
404 blocks,
405 key,
406 counter,
407 increment_counter,
408 flags,
409 flags_start,
410 flags_end,
411 array_mut_ref!(out, 0, DEGREE * OUT_LEN),
412 );
413 if increment_counter.yes() {
414 counter += DEGREE as u64;
415 }
416 inputs = &inputs[DEGREE..];
417 out = &mut out[DEGREE * OUT_LEN..];
418 }
419 crate::sse41::hash_many(
420 inputs,
421 key,
422 counter,
423 increment_counter,
424 flags,
425 flags_start,
426 flags_end,
427 out,
428 );
429 }
430
431 #[cfg(test)]
432 mod test {
433 use super::*;
434
435 #[test]
test_transpose()436 fn test_transpose() {
437 if !crate::platform::avx2_detected() {
438 return;
439 }
440
441 #[target_feature(enable = "avx2")]
442 unsafe fn transpose_wrapper(vecs: &mut [__m256i; DEGREE]) {
443 transpose_vecs(vecs);
444 }
445
446 let mut matrix = [[0 as u32; DEGREE]; DEGREE];
447 for i in 0..DEGREE {
448 for j in 0..DEGREE {
449 matrix[i][j] = (i * DEGREE + j) as u32;
450 }
451 }
452
453 unsafe {
454 let mut vecs: [__m256i; DEGREE] = core::mem::transmute(matrix);
455 transpose_wrapper(&mut vecs);
456 matrix = core::mem::transmute(vecs);
457 }
458
459 for i in 0..DEGREE {
460 for j in 0..DEGREE {
461 // Reversed indexes from above.
462 assert_eq!(matrix[j][i], (i * DEGREE + j) as u32);
463 }
464 }
465 }
466
467 #[test]
test_hash_many()468 fn test_hash_many() {
469 if !crate::platform::avx2_detected() {
470 return;
471 }
472 crate::test::test_hash_many_fn(hash_many, hash_many);
473 }
474 }
475