1 #include "blake3_impl.h"
2 
3 #include <immintrin.h>
4 
5 #define _mm_shuffle_ps2(a, b, c)                                               \
6   (_mm_castps_si128(                                                           \
7       _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))
8 
loadu_128(const uint8_t src[16])9 INLINE __m128i loadu_128(const uint8_t src[16]) {
10   return _mm_loadu_si128((const __m128i *)src);
11 }
12 
loadu_256(const uint8_t src[32])13 INLINE __m256i loadu_256(const uint8_t src[32]) {
14   return _mm256_loadu_si256((const __m256i *)src);
15 }
16 
loadu_512(const uint8_t src[64])17 INLINE __m512i loadu_512(const uint8_t src[64]) {
18   return _mm512_loadu_si512((const __m512i *)src);
19 }
20 
storeu_128(__m128i src,uint8_t dest[16])21 INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
22   _mm_storeu_si128((__m128i *)dest, src);
23 }
24 
storeu_256(__m256i src,uint8_t dest[16])25 INLINE void storeu_256(__m256i src, uint8_t dest[16]) {
26   _mm256_storeu_si256((__m256i *)dest, src);
27 }
28 
add_128(__m128i a,__m128i b)29 INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
30 
add_256(__m256i a,__m256i b)31 INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
32 
add_512(__m512i a,__m512i b)33 INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(a, b); }
34 
xor_128(__m128i a,__m128i b)35 INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a, b); }
36 
xor_256(__m256i a,__m256i b)37 INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }
38 
xor_512(__m512i a,__m512i b)39 INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a, b); }
40 
set1_128(uint32_t x)41 INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32((int32_t)x); }
42 
set1_256(uint32_t x)43 INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32((int32_t)x); }
44 
set1_512(uint32_t x)45 INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32((int32_t)x); }
46 
set4(uint32_t a,uint32_t b,uint32_t c,uint32_t d)47 INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
48   return _mm_setr_epi32((int32_t)a, (int32_t)b, (int32_t)c, (int32_t)d);
49 }
50 
rot16_128(__m128i x)51 INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); }
52 
rot16_256(__m256i x)53 INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); }
54 
rot16_512(__m512i x)55 INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); }
56 
rot12_128(__m128i x)57 INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); }
58 
rot12_256(__m256i x)59 INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); }
60 
rot12_512(__m512i x)61 INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); }
62 
rot8_128(__m128i x)63 INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); }
64 
rot8_256(__m256i x)65 INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); }
66 
rot8_512(__m512i x)67 INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); }
68 
rot7_128(__m128i x)69 INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); }
70 
rot7_256(__m256i x)71 INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); }
72 
rot7_512(__m512i x)73 INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); }
74 
75 /*
76  * ----------------------------------------------------------------------------
77  * compress_avx512
78  * ----------------------------------------------------------------------------
79  */
80 
g1(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)81 INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
82                __m128i m) {
83   *row0 = add_128(add_128(*row0, m), *row1);
84   *row3 = xor_128(*row3, *row0);
85   *row3 = rot16_128(*row3);
86   *row2 = add_128(*row2, *row3);
87   *row1 = xor_128(*row1, *row2);
88   *row1 = rot12_128(*row1);
89 }
90 
g2(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)91 INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
92                __m128i m) {
93   *row0 = add_128(add_128(*row0, m), *row1);
94   *row3 = xor_128(*row3, *row0);
95   *row3 = rot8_128(*row3);
96   *row2 = add_128(*row2, *row3);
97   *row1 = xor_128(*row1, *row2);
98   *row1 = rot7_128(*row1);
99 }
100 
101 // Note the optimization here of leaving row1 as the unrotated row, rather than
102 // row0. All the message loads below are adjusted to compensate for this. See
103 // discussion at https://github.com/sneves/blake2-avx2/pull/4
diagonalize(__m128i * row0,__m128i * row2,__m128i * row3)104 INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
105   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3));
106   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
107   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1));
108 }
109 
undiagonalize(__m128i * row0,__m128i * row2,__m128i * row3)110 INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
111   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1));
112   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
113   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3));
114 }
115 
compress_pre(__m128i rows[4],const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)116 INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8],
117                          const uint8_t block[BLAKE3_BLOCK_LEN],
118                          uint8_t block_len, uint64_t counter, uint8_t flags) {
119   rows[0] = loadu_128((uint8_t *)&cv[0]);
120   rows[1] = loadu_128((uint8_t *)&cv[4]);
121   rows[2] = set4(IV[0], IV[1], IV[2], IV[3]);
122   rows[3] = set4(counter_low(counter), counter_high(counter),
123                  (uint32_t)block_len, (uint32_t)flags);
124 
125   __m128i m0 = loadu_128(&block[sizeof(__m128i) * 0]);
126   __m128i m1 = loadu_128(&block[sizeof(__m128i) * 1]);
127   __m128i m2 = loadu_128(&block[sizeof(__m128i) * 2]);
128   __m128i m3 = loadu_128(&block[sizeof(__m128i) * 3]);
129 
130   __m128i t0, t1, t2, t3, tt;
131 
132   // Round 1. The first round permutes the message words from the original
133   // input order, into the groups that get mixed in parallel.
134   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); //  6  4  2  0
135   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
136   t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); //  7  5  3  1
137   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
138   diagonalize(&rows[0], &rows[2], &rows[3]);
139   t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10  8
140   t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3));   // 12 10  8 14
141   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
142   t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11  9
143   t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3));   // 13 11  9 15
144   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
145   undiagonalize(&rows[0], &rows[2], &rows[3]);
146   m0 = t0;
147   m1 = t1;
148   m2 = t2;
149   m3 = t3;
150 
151   // Round 2. This round and all following rounds apply a fixed permutation
152   // to the message words from the round before.
153   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
154   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
155   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
156   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
157   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
158   t1 = _mm_blend_epi16(tt, t1, 0xCC);
159   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
160   diagonalize(&rows[0], &rows[2], &rows[3]);
161   t2 = _mm_unpacklo_epi64(m3, m1);
162   tt = _mm_blend_epi16(t2, m2, 0xC0);
163   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
164   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
165   t3 = _mm_unpackhi_epi32(m1, m3);
166   tt = _mm_unpacklo_epi32(m2, t3);
167   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
168   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
169   undiagonalize(&rows[0], &rows[2], &rows[3]);
170   m0 = t0;
171   m1 = t1;
172   m2 = t2;
173   m3 = t3;
174 
175   // Round 3
176   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
177   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
178   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
179   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
180   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
181   t1 = _mm_blend_epi16(tt, t1, 0xCC);
182   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
183   diagonalize(&rows[0], &rows[2], &rows[3]);
184   t2 = _mm_unpacklo_epi64(m3, m1);
185   tt = _mm_blend_epi16(t2, m2, 0xC0);
186   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
187   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
188   t3 = _mm_unpackhi_epi32(m1, m3);
189   tt = _mm_unpacklo_epi32(m2, t3);
190   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
191   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
192   undiagonalize(&rows[0], &rows[2], &rows[3]);
193   m0 = t0;
194   m1 = t1;
195   m2 = t2;
196   m3 = t3;
197 
198   // Round 4
199   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
200   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
201   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
202   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
203   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
204   t1 = _mm_blend_epi16(tt, t1, 0xCC);
205   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
206   diagonalize(&rows[0], &rows[2], &rows[3]);
207   t2 = _mm_unpacklo_epi64(m3, m1);
208   tt = _mm_blend_epi16(t2, m2, 0xC0);
209   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
210   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
211   t3 = _mm_unpackhi_epi32(m1, m3);
212   tt = _mm_unpacklo_epi32(m2, t3);
213   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
214   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
215   undiagonalize(&rows[0], &rows[2], &rows[3]);
216   m0 = t0;
217   m1 = t1;
218   m2 = t2;
219   m3 = t3;
220 
221   // Round 5
222   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
223   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
224   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
225   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
226   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
227   t1 = _mm_blend_epi16(tt, t1, 0xCC);
228   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
229   diagonalize(&rows[0], &rows[2], &rows[3]);
230   t2 = _mm_unpacklo_epi64(m3, m1);
231   tt = _mm_blend_epi16(t2, m2, 0xC0);
232   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
233   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
234   t3 = _mm_unpackhi_epi32(m1, m3);
235   tt = _mm_unpacklo_epi32(m2, t3);
236   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
237   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
238   undiagonalize(&rows[0], &rows[2], &rows[3]);
239   m0 = t0;
240   m1 = t1;
241   m2 = t2;
242   m3 = t3;
243 
244   // Round 6
245   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
246   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
247   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
248   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
249   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
250   t1 = _mm_blend_epi16(tt, t1, 0xCC);
251   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
252   diagonalize(&rows[0], &rows[2], &rows[3]);
253   t2 = _mm_unpacklo_epi64(m3, m1);
254   tt = _mm_blend_epi16(t2, m2, 0xC0);
255   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
256   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
257   t3 = _mm_unpackhi_epi32(m1, m3);
258   tt = _mm_unpacklo_epi32(m2, t3);
259   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
260   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
261   undiagonalize(&rows[0], &rows[2], &rows[3]);
262   m0 = t0;
263   m1 = t1;
264   m2 = t2;
265   m3 = t3;
266 
267   // Round 7
268   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
269   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
270   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
271   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
272   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
273   t1 = _mm_blend_epi16(tt, t1, 0xCC);
274   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
275   diagonalize(&rows[0], &rows[2], &rows[3]);
276   t2 = _mm_unpacklo_epi64(m3, m1);
277   tt = _mm_blend_epi16(t2, m2, 0xC0);
278   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
279   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
280   t3 = _mm_unpackhi_epi32(m1, m3);
281   tt = _mm_unpacklo_epi32(m2, t3);
282   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
283   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
284   undiagonalize(&rows[0], &rows[2], &rows[3]);
285 }
286 
blake3_compress_xof_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t out[64])287 void blake3_compress_xof_avx512(const uint32_t cv[8],
288                                 const uint8_t block[BLAKE3_BLOCK_LEN],
289                                 uint8_t block_len, uint64_t counter,
290                                 uint8_t flags, uint8_t out[64]) {
291   __m128i rows[4];
292   compress_pre(rows, cv, block, block_len, counter, flags);
293   storeu_128(xor_128(rows[0], rows[2]), &out[0]);
294   storeu_128(xor_128(rows[1], rows[3]), &out[16]);
295   storeu_128(xor_128(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]);
296   storeu_128(xor_128(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]);
297 }
298 
blake3_compress_in_place_avx512(uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)299 void blake3_compress_in_place_avx512(uint32_t cv[8],
300                                      const uint8_t block[BLAKE3_BLOCK_LEN],
301                                      uint8_t block_len, uint64_t counter,
302                                      uint8_t flags) {
303   __m128i rows[4];
304   compress_pre(rows, cv, block, block_len, counter, flags);
305   storeu_128(xor_128(rows[0], rows[2]), (uint8_t *)&cv[0]);
306   storeu_128(xor_128(rows[1], rows[3]), (uint8_t *)&cv[4]);
307 }
308 
309 /*
310  * ----------------------------------------------------------------------------
311  * hash4_avx512
312  * ----------------------------------------------------------------------------
313  */
314 
round_fn4(__m128i v[16],__m128i m[16],size_t r)315 INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) {
316   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
317   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
318   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
319   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
320   v[0] = add_128(v[0], v[4]);
321   v[1] = add_128(v[1], v[5]);
322   v[2] = add_128(v[2], v[6]);
323   v[3] = add_128(v[3], v[7]);
324   v[12] = xor_128(v[12], v[0]);
325   v[13] = xor_128(v[13], v[1]);
326   v[14] = xor_128(v[14], v[2]);
327   v[15] = xor_128(v[15], v[3]);
328   v[12] = rot16_128(v[12]);
329   v[13] = rot16_128(v[13]);
330   v[14] = rot16_128(v[14]);
331   v[15] = rot16_128(v[15]);
332   v[8] = add_128(v[8], v[12]);
333   v[9] = add_128(v[9], v[13]);
334   v[10] = add_128(v[10], v[14]);
335   v[11] = add_128(v[11], v[15]);
336   v[4] = xor_128(v[4], v[8]);
337   v[5] = xor_128(v[5], v[9]);
338   v[6] = xor_128(v[6], v[10]);
339   v[7] = xor_128(v[7], v[11]);
340   v[4] = rot12_128(v[4]);
341   v[5] = rot12_128(v[5]);
342   v[6] = rot12_128(v[6]);
343   v[7] = rot12_128(v[7]);
344   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
345   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
346   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
347   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
348   v[0] = add_128(v[0], v[4]);
349   v[1] = add_128(v[1], v[5]);
350   v[2] = add_128(v[2], v[6]);
351   v[3] = add_128(v[3], v[7]);
352   v[12] = xor_128(v[12], v[0]);
353   v[13] = xor_128(v[13], v[1]);
354   v[14] = xor_128(v[14], v[2]);
355   v[15] = xor_128(v[15], v[3]);
356   v[12] = rot8_128(v[12]);
357   v[13] = rot8_128(v[13]);
358   v[14] = rot8_128(v[14]);
359   v[15] = rot8_128(v[15]);
360   v[8] = add_128(v[8], v[12]);
361   v[9] = add_128(v[9], v[13]);
362   v[10] = add_128(v[10], v[14]);
363   v[11] = add_128(v[11], v[15]);
364   v[4] = xor_128(v[4], v[8]);
365   v[5] = xor_128(v[5], v[9]);
366   v[6] = xor_128(v[6], v[10]);
367   v[7] = xor_128(v[7], v[11]);
368   v[4] = rot7_128(v[4]);
369   v[5] = rot7_128(v[5]);
370   v[6] = rot7_128(v[6]);
371   v[7] = rot7_128(v[7]);
372 
373   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
374   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
375   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
376   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
377   v[0] = add_128(v[0], v[5]);
378   v[1] = add_128(v[1], v[6]);
379   v[2] = add_128(v[2], v[7]);
380   v[3] = add_128(v[3], v[4]);
381   v[15] = xor_128(v[15], v[0]);
382   v[12] = xor_128(v[12], v[1]);
383   v[13] = xor_128(v[13], v[2]);
384   v[14] = xor_128(v[14], v[3]);
385   v[15] = rot16_128(v[15]);
386   v[12] = rot16_128(v[12]);
387   v[13] = rot16_128(v[13]);
388   v[14] = rot16_128(v[14]);
389   v[10] = add_128(v[10], v[15]);
390   v[11] = add_128(v[11], v[12]);
391   v[8] = add_128(v[8], v[13]);
392   v[9] = add_128(v[9], v[14]);
393   v[5] = xor_128(v[5], v[10]);
394   v[6] = xor_128(v[6], v[11]);
395   v[7] = xor_128(v[7], v[8]);
396   v[4] = xor_128(v[4], v[9]);
397   v[5] = rot12_128(v[5]);
398   v[6] = rot12_128(v[6]);
399   v[7] = rot12_128(v[7]);
400   v[4] = rot12_128(v[4]);
401   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
402   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
403   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
404   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
405   v[0] = add_128(v[0], v[5]);
406   v[1] = add_128(v[1], v[6]);
407   v[2] = add_128(v[2], v[7]);
408   v[3] = add_128(v[3], v[4]);
409   v[15] = xor_128(v[15], v[0]);
410   v[12] = xor_128(v[12], v[1]);
411   v[13] = xor_128(v[13], v[2]);
412   v[14] = xor_128(v[14], v[3]);
413   v[15] = rot8_128(v[15]);
414   v[12] = rot8_128(v[12]);
415   v[13] = rot8_128(v[13]);
416   v[14] = rot8_128(v[14]);
417   v[10] = add_128(v[10], v[15]);
418   v[11] = add_128(v[11], v[12]);
419   v[8] = add_128(v[8], v[13]);
420   v[9] = add_128(v[9], v[14]);
421   v[5] = xor_128(v[5], v[10]);
422   v[6] = xor_128(v[6], v[11]);
423   v[7] = xor_128(v[7], v[8]);
424   v[4] = xor_128(v[4], v[9]);
425   v[5] = rot7_128(v[5]);
426   v[6] = rot7_128(v[6]);
427   v[7] = rot7_128(v[7]);
428   v[4] = rot7_128(v[4]);
429 }
430 
transpose_vecs_128(__m128i vecs[4])431 INLINE void transpose_vecs_128(__m128i vecs[4]) {
432   // Interleave 32-bit lates. The low unpack is lanes 00/11 and the high is
433   // 22/33. Note that this doesn't split the vector into two lanes, as the
434   // AVX2 counterparts do.
435   __m128i ab_01 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
436   __m128i ab_23 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
437   __m128i cd_01 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
438   __m128i cd_23 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
439 
440   // Interleave 64-bit lanes.
441   __m128i abcd_0 = _mm_unpacklo_epi64(ab_01, cd_01);
442   __m128i abcd_1 = _mm_unpackhi_epi64(ab_01, cd_01);
443   __m128i abcd_2 = _mm_unpacklo_epi64(ab_23, cd_23);
444   __m128i abcd_3 = _mm_unpackhi_epi64(ab_23, cd_23);
445 
446   vecs[0] = abcd_0;
447   vecs[1] = abcd_1;
448   vecs[2] = abcd_2;
449   vecs[3] = abcd_3;
450 }
451 
transpose_msg_vecs4(const uint8_t * const * inputs,size_t block_offset,__m128i out[16])452 INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
453                                 size_t block_offset, __m128i out[16]) {
454   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(__m128i)]);
455   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(__m128i)]);
456   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(__m128i)]);
457   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(__m128i)]);
458   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(__m128i)]);
459   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(__m128i)]);
460   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(__m128i)]);
461   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(__m128i)]);
462   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(__m128i)]);
463   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(__m128i)]);
464   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(__m128i)]);
465   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(__m128i)]);
466   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(__m128i)]);
467   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(__m128i)]);
468   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(__m128i)]);
469   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(__m128i)]);
470   for (size_t i = 0; i < 4; ++i) {
471     _mm_prefetch(&inputs[i][block_offset + 256], _MM_HINT_T0);
472   }
473   transpose_vecs_128(&out[0]);
474   transpose_vecs_128(&out[4]);
475   transpose_vecs_128(&out[8]);
476   transpose_vecs_128(&out[12]);
477 }
478 
load_counters4(uint64_t counter,bool increment_counter,__m128i * out_lo,__m128i * out_hi)479 INLINE void load_counters4(uint64_t counter, bool increment_counter,
480                            __m128i *out_lo, __m128i *out_hi) {
481   uint64_t mask = (increment_counter ? ~0 : 0);
482   __m256i mask_vec = _mm256_set1_epi64x(mask);
483   __m256i deltas = _mm256_setr_epi64x(0, 1, 2, 3);
484   deltas = _mm256_and_si256(mask_vec, deltas);
485   __m256i counters =
486       _mm256_add_epi64(_mm256_set1_epi64x((int64_t)counter), deltas);
487   *out_lo = _mm256_cvtepi64_epi32(counters);
488   *out_hi = _mm256_cvtepi64_epi32(_mm256_srli_epi64(counters, 32));
489 }
490 
blake3_hash4_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)491 void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
492                          const uint32_t key[8], uint64_t counter,
493                          bool increment_counter, uint8_t flags,
494                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
495   __m128i h_vecs[8] = {
496       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
497       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
498   };
499   __m128i counter_low_vec, counter_high_vec;
500   load_counters4(counter, increment_counter, &counter_low_vec,
501                  &counter_high_vec);
502   uint8_t block_flags = flags | flags_start;
503 
504   for (size_t block = 0; block < blocks; block++) {
505     if (block + 1 == blocks) {
506       block_flags |= flags_end;
507     }
508     __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
509     __m128i block_flags_vec = set1_128(block_flags);
510     __m128i msg_vecs[16];
511     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
512 
513     __m128i v[16] = {
514         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
515         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
516         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
517         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
518     };
519     round_fn4(v, msg_vecs, 0);
520     round_fn4(v, msg_vecs, 1);
521     round_fn4(v, msg_vecs, 2);
522     round_fn4(v, msg_vecs, 3);
523     round_fn4(v, msg_vecs, 4);
524     round_fn4(v, msg_vecs, 5);
525     round_fn4(v, msg_vecs, 6);
526     h_vecs[0] = xor_128(v[0], v[8]);
527     h_vecs[1] = xor_128(v[1], v[9]);
528     h_vecs[2] = xor_128(v[2], v[10]);
529     h_vecs[3] = xor_128(v[3], v[11]);
530     h_vecs[4] = xor_128(v[4], v[12]);
531     h_vecs[5] = xor_128(v[5], v[13]);
532     h_vecs[6] = xor_128(v[6], v[14]);
533     h_vecs[7] = xor_128(v[7], v[15]);
534 
535     block_flags = flags;
536   }
537 
538   transpose_vecs_128(&h_vecs[0]);
539   transpose_vecs_128(&h_vecs[4]);
540   // The first four vecs now contain the first half of each output, and the
541   // second four vecs contain the second half of each output.
542   storeu_128(h_vecs[0], &out[0 * sizeof(__m128i)]);
543   storeu_128(h_vecs[4], &out[1 * sizeof(__m128i)]);
544   storeu_128(h_vecs[1], &out[2 * sizeof(__m128i)]);
545   storeu_128(h_vecs[5], &out[3 * sizeof(__m128i)]);
546   storeu_128(h_vecs[2], &out[4 * sizeof(__m128i)]);
547   storeu_128(h_vecs[6], &out[5 * sizeof(__m128i)]);
548   storeu_128(h_vecs[3], &out[6 * sizeof(__m128i)]);
549   storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
550 }
551 
552 /*
553  * ----------------------------------------------------------------------------
554  * hash8_avx512
555  * ----------------------------------------------------------------------------
556  */
557 
round_fn8(__m256i v[16],__m256i m[16],size_t r)558 INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) {
559   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
560   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
561   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
562   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
563   v[0] = add_256(v[0], v[4]);
564   v[1] = add_256(v[1], v[5]);
565   v[2] = add_256(v[2], v[6]);
566   v[3] = add_256(v[3], v[7]);
567   v[12] = xor_256(v[12], v[0]);
568   v[13] = xor_256(v[13], v[1]);
569   v[14] = xor_256(v[14], v[2]);
570   v[15] = xor_256(v[15], v[3]);
571   v[12] = rot16_256(v[12]);
572   v[13] = rot16_256(v[13]);
573   v[14] = rot16_256(v[14]);
574   v[15] = rot16_256(v[15]);
575   v[8] = add_256(v[8], v[12]);
576   v[9] = add_256(v[9], v[13]);
577   v[10] = add_256(v[10], v[14]);
578   v[11] = add_256(v[11], v[15]);
579   v[4] = xor_256(v[4], v[8]);
580   v[5] = xor_256(v[5], v[9]);
581   v[6] = xor_256(v[6], v[10]);
582   v[7] = xor_256(v[7], v[11]);
583   v[4] = rot12_256(v[4]);
584   v[5] = rot12_256(v[5]);
585   v[6] = rot12_256(v[6]);
586   v[7] = rot12_256(v[7]);
587   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
588   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
589   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
590   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
591   v[0] = add_256(v[0], v[4]);
592   v[1] = add_256(v[1], v[5]);
593   v[2] = add_256(v[2], v[6]);
594   v[3] = add_256(v[3], v[7]);
595   v[12] = xor_256(v[12], v[0]);
596   v[13] = xor_256(v[13], v[1]);
597   v[14] = xor_256(v[14], v[2]);
598   v[15] = xor_256(v[15], v[3]);
599   v[12] = rot8_256(v[12]);
600   v[13] = rot8_256(v[13]);
601   v[14] = rot8_256(v[14]);
602   v[15] = rot8_256(v[15]);
603   v[8] = add_256(v[8], v[12]);
604   v[9] = add_256(v[9], v[13]);
605   v[10] = add_256(v[10], v[14]);
606   v[11] = add_256(v[11], v[15]);
607   v[4] = xor_256(v[4], v[8]);
608   v[5] = xor_256(v[5], v[9]);
609   v[6] = xor_256(v[6], v[10]);
610   v[7] = xor_256(v[7], v[11]);
611   v[4] = rot7_256(v[4]);
612   v[5] = rot7_256(v[5]);
613   v[6] = rot7_256(v[6]);
614   v[7] = rot7_256(v[7]);
615 
616   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
617   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
618   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
619   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
620   v[0] = add_256(v[0], v[5]);
621   v[1] = add_256(v[1], v[6]);
622   v[2] = add_256(v[2], v[7]);
623   v[3] = add_256(v[3], v[4]);
624   v[15] = xor_256(v[15], v[0]);
625   v[12] = xor_256(v[12], v[1]);
626   v[13] = xor_256(v[13], v[2]);
627   v[14] = xor_256(v[14], v[3]);
628   v[15] = rot16_256(v[15]);
629   v[12] = rot16_256(v[12]);
630   v[13] = rot16_256(v[13]);
631   v[14] = rot16_256(v[14]);
632   v[10] = add_256(v[10], v[15]);
633   v[11] = add_256(v[11], v[12]);
634   v[8] = add_256(v[8], v[13]);
635   v[9] = add_256(v[9], v[14]);
636   v[5] = xor_256(v[5], v[10]);
637   v[6] = xor_256(v[6], v[11]);
638   v[7] = xor_256(v[7], v[8]);
639   v[4] = xor_256(v[4], v[9]);
640   v[5] = rot12_256(v[5]);
641   v[6] = rot12_256(v[6]);
642   v[7] = rot12_256(v[7]);
643   v[4] = rot12_256(v[4]);
644   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
645   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
646   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
647   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
648   v[0] = add_256(v[0], v[5]);
649   v[1] = add_256(v[1], v[6]);
650   v[2] = add_256(v[2], v[7]);
651   v[3] = add_256(v[3], v[4]);
652   v[15] = xor_256(v[15], v[0]);
653   v[12] = xor_256(v[12], v[1]);
654   v[13] = xor_256(v[13], v[2]);
655   v[14] = xor_256(v[14], v[3]);
656   v[15] = rot8_256(v[15]);
657   v[12] = rot8_256(v[12]);
658   v[13] = rot8_256(v[13]);
659   v[14] = rot8_256(v[14]);
660   v[10] = add_256(v[10], v[15]);
661   v[11] = add_256(v[11], v[12]);
662   v[8] = add_256(v[8], v[13]);
663   v[9] = add_256(v[9], v[14]);
664   v[5] = xor_256(v[5], v[10]);
665   v[6] = xor_256(v[6], v[11]);
666   v[7] = xor_256(v[7], v[8]);
667   v[4] = xor_256(v[4], v[9]);
668   v[5] = rot7_256(v[5]);
669   v[6] = rot7_256(v[6]);
670   v[7] = rot7_256(v[7]);
671   v[4] = rot7_256(v[4]);
672 }
673 
transpose_vecs_256(__m256i vecs[8])674 INLINE void transpose_vecs_256(__m256i vecs[8]) {
675   // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high
676   // is 22/33/66/77.
677   __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
678   __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
679   __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
680   __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
681   __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
682   __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
683   __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
684   __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
685 
686   // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is
687   // 11/33.
688   __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
689   __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
690   __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
691   __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
692   __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
693   __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
694   __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
695   __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
696 
697   // Interleave 128-bit lanes.
698   vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20);
699   vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20);
700   vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20);
701   vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20);
702   vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31);
703   vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31);
704   vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31);
705   vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31);
706 }
707 
transpose_msg_vecs8(const uint8_t * const * inputs,size_t block_offset,__m256i out[16])708 INLINE void transpose_msg_vecs8(const uint8_t *const *inputs,
709                                 size_t block_offset, __m256i out[16]) {
710   out[0] = loadu_256(&inputs[0][block_offset + 0 * sizeof(__m256i)]);
711   out[1] = loadu_256(&inputs[1][block_offset + 0 * sizeof(__m256i)]);
712   out[2] = loadu_256(&inputs[2][block_offset + 0 * sizeof(__m256i)]);
713   out[3] = loadu_256(&inputs[3][block_offset + 0 * sizeof(__m256i)]);
714   out[4] = loadu_256(&inputs[4][block_offset + 0 * sizeof(__m256i)]);
715   out[5] = loadu_256(&inputs[5][block_offset + 0 * sizeof(__m256i)]);
716   out[6] = loadu_256(&inputs[6][block_offset + 0 * sizeof(__m256i)]);
717   out[7] = loadu_256(&inputs[7][block_offset + 0 * sizeof(__m256i)]);
718   out[8] = loadu_256(&inputs[0][block_offset + 1 * sizeof(__m256i)]);
719   out[9] = loadu_256(&inputs[1][block_offset + 1 * sizeof(__m256i)]);
720   out[10] = loadu_256(&inputs[2][block_offset + 1 * sizeof(__m256i)]);
721   out[11] = loadu_256(&inputs[3][block_offset + 1 * sizeof(__m256i)]);
722   out[12] = loadu_256(&inputs[4][block_offset + 1 * sizeof(__m256i)]);
723   out[13] = loadu_256(&inputs[5][block_offset + 1 * sizeof(__m256i)]);
724   out[14] = loadu_256(&inputs[6][block_offset + 1 * sizeof(__m256i)]);
725   out[15] = loadu_256(&inputs[7][block_offset + 1 * sizeof(__m256i)]);
726   for (size_t i = 0; i < 8; ++i) {
727     _mm_prefetch(&inputs[i][block_offset + 256], _MM_HINT_T0);
728   }
729   transpose_vecs_256(&out[0]);
730   transpose_vecs_256(&out[8]);
731 }
732 
load_counters8(uint64_t counter,bool increment_counter,__m256i * out_lo,__m256i * out_hi)733 INLINE void load_counters8(uint64_t counter, bool increment_counter,
734                            __m256i *out_lo, __m256i *out_hi) {
735   uint64_t mask = (increment_counter ? ~0 : 0);
736   __m512i mask_vec = _mm512_set1_epi64(mask);
737   __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7);
738   deltas = _mm512_and_si512(mask_vec, deltas);
739   __m512i counters =
740       _mm512_add_epi64(_mm512_set1_epi64((int64_t)counter), deltas);
741   *out_lo = _mm512_cvtepi64_epi32(counters);
742   *out_hi = _mm512_cvtepi64_epi32(_mm512_srli_epi64(counters, 32));
743 }
744 
blake3_hash8_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)745 void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
746                          const uint32_t key[8], uint64_t counter,
747                          bool increment_counter, uint8_t flags,
748                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
749   __m256i h_vecs[8] = {
750       set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]),
751       set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]),
752   };
753   __m256i counter_low_vec, counter_high_vec;
754   load_counters8(counter, increment_counter, &counter_low_vec,
755                  &counter_high_vec);
756   uint8_t block_flags = flags | flags_start;
757 
758   for (size_t block = 0; block < blocks; block++) {
759     if (block + 1 == blocks) {
760       block_flags |= flags_end;
761     }
762     __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN);
763     __m256i block_flags_vec = set1_256(block_flags);
764     __m256i msg_vecs[16];
765     transpose_msg_vecs8(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
766 
767     __m256i v[16] = {
768         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
769         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
770         set1_256(IV[0]), set1_256(IV[1]),  set1_256(IV[2]), set1_256(IV[3]),
771         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
772     };
773     round_fn8(v, msg_vecs, 0);
774     round_fn8(v, msg_vecs, 1);
775     round_fn8(v, msg_vecs, 2);
776     round_fn8(v, msg_vecs, 3);
777     round_fn8(v, msg_vecs, 4);
778     round_fn8(v, msg_vecs, 5);
779     round_fn8(v, msg_vecs, 6);
780     h_vecs[0] = xor_256(v[0], v[8]);
781     h_vecs[1] = xor_256(v[1], v[9]);
782     h_vecs[2] = xor_256(v[2], v[10]);
783     h_vecs[3] = xor_256(v[3], v[11]);
784     h_vecs[4] = xor_256(v[4], v[12]);
785     h_vecs[5] = xor_256(v[5], v[13]);
786     h_vecs[6] = xor_256(v[6], v[14]);
787     h_vecs[7] = xor_256(v[7], v[15]);
788 
789     block_flags = flags;
790   }
791 
792   transpose_vecs_256(h_vecs);
793   storeu_256(h_vecs[0], &out[0 * sizeof(__m256i)]);
794   storeu_256(h_vecs[1], &out[1 * sizeof(__m256i)]);
795   storeu_256(h_vecs[2], &out[2 * sizeof(__m256i)]);
796   storeu_256(h_vecs[3], &out[3 * sizeof(__m256i)]);
797   storeu_256(h_vecs[4], &out[4 * sizeof(__m256i)]);
798   storeu_256(h_vecs[5], &out[5 * sizeof(__m256i)]);
799   storeu_256(h_vecs[6], &out[6 * sizeof(__m256i)]);
800   storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
801 }
802 
803 /*
804  * ----------------------------------------------------------------------------
805  * hash16_avx512
806  * ----------------------------------------------------------------------------
807  */
808 
round_fn16(__m512i v[16],__m512i m[16],size_t r)809 INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) {
810   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
811   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
812   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
813   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
814   v[0] = add_512(v[0], v[4]);
815   v[1] = add_512(v[1], v[5]);
816   v[2] = add_512(v[2], v[6]);
817   v[3] = add_512(v[3], v[7]);
818   v[12] = xor_512(v[12], v[0]);
819   v[13] = xor_512(v[13], v[1]);
820   v[14] = xor_512(v[14], v[2]);
821   v[15] = xor_512(v[15], v[3]);
822   v[12] = rot16_512(v[12]);
823   v[13] = rot16_512(v[13]);
824   v[14] = rot16_512(v[14]);
825   v[15] = rot16_512(v[15]);
826   v[8] = add_512(v[8], v[12]);
827   v[9] = add_512(v[9], v[13]);
828   v[10] = add_512(v[10], v[14]);
829   v[11] = add_512(v[11], v[15]);
830   v[4] = xor_512(v[4], v[8]);
831   v[5] = xor_512(v[5], v[9]);
832   v[6] = xor_512(v[6], v[10]);
833   v[7] = xor_512(v[7], v[11]);
834   v[4] = rot12_512(v[4]);
835   v[5] = rot12_512(v[5]);
836   v[6] = rot12_512(v[6]);
837   v[7] = rot12_512(v[7]);
838   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
839   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
840   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
841   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
842   v[0] = add_512(v[0], v[4]);
843   v[1] = add_512(v[1], v[5]);
844   v[2] = add_512(v[2], v[6]);
845   v[3] = add_512(v[3], v[7]);
846   v[12] = xor_512(v[12], v[0]);
847   v[13] = xor_512(v[13], v[1]);
848   v[14] = xor_512(v[14], v[2]);
849   v[15] = xor_512(v[15], v[3]);
850   v[12] = rot8_512(v[12]);
851   v[13] = rot8_512(v[13]);
852   v[14] = rot8_512(v[14]);
853   v[15] = rot8_512(v[15]);
854   v[8] = add_512(v[8], v[12]);
855   v[9] = add_512(v[9], v[13]);
856   v[10] = add_512(v[10], v[14]);
857   v[11] = add_512(v[11], v[15]);
858   v[4] = xor_512(v[4], v[8]);
859   v[5] = xor_512(v[5], v[9]);
860   v[6] = xor_512(v[6], v[10]);
861   v[7] = xor_512(v[7], v[11]);
862   v[4] = rot7_512(v[4]);
863   v[5] = rot7_512(v[5]);
864   v[6] = rot7_512(v[6]);
865   v[7] = rot7_512(v[7]);
866 
867   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
868   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
869   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
870   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
871   v[0] = add_512(v[0], v[5]);
872   v[1] = add_512(v[1], v[6]);
873   v[2] = add_512(v[2], v[7]);
874   v[3] = add_512(v[3], v[4]);
875   v[15] = xor_512(v[15], v[0]);
876   v[12] = xor_512(v[12], v[1]);
877   v[13] = xor_512(v[13], v[2]);
878   v[14] = xor_512(v[14], v[3]);
879   v[15] = rot16_512(v[15]);
880   v[12] = rot16_512(v[12]);
881   v[13] = rot16_512(v[13]);
882   v[14] = rot16_512(v[14]);
883   v[10] = add_512(v[10], v[15]);
884   v[11] = add_512(v[11], v[12]);
885   v[8] = add_512(v[8], v[13]);
886   v[9] = add_512(v[9], v[14]);
887   v[5] = xor_512(v[5], v[10]);
888   v[6] = xor_512(v[6], v[11]);
889   v[7] = xor_512(v[7], v[8]);
890   v[4] = xor_512(v[4], v[9]);
891   v[5] = rot12_512(v[5]);
892   v[6] = rot12_512(v[6]);
893   v[7] = rot12_512(v[7]);
894   v[4] = rot12_512(v[4]);
895   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
896   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
897   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
898   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
899   v[0] = add_512(v[0], v[5]);
900   v[1] = add_512(v[1], v[6]);
901   v[2] = add_512(v[2], v[7]);
902   v[3] = add_512(v[3], v[4]);
903   v[15] = xor_512(v[15], v[0]);
904   v[12] = xor_512(v[12], v[1]);
905   v[13] = xor_512(v[13], v[2]);
906   v[14] = xor_512(v[14], v[3]);
907   v[15] = rot8_512(v[15]);
908   v[12] = rot8_512(v[12]);
909   v[13] = rot8_512(v[13]);
910   v[14] = rot8_512(v[14]);
911   v[10] = add_512(v[10], v[15]);
912   v[11] = add_512(v[11], v[12]);
913   v[8] = add_512(v[8], v[13]);
914   v[9] = add_512(v[9], v[14]);
915   v[5] = xor_512(v[5], v[10]);
916   v[6] = xor_512(v[6], v[11]);
917   v[7] = xor_512(v[7], v[8]);
918   v[4] = xor_512(v[4], v[9]);
919   v[5] = rot7_512(v[5]);
920   v[6] = rot7_512(v[6]);
921   v[7] = rot7_512(v[7]);
922   v[4] = rot7_512(v[4]);
923 }
924 
925 // 0b10001000, or lanes a0/a2/b0/b2 in little-endian order
926 #define LO_IMM8 0x88
927 
unpack_lo_128(__m512i a,__m512i b)928 INLINE __m512i unpack_lo_128(__m512i a, __m512i b) {
929   return _mm512_shuffle_i32x4(a, b, LO_IMM8);
930 }
931 
932 // 0b11011101, or lanes a1/a3/b1/b3 in little-endian order
933 #define HI_IMM8 0xdd
934 
unpack_hi_128(__m512i a,__m512i b)935 INLINE __m512i unpack_hi_128(__m512i a, __m512i b) {
936   return _mm512_shuffle_i32x4(a, b, HI_IMM8);
937 }
938 
transpose_vecs_512(__m512i vecs[16])939 INLINE void transpose_vecs_512(__m512i vecs[16]) {
940   // Interleave 32-bit lanes. The _0 unpack is lanes
941   // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes
942   // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15.
943   __m512i ab_0 = _mm512_unpacklo_epi32(vecs[0], vecs[1]);
944   __m512i ab_2 = _mm512_unpackhi_epi32(vecs[0], vecs[1]);
945   __m512i cd_0 = _mm512_unpacklo_epi32(vecs[2], vecs[3]);
946   __m512i cd_2 = _mm512_unpackhi_epi32(vecs[2], vecs[3]);
947   __m512i ef_0 = _mm512_unpacklo_epi32(vecs[4], vecs[5]);
948   __m512i ef_2 = _mm512_unpackhi_epi32(vecs[4], vecs[5]);
949   __m512i gh_0 = _mm512_unpacklo_epi32(vecs[6], vecs[7]);
950   __m512i gh_2 = _mm512_unpackhi_epi32(vecs[6], vecs[7]);
951   __m512i ij_0 = _mm512_unpacklo_epi32(vecs[8], vecs[9]);
952   __m512i ij_2 = _mm512_unpackhi_epi32(vecs[8], vecs[9]);
953   __m512i kl_0 = _mm512_unpacklo_epi32(vecs[10], vecs[11]);
954   __m512i kl_2 = _mm512_unpackhi_epi32(vecs[10], vecs[11]);
955   __m512i mn_0 = _mm512_unpacklo_epi32(vecs[12], vecs[13]);
956   __m512i mn_2 = _mm512_unpackhi_epi32(vecs[12], vecs[13]);
957   __m512i op_0 = _mm512_unpacklo_epi32(vecs[14], vecs[15]);
958   __m512i op_2 = _mm512_unpackhi_epi32(vecs[14], vecs[15]);
959 
960   // Interleave 64-bit lates. The _0 unpack is lanes
961   // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes
962   // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes
963   // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes
964   // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15.
965   __m512i abcd_0 = _mm512_unpacklo_epi64(ab_0, cd_0);
966   __m512i abcd_1 = _mm512_unpackhi_epi64(ab_0, cd_0);
967   __m512i abcd_2 = _mm512_unpacklo_epi64(ab_2, cd_2);
968   __m512i abcd_3 = _mm512_unpackhi_epi64(ab_2, cd_2);
969   __m512i efgh_0 = _mm512_unpacklo_epi64(ef_0, gh_0);
970   __m512i efgh_1 = _mm512_unpackhi_epi64(ef_0, gh_0);
971   __m512i efgh_2 = _mm512_unpacklo_epi64(ef_2, gh_2);
972   __m512i efgh_3 = _mm512_unpackhi_epi64(ef_2, gh_2);
973   __m512i ijkl_0 = _mm512_unpacklo_epi64(ij_0, kl_0);
974   __m512i ijkl_1 = _mm512_unpackhi_epi64(ij_0, kl_0);
975   __m512i ijkl_2 = _mm512_unpacklo_epi64(ij_2, kl_2);
976   __m512i ijkl_3 = _mm512_unpackhi_epi64(ij_2, kl_2);
977   __m512i mnop_0 = _mm512_unpacklo_epi64(mn_0, op_0);
978   __m512i mnop_1 = _mm512_unpackhi_epi64(mn_0, op_0);
979   __m512i mnop_2 = _mm512_unpacklo_epi64(mn_2, op_2);
980   __m512i mnop_3 = _mm512_unpackhi_epi64(mn_2, op_2);
981 
982   // Interleave 128-bit lanes. The _0 unpack is
983   // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is
984   // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on.
985   __m512i abcdefgh_0 = unpack_lo_128(abcd_0, efgh_0);
986   __m512i abcdefgh_1 = unpack_lo_128(abcd_1, efgh_1);
987   __m512i abcdefgh_2 = unpack_lo_128(abcd_2, efgh_2);
988   __m512i abcdefgh_3 = unpack_lo_128(abcd_3, efgh_3);
989   __m512i abcdefgh_4 = unpack_hi_128(abcd_0, efgh_0);
990   __m512i abcdefgh_5 = unpack_hi_128(abcd_1, efgh_1);
991   __m512i abcdefgh_6 = unpack_hi_128(abcd_2, efgh_2);
992   __m512i abcdefgh_7 = unpack_hi_128(abcd_3, efgh_3);
993   __m512i ijklmnop_0 = unpack_lo_128(ijkl_0, mnop_0);
994   __m512i ijklmnop_1 = unpack_lo_128(ijkl_1, mnop_1);
995   __m512i ijklmnop_2 = unpack_lo_128(ijkl_2, mnop_2);
996   __m512i ijklmnop_3 = unpack_lo_128(ijkl_3, mnop_3);
997   __m512i ijklmnop_4 = unpack_hi_128(ijkl_0, mnop_0);
998   __m512i ijklmnop_5 = unpack_hi_128(ijkl_1, mnop_1);
999   __m512i ijklmnop_6 = unpack_hi_128(ijkl_2, mnop_2);
1000   __m512i ijklmnop_7 = unpack_hi_128(ijkl_3, mnop_3);
1001 
1002   // Interleave 128-bit lanes again for the final outputs.
1003   vecs[0] = unpack_lo_128(abcdefgh_0, ijklmnop_0);
1004   vecs[1] = unpack_lo_128(abcdefgh_1, ijklmnop_1);
1005   vecs[2] = unpack_lo_128(abcdefgh_2, ijklmnop_2);
1006   vecs[3] = unpack_lo_128(abcdefgh_3, ijklmnop_3);
1007   vecs[4] = unpack_lo_128(abcdefgh_4, ijklmnop_4);
1008   vecs[5] = unpack_lo_128(abcdefgh_5, ijklmnop_5);
1009   vecs[6] = unpack_lo_128(abcdefgh_6, ijklmnop_6);
1010   vecs[7] = unpack_lo_128(abcdefgh_7, ijklmnop_7);
1011   vecs[8] = unpack_hi_128(abcdefgh_0, ijklmnop_0);
1012   vecs[9] = unpack_hi_128(abcdefgh_1, ijklmnop_1);
1013   vecs[10] = unpack_hi_128(abcdefgh_2, ijklmnop_2);
1014   vecs[11] = unpack_hi_128(abcdefgh_3, ijklmnop_3);
1015   vecs[12] = unpack_hi_128(abcdefgh_4, ijklmnop_4);
1016   vecs[13] = unpack_hi_128(abcdefgh_5, ijklmnop_5);
1017   vecs[14] = unpack_hi_128(abcdefgh_6, ijklmnop_6);
1018   vecs[15] = unpack_hi_128(abcdefgh_7, ijklmnop_7);
1019 }
1020 
transpose_msg_vecs16(const uint8_t * const * inputs,size_t block_offset,__m512i out[16])1021 INLINE void transpose_msg_vecs16(const uint8_t *const *inputs,
1022                                  size_t block_offset, __m512i out[16]) {
1023   out[0] = loadu_512(&inputs[0][block_offset]);
1024   out[1] = loadu_512(&inputs[1][block_offset]);
1025   out[2] = loadu_512(&inputs[2][block_offset]);
1026   out[3] = loadu_512(&inputs[3][block_offset]);
1027   out[4] = loadu_512(&inputs[4][block_offset]);
1028   out[5] = loadu_512(&inputs[5][block_offset]);
1029   out[6] = loadu_512(&inputs[6][block_offset]);
1030   out[7] = loadu_512(&inputs[7][block_offset]);
1031   out[8] = loadu_512(&inputs[8][block_offset]);
1032   out[9] = loadu_512(&inputs[9][block_offset]);
1033   out[10] = loadu_512(&inputs[10][block_offset]);
1034   out[11] = loadu_512(&inputs[11][block_offset]);
1035   out[12] = loadu_512(&inputs[12][block_offset]);
1036   out[13] = loadu_512(&inputs[13][block_offset]);
1037   out[14] = loadu_512(&inputs[14][block_offset]);
1038   out[15] = loadu_512(&inputs[15][block_offset]);
1039   for (size_t i = 0; i < 16; ++i) {
1040     _mm_prefetch(&inputs[i][block_offset + 256], _MM_HINT_T0);
1041   }
1042   transpose_vecs_512(out);
1043 }
1044 
load_counters16(uint64_t counter,bool increment_counter,__m512i * out_lo,__m512i * out_hi)1045 INLINE void load_counters16(uint64_t counter, bool increment_counter,
1046                             __m512i *out_lo, __m512i *out_hi) {
1047   const __m512i mask = _mm512_set1_epi32(-(int32_t)increment_counter);
1048   const __m512i add0 = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1049   const __m512i add1 = _mm512_and_si512(mask, add0);
1050   __m512i l = _mm512_add_epi32(_mm512_set1_epi32(counter), add1);
1051   __mmask16 carry = _mm512_cmp_epu32_mask(l, add1, _MM_CMPINT_LT);
1052   __m512i h = _mm512_mask_add_epi32(_mm512_set1_epi32(counter >> 32), carry, _mm512_set1_epi32(counter >> 32), _mm512_set1_epi32(1));
1053   *out_lo = l;
1054   *out_hi = h;
1055 }
1056 
blake3_hash16_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1057 void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
1058                           const uint32_t key[8], uint64_t counter,
1059                           bool increment_counter, uint8_t flags,
1060                           uint8_t flags_start, uint8_t flags_end,
1061                           uint8_t *out) {
1062   __m512i h_vecs[8] = {
1063       set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]),
1064       set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]),
1065   };
1066   __m512i counter_low_vec, counter_high_vec;
1067   load_counters16(counter, increment_counter, &counter_low_vec,
1068                   &counter_high_vec);
1069   uint8_t block_flags = flags | flags_start;
1070 
1071   for (size_t block = 0; block < blocks; block++) {
1072     if (block + 1 == blocks) {
1073       block_flags |= flags_end;
1074     }
1075     __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN);
1076     __m512i block_flags_vec = set1_512(block_flags);
1077     __m512i msg_vecs[16];
1078     transpose_msg_vecs16(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
1079 
1080     __m512i v[16] = {
1081         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
1082         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
1083         set1_512(IV[0]), set1_512(IV[1]),  set1_512(IV[2]), set1_512(IV[3]),
1084         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
1085     };
1086     round_fn16(v, msg_vecs, 0);
1087     round_fn16(v, msg_vecs, 1);
1088     round_fn16(v, msg_vecs, 2);
1089     round_fn16(v, msg_vecs, 3);
1090     round_fn16(v, msg_vecs, 4);
1091     round_fn16(v, msg_vecs, 5);
1092     round_fn16(v, msg_vecs, 6);
1093     h_vecs[0] = xor_512(v[0], v[8]);
1094     h_vecs[1] = xor_512(v[1], v[9]);
1095     h_vecs[2] = xor_512(v[2], v[10]);
1096     h_vecs[3] = xor_512(v[3], v[11]);
1097     h_vecs[4] = xor_512(v[4], v[12]);
1098     h_vecs[5] = xor_512(v[5], v[13]);
1099     h_vecs[6] = xor_512(v[6], v[14]);
1100     h_vecs[7] = xor_512(v[7], v[15]);
1101 
1102     block_flags = flags;
1103   }
1104 
1105   // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8
1106   // state vectors. Pad the matrix with zeros. After transposition, store the
1107   // lower half of each vector.
1108   __m512i padded[16] = {
1109       h_vecs[0],   h_vecs[1],   h_vecs[2],   h_vecs[3],
1110       h_vecs[4],   h_vecs[5],   h_vecs[6],   h_vecs[7],
1111       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1112       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1113   };
1114   transpose_vecs_512(padded);
1115   _mm256_mask_storeu_epi32(&out[0 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[0]));
1116   _mm256_mask_storeu_epi32(&out[1 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[1]));
1117   _mm256_mask_storeu_epi32(&out[2 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[2]));
1118   _mm256_mask_storeu_epi32(&out[3 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[3]));
1119   _mm256_mask_storeu_epi32(&out[4 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[4]));
1120   _mm256_mask_storeu_epi32(&out[5 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[5]));
1121   _mm256_mask_storeu_epi32(&out[6 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[6]));
1122   _mm256_mask_storeu_epi32(&out[7 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[7]));
1123   _mm256_mask_storeu_epi32(&out[8 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[8]));
1124   _mm256_mask_storeu_epi32(&out[9 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[9]));
1125   _mm256_mask_storeu_epi32(&out[10 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[10]));
1126   _mm256_mask_storeu_epi32(&out[11 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[11]));
1127   _mm256_mask_storeu_epi32(&out[12 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[12]));
1128   _mm256_mask_storeu_epi32(&out[13 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[13]));
1129   _mm256_mask_storeu_epi32(&out[14 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[14]));
1130   _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
1131 }
1132 
1133 /*
1134  * ----------------------------------------------------------------------------
1135  * hash_many_avx512
1136  * ----------------------------------------------------------------------------
1137  */
1138 
hash_one_avx512(const uint8_t * input,size_t blocks,const uint32_t key[8],uint64_t counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t out[BLAKE3_OUT_LEN])1139 INLINE void hash_one_avx512(const uint8_t *input, size_t blocks,
1140                             const uint32_t key[8], uint64_t counter,
1141                             uint8_t flags, uint8_t flags_start,
1142                             uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) {
1143   uint32_t cv[8];
1144   memcpy(cv, key, BLAKE3_KEY_LEN);
1145   uint8_t block_flags = flags | flags_start;
1146   while (blocks > 0) {
1147     if (blocks == 1) {
1148       block_flags |= flags_end;
1149     }
1150     blake3_compress_in_place_avx512(cv, input, BLAKE3_BLOCK_LEN, counter,
1151                                     block_flags);
1152     input = &input[BLAKE3_BLOCK_LEN];
1153     blocks -= 1;
1154     block_flags = flags;
1155   }
1156   memcpy(out, cv, BLAKE3_OUT_LEN);
1157 }
1158 
blake3_hash_many_avx512(const uint8_t * const * inputs,size_t num_inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1159 void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
1160                              size_t blocks, const uint32_t key[8],
1161                              uint64_t counter, bool increment_counter,
1162                              uint8_t flags, uint8_t flags_start,
1163                              uint8_t flags_end, uint8_t *out) {
1164   while (num_inputs >= 16) {
1165     blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags,
1166                          flags_start, flags_end, out);
1167     if (increment_counter) {
1168       counter += 16;
1169     }
1170     inputs += 16;
1171     num_inputs -= 16;
1172     out = &out[16 * BLAKE3_OUT_LEN];
1173   }
1174   while (num_inputs >= 8) {
1175     blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags,
1176                         flags_start, flags_end, out);
1177     if (increment_counter) {
1178       counter += 8;
1179     }
1180     inputs += 8;
1181     num_inputs -= 8;
1182     out = &out[8 * BLAKE3_OUT_LEN];
1183   }
1184   while (num_inputs >= 4) {
1185     blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags,
1186                         flags_start, flags_end, out);
1187     if (increment_counter) {
1188       counter += 4;
1189     }
1190     inputs += 4;
1191     num_inputs -= 4;
1192     out = &out[4 * BLAKE3_OUT_LEN];
1193   }
1194   while (num_inputs > 0) {
1195     hash_one_avx512(inputs[0], blocks, key, counter, flags, flags_start,
1196                     flags_end, out);
1197     if (increment_counter) {
1198       counter += 1;
1199     }
1200     inputs += 1;
1201     num_inputs -= 1;
1202     out = &out[BLAKE3_OUT_LEN];
1203   }
1204 }
1205