1*d415bd75Srobert #include "blake3_impl.h"
2*d415bd75Srobert 
3*d415bd75Srobert #include <immintrin.h>
4*d415bd75Srobert 
5*d415bd75Srobert #define _mm_shuffle_ps2(a, b, c)                                               \
6*d415bd75Srobert   (_mm_castps_si128(                                                           \
7*d415bd75Srobert       _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))
8*d415bd75Srobert 
loadu_128(const uint8_t src[16])9*d415bd75Srobert INLINE __m128i loadu_128(const uint8_t src[16]) {
10*d415bd75Srobert   return _mm_loadu_si128((const __m128i *)src);
11*d415bd75Srobert }
12*d415bd75Srobert 
loadu_256(const uint8_t src[32])13*d415bd75Srobert INLINE __m256i loadu_256(const uint8_t src[32]) {
14*d415bd75Srobert   return _mm256_loadu_si256((const __m256i *)src);
15*d415bd75Srobert }
16*d415bd75Srobert 
loadu_512(const uint8_t src[64])17*d415bd75Srobert INLINE __m512i loadu_512(const uint8_t src[64]) {
18*d415bd75Srobert   return _mm512_loadu_si512((const __m512i *)src);
19*d415bd75Srobert }
20*d415bd75Srobert 
storeu_128(__m128i src,uint8_t dest[16])21*d415bd75Srobert INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
22*d415bd75Srobert   _mm_storeu_si128((__m128i *)dest, src);
23*d415bd75Srobert }
24*d415bd75Srobert 
storeu_256(__m256i src,uint8_t dest[16])25*d415bd75Srobert INLINE void storeu_256(__m256i src, uint8_t dest[16]) {
26*d415bd75Srobert   _mm256_storeu_si256((__m256i *)dest, src);
27*d415bd75Srobert }
28*d415bd75Srobert 
add_128(__m128i a,__m128i b)29*d415bd75Srobert INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
30*d415bd75Srobert 
add_256(__m256i a,__m256i b)31*d415bd75Srobert INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
32*d415bd75Srobert 
add_512(__m512i a,__m512i b)33*d415bd75Srobert INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(a, b); }
34*d415bd75Srobert 
xor_128(__m128i a,__m128i b)35*d415bd75Srobert INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a, b); }
36*d415bd75Srobert 
xor_256(__m256i a,__m256i b)37*d415bd75Srobert INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }
38*d415bd75Srobert 
xor_512(__m512i a,__m512i b)39*d415bd75Srobert INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a, b); }
40*d415bd75Srobert 
set1_128(uint32_t x)41*d415bd75Srobert INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32((int32_t)x); }
42*d415bd75Srobert 
set1_256(uint32_t x)43*d415bd75Srobert INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32((int32_t)x); }
44*d415bd75Srobert 
set1_512(uint32_t x)45*d415bd75Srobert INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32((int32_t)x); }
46*d415bd75Srobert 
set4(uint32_t a,uint32_t b,uint32_t c,uint32_t d)47*d415bd75Srobert INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
48*d415bd75Srobert   return _mm_setr_epi32((int32_t)a, (int32_t)b, (int32_t)c, (int32_t)d);
49*d415bd75Srobert }
50*d415bd75Srobert 
rot16_128(__m128i x)51*d415bd75Srobert INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); }
52*d415bd75Srobert 
rot16_256(__m256i x)53*d415bd75Srobert INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); }
54*d415bd75Srobert 
rot16_512(__m512i x)55*d415bd75Srobert INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); }
56*d415bd75Srobert 
rot12_128(__m128i x)57*d415bd75Srobert INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); }
58*d415bd75Srobert 
rot12_256(__m256i x)59*d415bd75Srobert INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); }
60*d415bd75Srobert 
rot12_512(__m512i x)61*d415bd75Srobert INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); }
62*d415bd75Srobert 
rot8_128(__m128i x)63*d415bd75Srobert INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); }
64*d415bd75Srobert 
rot8_256(__m256i x)65*d415bd75Srobert INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); }
66*d415bd75Srobert 
rot8_512(__m512i x)67*d415bd75Srobert INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); }
68*d415bd75Srobert 
rot7_128(__m128i x)69*d415bd75Srobert INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); }
70*d415bd75Srobert 
rot7_256(__m256i x)71*d415bd75Srobert INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); }
72*d415bd75Srobert 
rot7_512(__m512i x)73*d415bd75Srobert INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); }
74*d415bd75Srobert 
75*d415bd75Srobert /*
76*d415bd75Srobert  * ----------------------------------------------------------------------------
77*d415bd75Srobert  * compress_avx512
78*d415bd75Srobert  * ----------------------------------------------------------------------------
79*d415bd75Srobert  */
80*d415bd75Srobert 
g1(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)81*d415bd75Srobert INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
82*d415bd75Srobert                __m128i m) {
83*d415bd75Srobert   *row0 = add_128(add_128(*row0, m), *row1);
84*d415bd75Srobert   *row3 = xor_128(*row3, *row0);
85*d415bd75Srobert   *row3 = rot16_128(*row3);
86*d415bd75Srobert   *row2 = add_128(*row2, *row3);
87*d415bd75Srobert   *row1 = xor_128(*row1, *row2);
88*d415bd75Srobert   *row1 = rot12_128(*row1);
89*d415bd75Srobert }
90*d415bd75Srobert 
g2(__m128i * row0,__m128i * row1,__m128i * row2,__m128i * row3,__m128i m)91*d415bd75Srobert INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
92*d415bd75Srobert                __m128i m) {
93*d415bd75Srobert   *row0 = add_128(add_128(*row0, m), *row1);
94*d415bd75Srobert   *row3 = xor_128(*row3, *row0);
95*d415bd75Srobert   *row3 = rot8_128(*row3);
96*d415bd75Srobert   *row2 = add_128(*row2, *row3);
97*d415bd75Srobert   *row1 = xor_128(*row1, *row2);
98*d415bd75Srobert   *row1 = rot7_128(*row1);
99*d415bd75Srobert }
100*d415bd75Srobert 
101*d415bd75Srobert // Note the optimization here of leaving row1 as the unrotated row, rather than
102*d415bd75Srobert // row0. All the message loads below are adjusted to compensate for this. See
103*d415bd75Srobert // discussion at https://github.com/sneves/blake2-avx2/pull/4
diagonalize(__m128i * row0,__m128i * row2,__m128i * row3)104*d415bd75Srobert INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
105*d415bd75Srobert   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3));
106*d415bd75Srobert   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
107*d415bd75Srobert   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1));
108*d415bd75Srobert }
109*d415bd75Srobert 
undiagonalize(__m128i * row0,__m128i * row2,__m128i * row3)110*d415bd75Srobert INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
111*d415bd75Srobert   *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1));
112*d415bd75Srobert   *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
113*d415bd75Srobert   *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3));
114*d415bd75Srobert }
115*d415bd75Srobert 
compress_pre(__m128i rows[4],const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)116*d415bd75Srobert INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8],
117*d415bd75Srobert                          const uint8_t block[BLAKE3_BLOCK_LEN],
118*d415bd75Srobert                          uint8_t block_len, uint64_t counter, uint8_t flags) {
119*d415bd75Srobert   rows[0] = loadu_128((uint8_t *)&cv[0]);
120*d415bd75Srobert   rows[1] = loadu_128((uint8_t *)&cv[4]);
121*d415bd75Srobert   rows[2] = set4(IV[0], IV[1], IV[2], IV[3]);
122*d415bd75Srobert   rows[3] = set4(counter_low(counter), counter_high(counter),
123*d415bd75Srobert                  (uint32_t)block_len, (uint32_t)flags);
124*d415bd75Srobert 
125*d415bd75Srobert   __m128i m0 = loadu_128(&block[sizeof(__m128i) * 0]);
126*d415bd75Srobert   __m128i m1 = loadu_128(&block[sizeof(__m128i) * 1]);
127*d415bd75Srobert   __m128i m2 = loadu_128(&block[sizeof(__m128i) * 2]);
128*d415bd75Srobert   __m128i m3 = loadu_128(&block[sizeof(__m128i) * 3]);
129*d415bd75Srobert 
130*d415bd75Srobert   __m128i t0, t1, t2, t3, tt;
131*d415bd75Srobert 
132*d415bd75Srobert   // Round 1. The first round permutes the message words from the original
133*d415bd75Srobert   // input order, into the groups that get mixed in parallel.
134*d415bd75Srobert   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); //  6  4  2  0
135*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
136*d415bd75Srobert   t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); //  7  5  3  1
137*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
138*d415bd75Srobert   diagonalize(&rows[0], &rows[2], &rows[3]);
139*d415bd75Srobert   t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10  8
140*d415bd75Srobert   t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3));   // 12 10  8 14
141*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
142*d415bd75Srobert   t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11  9
143*d415bd75Srobert   t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3));   // 13 11  9 15
144*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
145*d415bd75Srobert   undiagonalize(&rows[0], &rows[2], &rows[3]);
146*d415bd75Srobert   m0 = t0;
147*d415bd75Srobert   m1 = t1;
148*d415bd75Srobert   m2 = t2;
149*d415bd75Srobert   m3 = t3;
150*d415bd75Srobert 
151*d415bd75Srobert   // Round 2. This round and all following rounds apply a fixed permutation
152*d415bd75Srobert   // to the message words from the round before.
153*d415bd75Srobert   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
154*d415bd75Srobert   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
155*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
156*d415bd75Srobert   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
157*d415bd75Srobert   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
158*d415bd75Srobert   t1 = _mm_blend_epi16(tt, t1, 0xCC);
159*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
160*d415bd75Srobert   diagonalize(&rows[0], &rows[2], &rows[3]);
161*d415bd75Srobert   t2 = _mm_unpacklo_epi64(m3, m1);
162*d415bd75Srobert   tt = _mm_blend_epi16(t2, m2, 0xC0);
163*d415bd75Srobert   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
164*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
165*d415bd75Srobert   t3 = _mm_unpackhi_epi32(m1, m3);
166*d415bd75Srobert   tt = _mm_unpacklo_epi32(m2, t3);
167*d415bd75Srobert   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
168*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
169*d415bd75Srobert   undiagonalize(&rows[0], &rows[2], &rows[3]);
170*d415bd75Srobert   m0 = t0;
171*d415bd75Srobert   m1 = t1;
172*d415bd75Srobert   m2 = t2;
173*d415bd75Srobert   m3 = t3;
174*d415bd75Srobert 
175*d415bd75Srobert   // Round 3
176*d415bd75Srobert   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
177*d415bd75Srobert   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
178*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
179*d415bd75Srobert   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
180*d415bd75Srobert   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
181*d415bd75Srobert   t1 = _mm_blend_epi16(tt, t1, 0xCC);
182*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
183*d415bd75Srobert   diagonalize(&rows[0], &rows[2], &rows[3]);
184*d415bd75Srobert   t2 = _mm_unpacklo_epi64(m3, m1);
185*d415bd75Srobert   tt = _mm_blend_epi16(t2, m2, 0xC0);
186*d415bd75Srobert   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
187*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
188*d415bd75Srobert   t3 = _mm_unpackhi_epi32(m1, m3);
189*d415bd75Srobert   tt = _mm_unpacklo_epi32(m2, t3);
190*d415bd75Srobert   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
191*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
192*d415bd75Srobert   undiagonalize(&rows[0], &rows[2], &rows[3]);
193*d415bd75Srobert   m0 = t0;
194*d415bd75Srobert   m1 = t1;
195*d415bd75Srobert   m2 = t2;
196*d415bd75Srobert   m3 = t3;
197*d415bd75Srobert 
198*d415bd75Srobert   // Round 4
199*d415bd75Srobert   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
200*d415bd75Srobert   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
201*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
202*d415bd75Srobert   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
203*d415bd75Srobert   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
204*d415bd75Srobert   t1 = _mm_blend_epi16(tt, t1, 0xCC);
205*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
206*d415bd75Srobert   diagonalize(&rows[0], &rows[2], &rows[3]);
207*d415bd75Srobert   t2 = _mm_unpacklo_epi64(m3, m1);
208*d415bd75Srobert   tt = _mm_blend_epi16(t2, m2, 0xC0);
209*d415bd75Srobert   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
210*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
211*d415bd75Srobert   t3 = _mm_unpackhi_epi32(m1, m3);
212*d415bd75Srobert   tt = _mm_unpacklo_epi32(m2, t3);
213*d415bd75Srobert   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
214*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
215*d415bd75Srobert   undiagonalize(&rows[0], &rows[2], &rows[3]);
216*d415bd75Srobert   m0 = t0;
217*d415bd75Srobert   m1 = t1;
218*d415bd75Srobert   m2 = t2;
219*d415bd75Srobert   m3 = t3;
220*d415bd75Srobert 
221*d415bd75Srobert   // Round 5
222*d415bd75Srobert   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
223*d415bd75Srobert   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
224*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
225*d415bd75Srobert   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
226*d415bd75Srobert   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
227*d415bd75Srobert   t1 = _mm_blend_epi16(tt, t1, 0xCC);
228*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
229*d415bd75Srobert   diagonalize(&rows[0], &rows[2], &rows[3]);
230*d415bd75Srobert   t2 = _mm_unpacklo_epi64(m3, m1);
231*d415bd75Srobert   tt = _mm_blend_epi16(t2, m2, 0xC0);
232*d415bd75Srobert   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
233*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
234*d415bd75Srobert   t3 = _mm_unpackhi_epi32(m1, m3);
235*d415bd75Srobert   tt = _mm_unpacklo_epi32(m2, t3);
236*d415bd75Srobert   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
237*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
238*d415bd75Srobert   undiagonalize(&rows[0], &rows[2], &rows[3]);
239*d415bd75Srobert   m0 = t0;
240*d415bd75Srobert   m1 = t1;
241*d415bd75Srobert   m2 = t2;
242*d415bd75Srobert   m3 = t3;
243*d415bd75Srobert 
244*d415bd75Srobert   // Round 6
245*d415bd75Srobert   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
246*d415bd75Srobert   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
247*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
248*d415bd75Srobert   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
249*d415bd75Srobert   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
250*d415bd75Srobert   t1 = _mm_blend_epi16(tt, t1, 0xCC);
251*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
252*d415bd75Srobert   diagonalize(&rows[0], &rows[2], &rows[3]);
253*d415bd75Srobert   t2 = _mm_unpacklo_epi64(m3, m1);
254*d415bd75Srobert   tt = _mm_blend_epi16(t2, m2, 0xC0);
255*d415bd75Srobert   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
256*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
257*d415bd75Srobert   t3 = _mm_unpackhi_epi32(m1, m3);
258*d415bd75Srobert   tt = _mm_unpacklo_epi32(m2, t3);
259*d415bd75Srobert   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
260*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
261*d415bd75Srobert   undiagonalize(&rows[0], &rows[2], &rows[3]);
262*d415bd75Srobert   m0 = t0;
263*d415bd75Srobert   m1 = t1;
264*d415bd75Srobert   m2 = t2;
265*d415bd75Srobert   m3 = t3;
266*d415bd75Srobert 
267*d415bd75Srobert   // Round 7
268*d415bd75Srobert   t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
269*d415bd75Srobert   t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
270*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
271*d415bd75Srobert   t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
272*d415bd75Srobert   tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
273*d415bd75Srobert   t1 = _mm_blend_epi16(tt, t1, 0xCC);
274*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
275*d415bd75Srobert   diagonalize(&rows[0], &rows[2], &rows[3]);
276*d415bd75Srobert   t2 = _mm_unpacklo_epi64(m3, m1);
277*d415bd75Srobert   tt = _mm_blend_epi16(t2, m2, 0xC0);
278*d415bd75Srobert   t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
279*d415bd75Srobert   g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
280*d415bd75Srobert   t3 = _mm_unpackhi_epi32(m1, m3);
281*d415bd75Srobert   tt = _mm_unpacklo_epi32(m2, t3);
282*d415bd75Srobert   t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
283*d415bd75Srobert   g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
284*d415bd75Srobert   undiagonalize(&rows[0], &rows[2], &rows[3]);
285*d415bd75Srobert }
286*d415bd75Srobert 
blake3_compress_xof_avx512(const uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags,uint8_t out[64])287*d415bd75Srobert void blake3_compress_xof_avx512(const uint32_t cv[8],
288*d415bd75Srobert                                 const uint8_t block[BLAKE3_BLOCK_LEN],
289*d415bd75Srobert                                 uint8_t block_len, uint64_t counter,
290*d415bd75Srobert                                 uint8_t flags, uint8_t out[64]) {
291*d415bd75Srobert   __m128i rows[4];
292*d415bd75Srobert   compress_pre(rows, cv, block, block_len, counter, flags);
293*d415bd75Srobert   storeu_128(xor_128(rows[0], rows[2]), &out[0]);
294*d415bd75Srobert   storeu_128(xor_128(rows[1], rows[3]), &out[16]);
295*d415bd75Srobert   storeu_128(xor_128(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]);
296*d415bd75Srobert   storeu_128(xor_128(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]);
297*d415bd75Srobert }
298*d415bd75Srobert 
blake3_compress_in_place_avx512(uint32_t cv[8],const uint8_t block[BLAKE3_BLOCK_LEN],uint8_t block_len,uint64_t counter,uint8_t flags)299*d415bd75Srobert void blake3_compress_in_place_avx512(uint32_t cv[8],
300*d415bd75Srobert                                      const uint8_t block[BLAKE3_BLOCK_LEN],
301*d415bd75Srobert                                      uint8_t block_len, uint64_t counter,
302*d415bd75Srobert                                      uint8_t flags) {
303*d415bd75Srobert   __m128i rows[4];
304*d415bd75Srobert   compress_pre(rows, cv, block, block_len, counter, flags);
305*d415bd75Srobert   storeu_128(xor_128(rows[0], rows[2]), (uint8_t *)&cv[0]);
306*d415bd75Srobert   storeu_128(xor_128(rows[1], rows[3]), (uint8_t *)&cv[4]);
307*d415bd75Srobert }
308*d415bd75Srobert 
309*d415bd75Srobert /*
310*d415bd75Srobert  * ----------------------------------------------------------------------------
311*d415bd75Srobert  * hash4_avx512
312*d415bd75Srobert  * ----------------------------------------------------------------------------
313*d415bd75Srobert  */
314*d415bd75Srobert 
round_fn4(__m128i v[16],__m128i m[16],size_t r)315*d415bd75Srobert INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) {
316*d415bd75Srobert   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
317*d415bd75Srobert   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
318*d415bd75Srobert   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
319*d415bd75Srobert   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
320*d415bd75Srobert   v[0] = add_128(v[0], v[4]);
321*d415bd75Srobert   v[1] = add_128(v[1], v[5]);
322*d415bd75Srobert   v[2] = add_128(v[2], v[6]);
323*d415bd75Srobert   v[3] = add_128(v[3], v[7]);
324*d415bd75Srobert   v[12] = xor_128(v[12], v[0]);
325*d415bd75Srobert   v[13] = xor_128(v[13], v[1]);
326*d415bd75Srobert   v[14] = xor_128(v[14], v[2]);
327*d415bd75Srobert   v[15] = xor_128(v[15], v[3]);
328*d415bd75Srobert   v[12] = rot16_128(v[12]);
329*d415bd75Srobert   v[13] = rot16_128(v[13]);
330*d415bd75Srobert   v[14] = rot16_128(v[14]);
331*d415bd75Srobert   v[15] = rot16_128(v[15]);
332*d415bd75Srobert   v[8] = add_128(v[8], v[12]);
333*d415bd75Srobert   v[9] = add_128(v[9], v[13]);
334*d415bd75Srobert   v[10] = add_128(v[10], v[14]);
335*d415bd75Srobert   v[11] = add_128(v[11], v[15]);
336*d415bd75Srobert   v[4] = xor_128(v[4], v[8]);
337*d415bd75Srobert   v[5] = xor_128(v[5], v[9]);
338*d415bd75Srobert   v[6] = xor_128(v[6], v[10]);
339*d415bd75Srobert   v[7] = xor_128(v[7], v[11]);
340*d415bd75Srobert   v[4] = rot12_128(v[4]);
341*d415bd75Srobert   v[5] = rot12_128(v[5]);
342*d415bd75Srobert   v[6] = rot12_128(v[6]);
343*d415bd75Srobert   v[7] = rot12_128(v[7]);
344*d415bd75Srobert   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
345*d415bd75Srobert   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
346*d415bd75Srobert   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
347*d415bd75Srobert   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
348*d415bd75Srobert   v[0] = add_128(v[0], v[4]);
349*d415bd75Srobert   v[1] = add_128(v[1], v[5]);
350*d415bd75Srobert   v[2] = add_128(v[2], v[6]);
351*d415bd75Srobert   v[3] = add_128(v[3], v[7]);
352*d415bd75Srobert   v[12] = xor_128(v[12], v[0]);
353*d415bd75Srobert   v[13] = xor_128(v[13], v[1]);
354*d415bd75Srobert   v[14] = xor_128(v[14], v[2]);
355*d415bd75Srobert   v[15] = xor_128(v[15], v[3]);
356*d415bd75Srobert   v[12] = rot8_128(v[12]);
357*d415bd75Srobert   v[13] = rot8_128(v[13]);
358*d415bd75Srobert   v[14] = rot8_128(v[14]);
359*d415bd75Srobert   v[15] = rot8_128(v[15]);
360*d415bd75Srobert   v[8] = add_128(v[8], v[12]);
361*d415bd75Srobert   v[9] = add_128(v[9], v[13]);
362*d415bd75Srobert   v[10] = add_128(v[10], v[14]);
363*d415bd75Srobert   v[11] = add_128(v[11], v[15]);
364*d415bd75Srobert   v[4] = xor_128(v[4], v[8]);
365*d415bd75Srobert   v[5] = xor_128(v[5], v[9]);
366*d415bd75Srobert   v[6] = xor_128(v[6], v[10]);
367*d415bd75Srobert   v[7] = xor_128(v[7], v[11]);
368*d415bd75Srobert   v[4] = rot7_128(v[4]);
369*d415bd75Srobert   v[5] = rot7_128(v[5]);
370*d415bd75Srobert   v[6] = rot7_128(v[6]);
371*d415bd75Srobert   v[7] = rot7_128(v[7]);
372*d415bd75Srobert 
373*d415bd75Srobert   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
374*d415bd75Srobert   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
375*d415bd75Srobert   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
376*d415bd75Srobert   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
377*d415bd75Srobert   v[0] = add_128(v[0], v[5]);
378*d415bd75Srobert   v[1] = add_128(v[1], v[6]);
379*d415bd75Srobert   v[2] = add_128(v[2], v[7]);
380*d415bd75Srobert   v[3] = add_128(v[3], v[4]);
381*d415bd75Srobert   v[15] = xor_128(v[15], v[0]);
382*d415bd75Srobert   v[12] = xor_128(v[12], v[1]);
383*d415bd75Srobert   v[13] = xor_128(v[13], v[2]);
384*d415bd75Srobert   v[14] = xor_128(v[14], v[3]);
385*d415bd75Srobert   v[15] = rot16_128(v[15]);
386*d415bd75Srobert   v[12] = rot16_128(v[12]);
387*d415bd75Srobert   v[13] = rot16_128(v[13]);
388*d415bd75Srobert   v[14] = rot16_128(v[14]);
389*d415bd75Srobert   v[10] = add_128(v[10], v[15]);
390*d415bd75Srobert   v[11] = add_128(v[11], v[12]);
391*d415bd75Srobert   v[8] = add_128(v[8], v[13]);
392*d415bd75Srobert   v[9] = add_128(v[9], v[14]);
393*d415bd75Srobert   v[5] = xor_128(v[5], v[10]);
394*d415bd75Srobert   v[6] = xor_128(v[6], v[11]);
395*d415bd75Srobert   v[7] = xor_128(v[7], v[8]);
396*d415bd75Srobert   v[4] = xor_128(v[4], v[9]);
397*d415bd75Srobert   v[5] = rot12_128(v[5]);
398*d415bd75Srobert   v[6] = rot12_128(v[6]);
399*d415bd75Srobert   v[7] = rot12_128(v[7]);
400*d415bd75Srobert   v[4] = rot12_128(v[4]);
401*d415bd75Srobert   v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
402*d415bd75Srobert   v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
403*d415bd75Srobert   v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
404*d415bd75Srobert   v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
405*d415bd75Srobert   v[0] = add_128(v[0], v[5]);
406*d415bd75Srobert   v[1] = add_128(v[1], v[6]);
407*d415bd75Srobert   v[2] = add_128(v[2], v[7]);
408*d415bd75Srobert   v[3] = add_128(v[3], v[4]);
409*d415bd75Srobert   v[15] = xor_128(v[15], v[0]);
410*d415bd75Srobert   v[12] = xor_128(v[12], v[1]);
411*d415bd75Srobert   v[13] = xor_128(v[13], v[2]);
412*d415bd75Srobert   v[14] = xor_128(v[14], v[3]);
413*d415bd75Srobert   v[15] = rot8_128(v[15]);
414*d415bd75Srobert   v[12] = rot8_128(v[12]);
415*d415bd75Srobert   v[13] = rot8_128(v[13]);
416*d415bd75Srobert   v[14] = rot8_128(v[14]);
417*d415bd75Srobert   v[10] = add_128(v[10], v[15]);
418*d415bd75Srobert   v[11] = add_128(v[11], v[12]);
419*d415bd75Srobert   v[8] = add_128(v[8], v[13]);
420*d415bd75Srobert   v[9] = add_128(v[9], v[14]);
421*d415bd75Srobert   v[5] = xor_128(v[5], v[10]);
422*d415bd75Srobert   v[6] = xor_128(v[6], v[11]);
423*d415bd75Srobert   v[7] = xor_128(v[7], v[8]);
424*d415bd75Srobert   v[4] = xor_128(v[4], v[9]);
425*d415bd75Srobert   v[5] = rot7_128(v[5]);
426*d415bd75Srobert   v[6] = rot7_128(v[6]);
427*d415bd75Srobert   v[7] = rot7_128(v[7]);
428*d415bd75Srobert   v[4] = rot7_128(v[4]);
429*d415bd75Srobert }
430*d415bd75Srobert 
transpose_vecs_128(__m128i vecs[4])431*d415bd75Srobert INLINE void transpose_vecs_128(__m128i vecs[4]) {
432*d415bd75Srobert   // Interleave 32-bit lates. The low unpack is lanes 00/11 and the high is
433*d415bd75Srobert   // 22/33. Note that this doesn't split the vector into two lanes, as the
434*d415bd75Srobert   // AVX2 counterparts do.
435*d415bd75Srobert   __m128i ab_01 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
436*d415bd75Srobert   __m128i ab_23 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
437*d415bd75Srobert   __m128i cd_01 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
438*d415bd75Srobert   __m128i cd_23 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
439*d415bd75Srobert 
440*d415bd75Srobert   // Interleave 64-bit lanes.
441*d415bd75Srobert   __m128i abcd_0 = _mm_unpacklo_epi64(ab_01, cd_01);
442*d415bd75Srobert   __m128i abcd_1 = _mm_unpackhi_epi64(ab_01, cd_01);
443*d415bd75Srobert   __m128i abcd_2 = _mm_unpacklo_epi64(ab_23, cd_23);
444*d415bd75Srobert   __m128i abcd_3 = _mm_unpackhi_epi64(ab_23, cd_23);
445*d415bd75Srobert 
446*d415bd75Srobert   vecs[0] = abcd_0;
447*d415bd75Srobert   vecs[1] = abcd_1;
448*d415bd75Srobert   vecs[2] = abcd_2;
449*d415bd75Srobert   vecs[3] = abcd_3;
450*d415bd75Srobert }
451*d415bd75Srobert 
transpose_msg_vecs4(const uint8_t * const * inputs,size_t block_offset,__m128i out[16])452*d415bd75Srobert INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
453*d415bd75Srobert                                 size_t block_offset, __m128i out[16]) {
454*d415bd75Srobert   out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(__m128i)]);
455*d415bd75Srobert   out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(__m128i)]);
456*d415bd75Srobert   out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(__m128i)]);
457*d415bd75Srobert   out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(__m128i)]);
458*d415bd75Srobert   out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(__m128i)]);
459*d415bd75Srobert   out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(__m128i)]);
460*d415bd75Srobert   out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(__m128i)]);
461*d415bd75Srobert   out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(__m128i)]);
462*d415bd75Srobert   out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(__m128i)]);
463*d415bd75Srobert   out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(__m128i)]);
464*d415bd75Srobert   out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(__m128i)]);
465*d415bd75Srobert   out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(__m128i)]);
466*d415bd75Srobert   out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(__m128i)]);
467*d415bd75Srobert   out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(__m128i)]);
468*d415bd75Srobert   out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(__m128i)]);
469*d415bd75Srobert   out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(__m128i)]);
470*d415bd75Srobert   for (size_t i = 0; i < 4; ++i) {
471*d415bd75Srobert     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
472*d415bd75Srobert   }
473*d415bd75Srobert   transpose_vecs_128(&out[0]);
474*d415bd75Srobert   transpose_vecs_128(&out[4]);
475*d415bd75Srobert   transpose_vecs_128(&out[8]);
476*d415bd75Srobert   transpose_vecs_128(&out[12]);
477*d415bd75Srobert }
478*d415bd75Srobert 
load_counters4(uint64_t counter,bool increment_counter,__m128i * out_lo,__m128i * out_hi)479*d415bd75Srobert INLINE void load_counters4(uint64_t counter, bool increment_counter,
480*d415bd75Srobert                            __m128i *out_lo, __m128i *out_hi) {
481*d415bd75Srobert   uint64_t mask = (increment_counter ? ~0 : 0);
482*d415bd75Srobert   __m256i mask_vec = _mm256_set1_epi64x(mask);
483*d415bd75Srobert   __m256i deltas = _mm256_setr_epi64x(0, 1, 2, 3);
484*d415bd75Srobert   deltas = _mm256_and_si256(mask_vec, deltas);
485*d415bd75Srobert   __m256i counters =
486*d415bd75Srobert       _mm256_add_epi64(_mm256_set1_epi64x((int64_t)counter), deltas);
487*d415bd75Srobert   *out_lo = _mm256_cvtepi64_epi32(counters);
488*d415bd75Srobert   *out_hi = _mm256_cvtepi64_epi32(_mm256_srli_epi64(counters, 32));
489*d415bd75Srobert }
490*d415bd75Srobert 
491*d415bd75Srobert static
blake3_hash4_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)492*d415bd75Srobert void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
493*d415bd75Srobert                          const uint32_t key[8], uint64_t counter,
494*d415bd75Srobert                          bool increment_counter, uint8_t flags,
495*d415bd75Srobert                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
496*d415bd75Srobert   __m128i h_vecs[8] = {
497*d415bd75Srobert       set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
498*d415bd75Srobert       set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
499*d415bd75Srobert   };
500*d415bd75Srobert   __m128i counter_low_vec, counter_high_vec;
501*d415bd75Srobert   load_counters4(counter, increment_counter, &counter_low_vec,
502*d415bd75Srobert                  &counter_high_vec);
503*d415bd75Srobert   uint8_t block_flags = flags | flags_start;
504*d415bd75Srobert 
505*d415bd75Srobert   for (size_t block = 0; block < blocks; block++) {
506*d415bd75Srobert     if (block + 1 == blocks) {
507*d415bd75Srobert       block_flags |= flags_end;
508*d415bd75Srobert     }
509*d415bd75Srobert     __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
510*d415bd75Srobert     __m128i block_flags_vec = set1_128(block_flags);
511*d415bd75Srobert     __m128i msg_vecs[16];
512*d415bd75Srobert     transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
513*d415bd75Srobert 
514*d415bd75Srobert     __m128i v[16] = {
515*d415bd75Srobert         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
516*d415bd75Srobert         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
517*d415bd75Srobert         set1_128(IV[0]), set1_128(IV[1]),  set1_128(IV[2]), set1_128(IV[3]),
518*d415bd75Srobert         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
519*d415bd75Srobert     };
520*d415bd75Srobert     round_fn4(v, msg_vecs, 0);
521*d415bd75Srobert     round_fn4(v, msg_vecs, 1);
522*d415bd75Srobert     round_fn4(v, msg_vecs, 2);
523*d415bd75Srobert     round_fn4(v, msg_vecs, 3);
524*d415bd75Srobert     round_fn4(v, msg_vecs, 4);
525*d415bd75Srobert     round_fn4(v, msg_vecs, 5);
526*d415bd75Srobert     round_fn4(v, msg_vecs, 6);
527*d415bd75Srobert     h_vecs[0] = xor_128(v[0], v[8]);
528*d415bd75Srobert     h_vecs[1] = xor_128(v[1], v[9]);
529*d415bd75Srobert     h_vecs[2] = xor_128(v[2], v[10]);
530*d415bd75Srobert     h_vecs[3] = xor_128(v[3], v[11]);
531*d415bd75Srobert     h_vecs[4] = xor_128(v[4], v[12]);
532*d415bd75Srobert     h_vecs[5] = xor_128(v[5], v[13]);
533*d415bd75Srobert     h_vecs[6] = xor_128(v[6], v[14]);
534*d415bd75Srobert     h_vecs[7] = xor_128(v[7], v[15]);
535*d415bd75Srobert 
536*d415bd75Srobert     block_flags = flags;
537*d415bd75Srobert   }
538*d415bd75Srobert 
539*d415bd75Srobert   transpose_vecs_128(&h_vecs[0]);
540*d415bd75Srobert   transpose_vecs_128(&h_vecs[4]);
541*d415bd75Srobert   // The first four vecs now contain the first half of each output, and the
542*d415bd75Srobert   // second four vecs contain the second half of each output.
543*d415bd75Srobert   storeu_128(h_vecs[0], &out[0 * sizeof(__m128i)]);
544*d415bd75Srobert   storeu_128(h_vecs[4], &out[1 * sizeof(__m128i)]);
545*d415bd75Srobert   storeu_128(h_vecs[1], &out[2 * sizeof(__m128i)]);
546*d415bd75Srobert   storeu_128(h_vecs[5], &out[3 * sizeof(__m128i)]);
547*d415bd75Srobert   storeu_128(h_vecs[2], &out[4 * sizeof(__m128i)]);
548*d415bd75Srobert   storeu_128(h_vecs[6], &out[5 * sizeof(__m128i)]);
549*d415bd75Srobert   storeu_128(h_vecs[3], &out[6 * sizeof(__m128i)]);
550*d415bd75Srobert   storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
551*d415bd75Srobert }
552*d415bd75Srobert 
553*d415bd75Srobert /*
554*d415bd75Srobert  * ----------------------------------------------------------------------------
555*d415bd75Srobert  * hash8_avx512
556*d415bd75Srobert  * ----------------------------------------------------------------------------
557*d415bd75Srobert  */
558*d415bd75Srobert 
round_fn8(__m256i v[16],__m256i m[16],size_t r)559*d415bd75Srobert INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) {
560*d415bd75Srobert   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
561*d415bd75Srobert   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
562*d415bd75Srobert   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
563*d415bd75Srobert   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
564*d415bd75Srobert   v[0] = add_256(v[0], v[4]);
565*d415bd75Srobert   v[1] = add_256(v[1], v[5]);
566*d415bd75Srobert   v[2] = add_256(v[2], v[6]);
567*d415bd75Srobert   v[3] = add_256(v[3], v[7]);
568*d415bd75Srobert   v[12] = xor_256(v[12], v[0]);
569*d415bd75Srobert   v[13] = xor_256(v[13], v[1]);
570*d415bd75Srobert   v[14] = xor_256(v[14], v[2]);
571*d415bd75Srobert   v[15] = xor_256(v[15], v[3]);
572*d415bd75Srobert   v[12] = rot16_256(v[12]);
573*d415bd75Srobert   v[13] = rot16_256(v[13]);
574*d415bd75Srobert   v[14] = rot16_256(v[14]);
575*d415bd75Srobert   v[15] = rot16_256(v[15]);
576*d415bd75Srobert   v[8] = add_256(v[8], v[12]);
577*d415bd75Srobert   v[9] = add_256(v[9], v[13]);
578*d415bd75Srobert   v[10] = add_256(v[10], v[14]);
579*d415bd75Srobert   v[11] = add_256(v[11], v[15]);
580*d415bd75Srobert   v[4] = xor_256(v[4], v[8]);
581*d415bd75Srobert   v[5] = xor_256(v[5], v[9]);
582*d415bd75Srobert   v[6] = xor_256(v[6], v[10]);
583*d415bd75Srobert   v[7] = xor_256(v[7], v[11]);
584*d415bd75Srobert   v[4] = rot12_256(v[4]);
585*d415bd75Srobert   v[5] = rot12_256(v[5]);
586*d415bd75Srobert   v[6] = rot12_256(v[6]);
587*d415bd75Srobert   v[7] = rot12_256(v[7]);
588*d415bd75Srobert   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
589*d415bd75Srobert   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
590*d415bd75Srobert   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
591*d415bd75Srobert   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
592*d415bd75Srobert   v[0] = add_256(v[0], v[4]);
593*d415bd75Srobert   v[1] = add_256(v[1], v[5]);
594*d415bd75Srobert   v[2] = add_256(v[2], v[6]);
595*d415bd75Srobert   v[3] = add_256(v[3], v[7]);
596*d415bd75Srobert   v[12] = xor_256(v[12], v[0]);
597*d415bd75Srobert   v[13] = xor_256(v[13], v[1]);
598*d415bd75Srobert   v[14] = xor_256(v[14], v[2]);
599*d415bd75Srobert   v[15] = xor_256(v[15], v[3]);
600*d415bd75Srobert   v[12] = rot8_256(v[12]);
601*d415bd75Srobert   v[13] = rot8_256(v[13]);
602*d415bd75Srobert   v[14] = rot8_256(v[14]);
603*d415bd75Srobert   v[15] = rot8_256(v[15]);
604*d415bd75Srobert   v[8] = add_256(v[8], v[12]);
605*d415bd75Srobert   v[9] = add_256(v[9], v[13]);
606*d415bd75Srobert   v[10] = add_256(v[10], v[14]);
607*d415bd75Srobert   v[11] = add_256(v[11], v[15]);
608*d415bd75Srobert   v[4] = xor_256(v[4], v[8]);
609*d415bd75Srobert   v[5] = xor_256(v[5], v[9]);
610*d415bd75Srobert   v[6] = xor_256(v[6], v[10]);
611*d415bd75Srobert   v[7] = xor_256(v[7], v[11]);
612*d415bd75Srobert   v[4] = rot7_256(v[4]);
613*d415bd75Srobert   v[5] = rot7_256(v[5]);
614*d415bd75Srobert   v[6] = rot7_256(v[6]);
615*d415bd75Srobert   v[7] = rot7_256(v[7]);
616*d415bd75Srobert 
617*d415bd75Srobert   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
618*d415bd75Srobert   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
619*d415bd75Srobert   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
620*d415bd75Srobert   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
621*d415bd75Srobert   v[0] = add_256(v[0], v[5]);
622*d415bd75Srobert   v[1] = add_256(v[1], v[6]);
623*d415bd75Srobert   v[2] = add_256(v[2], v[7]);
624*d415bd75Srobert   v[3] = add_256(v[3], v[4]);
625*d415bd75Srobert   v[15] = xor_256(v[15], v[0]);
626*d415bd75Srobert   v[12] = xor_256(v[12], v[1]);
627*d415bd75Srobert   v[13] = xor_256(v[13], v[2]);
628*d415bd75Srobert   v[14] = xor_256(v[14], v[3]);
629*d415bd75Srobert   v[15] = rot16_256(v[15]);
630*d415bd75Srobert   v[12] = rot16_256(v[12]);
631*d415bd75Srobert   v[13] = rot16_256(v[13]);
632*d415bd75Srobert   v[14] = rot16_256(v[14]);
633*d415bd75Srobert   v[10] = add_256(v[10], v[15]);
634*d415bd75Srobert   v[11] = add_256(v[11], v[12]);
635*d415bd75Srobert   v[8] = add_256(v[8], v[13]);
636*d415bd75Srobert   v[9] = add_256(v[9], v[14]);
637*d415bd75Srobert   v[5] = xor_256(v[5], v[10]);
638*d415bd75Srobert   v[6] = xor_256(v[6], v[11]);
639*d415bd75Srobert   v[7] = xor_256(v[7], v[8]);
640*d415bd75Srobert   v[4] = xor_256(v[4], v[9]);
641*d415bd75Srobert   v[5] = rot12_256(v[5]);
642*d415bd75Srobert   v[6] = rot12_256(v[6]);
643*d415bd75Srobert   v[7] = rot12_256(v[7]);
644*d415bd75Srobert   v[4] = rot12_256(v[4]);
645*d415bd75Srobert   v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
646*d415bd75Srobert   v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
647*d415bd75Srobert   v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
648*d415bd75Srobert   v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
649*d415bd75Srobert   v[0] = add_256(v[0], v[5]);
650*d415bd75Srobert   v[1] = add_256(v[1], v[6]);
651*d415bd75Srobert   v[2] = add_256(v[2], v[7]);
652*d415bd75Srobert   v[3] = add_256(v[3], v[4]);
653*d415bd75Srobert   v[15] = xor_256(v[15], v[0]);
654*d415bd75Srobert   v[12] = xor_256(v[12], v[1]);
655*d415bd75Srobert   v[13] = xor_256(v[13], v[2]);
656*d415bd75Srobert   v[14] = xor_256(v[14], v[3]);
657*d415bd75Srobert   v[15] = rot8_256(v[15]);
658*d415bd75Srobert   v[12] = rot8_256(v[12]);
659*d415bd75Srobert   v[13] = rot8_256(v[13]);
660*d415bd75Srobert   v[14] = rot8_256(v[14]);
661*d415bd75Srobert   v[10] = add_256(v[10], v[15]);
662*d415bd75Srobert   v[11] = add_256(v[11], v[12]);
663*d415bd75Srobert   v[8] = add_256(v[8], v[13]);
664*d415bd75Srobert   v[9] = add_256(v[9], v[14]);
665*d415bd75Srobert   v[5] = xor_256(v[5], v[10]);
666*d415bd75Srobert   v[6] = xor_256(v[6], v[11]);
667*d415bd75Srobert   v[7] = xor_256(v[7], v[8]);
668*d415bd75Srobert   v[4] = xor_256(v[4], v[9]);
669*d415bd75Srobert   v[5] = rot7_256(v[5]);
670*d415bd75Srobert   v[6] = rot7_256(v[6]);
671*d415bd75Srobert   v[7] = rot7_256(v[7]);
672*d415bd75Srobert   v[4] = rot7_256(v[4]);
673*d415bd75Srobert }
674*d415bd75Srobert 
transpose_vecs_256(__m256i vecs[8])675*d415bd75Srobert INLINE void transpose_vecs_256(__m256i vecs[8]) {
676*d415bd75Srobert   // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high
677*d415bd75Srobert   // is 22/33/66/77.
678*d415bd75Srobert   __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
679*d415bd75Srobert   __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
680*d415bd75Srobert   __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
681*d415bd75Srobert   __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
682*d415bd75Srobert   __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
683*d415bd75Srobert   __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
684*d415bd75Srobert   __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
685*d415bd75Srobert   __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
686*d415bd75Srobert 
687*d415bd75Srobert   // Interleave 64-bit lates. The low unpack is lanes 00/22 and the high is
688*d415bd75Srobert   // 11/33.
689*d415bd75Srobert   __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
690*d415bd75Srobert   __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
691*d415bd75Srobert   __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
692*d415bd75Srobert   __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
693*d415bd75Srobert   __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
694*d415bd75Srobert   __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
695*d415bd75Srobert   __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
696*d415bd75Srobert   __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
697*d415bd75Srobert 
698*d415bd75Srobert   // Interleave 128-bit lanes.
699*d415bd75Srobert   vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20);
700*d415bd75Srobert   vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20);
701*d415bd75Srobert   vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20);
702*d415bd75Srobert   vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20);
703*d415bd75Srobert   vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31);
704*d415bd75Srobert   vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31);
705*d415bd75Srobert   vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31);
706*d415bd75Srobert   vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31);
707*d415bd75Srobert }
708*d415bd75Srobert 
transpose_msg_vecs8(const uint8_t * const * inputs,size_t block_offset,__m256i out[16])709*d415bd75Srobert INLINE void transpose_msg_vecs8(const uint8_t *const *inputs,
710*d415bd75Srobert                                 size_t block_offset, __m256i out[16]) {
711*d415bd75Srobert   out[0] = loadu_256(&inputs[0][block_offset + 0 * sizeof(__m256i)]);
712*d415bd75Srobert   out[1] = loadu_256(&inputs[1][block_offset + 0 * sizeof(__m256i)]);
713*d415bd75Srobert   out[2] = loadu_256(&inputs[2][block_offset + 0 * sizeof(__m256i)]);
714*d415bd75Srobert   out[3] = loadu_256(&inputs[3][block_offset + 0 * sizeof(__m256i)]);
715*d415bd75Srobert   out[4] = loadu_256(&inputs[4][block_offset + 0 * sizeof(__m256i)]);
716*d415bd75Srobert   out[5] = loadu_256(&inputs[5][block_offset + 0 * sizeof(__m256i)]);
717*d415bd75Srobert   out[6] = loadu_256(&inputs[6][block_offset + 0 * sizeof(__m256i)]);
718*d415bd75Srobert   out[7] = loadu_256(&inputs[7][block_offset + 0 * sizeof(__m256i)]);
719*d415bd75Srobert   out[8] = loadu_256(&inputs[0][block_offset + 1 * sizeof(__m256i)]);
720*d415bd75Srobert   out[9] = loadu_256(&inputs[1][block_offset + 1 * sizeof(__m256i)]);
721*d415bd75Srobert   out[10] = loadu_256(&inputs[2][block_offset + 1 * sizeof(__m256i)]);
722*d415bd75Srobert   out[11] = loadu_256(&inputs[3][block_offset + 1 * sizeof(__m256i)]);
723*d415bd75Srobert   out[12] = loadu_256(&inputs[4][block_offset + 1 * sizeof(__m256i)]);
724*d415bd75Srobert   out[13] = loadu_256(&inputs[5][block_offset + 1 * sizeof(__m256i)]);
725*d415bd75Srobert   out[14] = loadu_256(&inputs[6][block_offset + 1 * sizeof(__m256i)]);
726*d415bd75Srobert   out[15] = loadu_256(&inputs[7][block_offset + 1 * sizeof(__m256i)]);
727*d415bd75Srobert   for (size_t i = 0; i < 8; ++i) {
728*d415bd75Srobert     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
729*d415bd75Srobert   }
730*d415bd75Srobert   transpose_vecs_256(&out[0]);
731*d415bd75Srobert   transpose_vecs_256(&out[8]);
732*d415bd75Srobert }
733*d415bd75Srobert 
load_counters8(uint64_t counter,bool increment_counter,__m256i * out_lo,__m256i * out_hi)734*d415bd75Srobert INLINE void load_counters8(uint64_t counter, bool increment_counter,
735*d415bd75Srobert                            __m256i *out_lo, __m256i *out_hi) {
736*d415bd75Srobert   uint64_t mask = (increment_counter ? ~0 : 0);
737*d415bd75Srobert   __m512i mask_vec = _mm512_set1_epi64(mask);
738*d415bd75Srobert   __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7);
739*d415bd75Srobert   deltas = _mm512_and_si512(mask_vec, deltas);
740*d415bd75Srobert   __m512i counters =
741*d415bd75Srobert       _mm512_add_epi64(_mm512_set1_epi64((int64_t)counter), deltas);
742*d415bd75Srobert   *out_lo = _mm512_cvtepi64_epi32(counters);
743*d415bd75Srobert   *out_hi = _mm512_cvtepi64_epi32(_mm512_srli_epi64(counters, 32));
744*d415bd75Srobert }
745*d415bd75Srobert 
746*d415bd75Srobert static
blake3_hash8_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)747*d415bd75Srobert void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
748*d415bd75Srobert                          const uint32_t key[8], uint64_t counter,
749*d415bd75Srobert                          bool increment_counter, uint8_t flags,
750*d415bd75Srobert                          uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
751*d415bd75Srobert   __m256i h_vecs[8] = {
752*d415bd75Srobert       set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]),
753*d415bd75Srobert       set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]),
754*d415bd75Srobert   };
755*d415bd75Srobert   __m256i counter_low_vec, counter_high_vec;
756*d415bd75Srobert   load_counters8(counter, increment_counter, &counter_low_vec,
757*d415bd75Srobert                  &counter_high_vec);
758*d415bd75Srobert   uint8_t block_flags = flags | flags_start;
759*d415bd75Srobert 
760*d415bd75Srobert   for (size_t block = 0; block < blocks; block++) {
761*d415bd75Srobert     if (block + 1 == blocks) {
762*d415bd75Srobert       block_flags |= flags_end;
763*d415bd75Srobert     }
764*d415bd75Srobert     __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN);
765*d415bd75Srobert     __m256i block_flags_vec = set1_256(block_flags);
766*d415bd75Srobert     __m256i msg_vecs[16];
767*d415bd75Srobert     transpose_msg_vecs8(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
768*d415bd75Srobert 
769*d415bd75Srobert     __m256i v[16] = {
770*d415bd75Srobert         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
771*d415bd75Srobert         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
772*d415bd75Srobert         set1_256(IV[0]), set1_256(IV[1]),  set1_256(IV[2]), set1_256(IV[3]),
773*d415bd75Srobert         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
774*d415bd75Srobert     };
775*d415bd75Srobert     round_fn8(v, msg_vecs, 0);
776*d415bd75Srobert     round_fn8(v, msg_vecs, 1);
777*d415bd75Srobert     round_fn8(v, msg_vecs, 2);
778*d415bd75Srobert     round_fn8(v, msg_vecs, 3);
779*d415bd75Srobert     round_fn8(v, msg_vecs, 4);
780*d415bd75Srobert     round_fn8(v, msg_vecs, 5);
781*d415bd75Srobert     round_fn8(v, msg_vecs, 6);
782*d415bd75Srobert     h_vecs[0] = xor_256(v[0], v[8]);
783*d415bd75Srobert     h_vecs[1] = xor_256(v[1], v[9]);
784*d415bd75Srobert     h_vecs[2] = xor_256(v[2], v[10]);
785*d415bd75Srobert     h_vecs[3] = xor_256(v[3], v[11]);
786*d415bd75Srobert     h_vecs[4] = xor_256(v[4], v[12]);
787*d415bd75Srobert     h_vecs[5] = xor_256(v[5], v[13]);
788*d415bd75Srobert     h_vecs[6] = xor_256(v[6], v[14]);
789*d415bd75Srobert     h_vecs[7] = xor_256(v[7], v[15]);
790*d415bd75Srobert 
791*d415bd75Srobert     block_flags = flags;
792*d415bd75Srobert   }
793*d415bd75Srobert 
794*d415bd75Srobert   transpose_vecs_256(h_vecs);
795*d415bd75Srobert   storeu_256(h_vecs[0], &out[0 * sizeof(__m256i)]);
796*d415bd75Srobert   storeu_256(h_vecs[1], &out[1 * sizeof(__m256i)]);
797*d415bd75Srobert   storeu_256(h_vecs[2], &out[2 * sizeof(__m256i)]);
798*d415bd75Srobert   storeu_256(h_vecs[3], &out[3 * sizeof(__m256i)]);
799*d415bd75Srobert   storeu_256(h_vecs[4], &out[4 * sizeof(__m256i)]);
800*d415bd75Srobert   storeu_256(h_vecs[5], &out[5 * sizeof(__m256i)]);
801*d415bd75Srobert   storeu_256(h_vecs[6], &out[6 * sizeof(__m256i)]);
802*d415bd75Srobert   storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
803*d415bd75Srobert }
804*d415bd75Srobert 
805*d415bd75Srobert /*
806*d415bd75Srobert  * ----------------------------------------------------------------------------
807*d415bd75Srobert  * hash16_avx512
808*d415bd75Srobert  * ----------------------------------------------------------------------------
809*d415bd75Srobert  */
810*d415bd75Srobert 
round_fn16(__m512i v[16],__m512i m[16],size_t r)811*d415bd75Srobert INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) {
812*d415bd75Srobert   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
813*d415bd75Srobert   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
814*d415bd75Srobert   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
815*d415bd75Srobert   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
816*d415bd75Srobert   v[0] = add_512(v[0], v[4]);
817*d415bd75Srobert   v[1] = add_512(v[1], v[5]);
818*d415bd75Srobert   v[2] = add_512(v[2], v[6]);
819*d415bd75Srobert   v[3] = add_512(v[3], v[7]);
820*d415bd75Srobert   v[12] = xor_512(v[12], v[0]);
821*d415bd75Srobert   v[13] = xor_512(v[13], v[1]);
822*d415bd75Srobert   v[14] = xor_512(v[14], v[2]);
823*d415bd75Srobert   v[15] = xor_512(v[15], v[3]);
824*d415bd75Srobert   v[12] = rot16_512(v[12]);
825*d415bd75Srobert   v[13] = rot16_512(v[13]);
826*d415bd75Srobert   v[14] = rot16_512(v[14]);
827*d415bd75Srobert   v[15] = rot16_512(v[15]);
828*d415bd75Srobert   v[8] = add_512(v[8], v[12]);
829*d415bd75Srobert   v[9] = add_512(v[9], v[13]);
830*d415bd75Srobert   v[10] = add_512(v[10], v[14]);
831*d415bd75Srobert   v[11] = add_512(v[11], v[15]);
832*d415bd75Srobert   v[4] = xor_512(v[4], v[8]);
833*d415bd75Srobert   v[5] = xor_512(v[5], v[9]);
834*d415bd75Srobert   v[6] = xor_512(v[6], v[10]);
835*d415bd75Srobert   v[7] = xor_512(v[7], v[11]);
836*d415bd75Srobert   v[4] = rot12_512(v[4]);
837*d415bd75Srobert   v[5] = rot12_512(v[5]);
838*d415bd75Srobert   v[6] = rot12_512(v[6]);
839*d415bd75Srobert   v[7] = rot12_512(v[7]);
840*d415bd75Srobert   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
841*d415bd75Srobert   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
842*d415bd75Srobert   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
843*d415bd75Srobert   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
844*d415bd75Srobert   v[0] = add_512(v[0], v[4]);
845*d415bd75Srobert   v[1] = add_512(v[1], v[5]);
846*d415bd75Srobert   v[2] = add_512(v[2], v[6]);
847*d415bd75Srobert   v[3] = add_512(v[3], v[7]);
848*d415bd75Srobert   v[12] = xor_512(v[12], v[0]);
849*d415bd75Srobert   v[13] = xor_512(v[13], v[1]);
850*d415bd75Srobert   v[14] = xor_512(v[14], v[2]);
851*d415bd75Srobert   v[15] = xor_512(v[15], v[3]);
852*d415bd75Srobert   v[12] = rot8_512(v[12]);
853*d415bd75Srobert   v[13] = rot8_512(v[13]);
854*d415bd75Srobert   v[14] = rot8_512(v[14]);
855*d415bd75Srobert   v[15] = rot8_512(v[15]);
856*d415bd75Srobert   v[8] = add_512(v[8], v[12]);
857*d415bd75Srobert   v[9] = add_512(v[9], v[13]);
858*d415bd75Srobert   v[10] = add_512(v[10], v[14]);
859*d415bd75Srobert   v[11] = add_512(v[11], v[15]);
860*d415bd75Srobert   v[4] = xor_512(v[4], v[8]);
861*d415bd75Srobert   v[5] = xor_512(v[5], v[9]);
862*d415bd75Srobert   v[6] = xor_512(v[6], v[10]);
863*d415bd75Srobert   v[7] = xor_512(v[7], v[11]);
864*d415bd75Srobert   v[4] = rot7_512(v[4]);
865*d415bd75Srobert   v[5] = rot7_512(v[5]);
866*d415bd75Srobert   v[6] = rot7_512(v[6]);
867*d415bd75Srobert   v[7] = rot7_512(v[7]);
868*d415bd75Srobert 
869*d415bd75Srobert   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
870*d415bd75Srobert   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
871*d415bd75Srobert   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
872*d415bd75Srobert   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
873*d415bd75Srobert   v[0] = add_512(v[0], v[5]);
874*d415bd75Srobert   v[1] = add_512(v[1], v[6]);
875*d415bd75Srobert   v[2] = add_512(v[2], v[7]);
876*d415bd75Srobert   v[3] = add_512(v[3], v[4]);
877*d415bd75Srobert   v[15] = xor_512(v[15], v[0]);
878*d415bd75Srobert   v[12] = xor_512(v[12], v[1]);
879*d415bd75Srobert   v[13] = xor_512(v[13], v[2]);
880*d415bd75Srobert   v[14] = xor_512(v[14], v[3]);
881*d415bd75Srobert   v[15] = rot16_512(v[15]);
882*d415bd75Srobert   v[12] = rot16_512(v[12]);
883*d415bd75Srobert   v[13] = rot16_512(v[13]);
884*d415bd75Srobert   v[14] = rot16_512(v[14]);
885*d415bd75Srobert   v[10] = add_512(v[10], v[15]);
886*d415bd75Srobert   v[11] = add_512(v[11], v[12]);
887*d415bd75Srobert   v[8] = add_512(v[8], v[13]);
888*d415bd75Srobert   v[9] = add_512(v[9], v[14]);
889*d415bd75Srobert   v[5] = xor_512(v[5], v[10]);
890*d415bd75Srobert   v[6] = xor_512(v[6], v[11]);
891*d415bd75Srobert   v[7] = xor_512(v[7], v[8]);
892*d415bd75Srobert   v[4] = xor_512(v[4], v[9]);
893*d415bd75Srobert   v[5] = rot12_512(v[5]);
894*d415bd75Srobert   v[6] = rot12_512(v[6]);
895*d415bd75Srobert   v[7] = rot12_512(v[7]);
896*d415bd75Srobert   v[4] = rot12_512(v[4]);
897*d415bd75Srobert   v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
898*d415bd75Srobert   v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
899*d415bd75Srobert   v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
900*d415bd75Srobert   v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
901*d415bd75Srobert   v[0] = add_512(v[0], v[5]);
902*d415bd75Srobert   v[1] = add_512(v[1], v[6]);
903*d415bd75Srobert   v[2] = add_512(v[2], v[7]);
904*d415bd75Srobert   v[3] = add_512(v[3], v[4]);
905*d415bd75Srobert   v[15] = xor_512(v[15], v[0]);
906*d415bd75Srobert   v[12] = xor_512(v[12], v[1]);
907*d415bd75Srobert   v[13] = xor_512(v[13], v[2]);
908*d415bd75Srobert   v[14] = xor_512(v[14], v[3]);
909*d415bd75Srobert   v[15] = rot8_512(v[15]);
910*d415bd75Srobert   v[12] = rot8_512(v[12]);
911*d415bd75Srobert   v[13] = rot8_512(v[13]);
912*d415bd75Srobert   v[14] = rot8_512(v[14]);
913*d415bd75Srobert   v[10] = add_512(v[10], v[15]);
914*d415bd75Srobert   v[11] = add_512(v[11], v[12]);
915*d415bd75Srobert   v[8] = add_512(v[8], v[13]);
916*d415bd75Srobert   v[9] = add_512(v[9], v[14]);
917*d415bd75Srobert   v[5] = xor_512(v[5], v[10]);
918*d415bd75Srobert   v[6] = xor_512(v[6], v[11]);
919*d415bd75Srobert   v[7] = xor_512(v[7], v[8]);
920*d415bd75Srobert   v[4] = xor_512(v[4], v[9]);
921*d415bd75Srobert   v[5] = rot7_512(v[5]);
922*d415bd75Srobert   v[6] = rot7_512(v[6]);
923*d415bd75Srobert   v[7] = rot7_512(v[7]);
924*d415bd75Srobert   v[4] = rot7_512(v[4]);
925*d415bd75Srobert }
926*d415bd75Srobert 
927*d415bd75Srobert // 0b10001000, or lanes a0/a2/b0/b2 in little-endian order
928*d415bd75Srobert #define LO_IMM8 0x88
929*d415bd75Srobert 
unpack_lo_128(__m512i a,__m512i b)930*d415bd75Srobert INLINE __m512i unpack_lo_128(__m512i a, __m512i b) {
931*d415bd75Srobert   return _mm512_shuffle_i32x4(a, b, LO_IMM8);
932*d415bd75Srobert }
933*d415bd75Srobert 
934*d415bd75Srobert // 0b11011101, or lanes a1/a3/b1/b3 in little-endian order
935*d415bd75Srobert #define HI_IMM8 0xdd
936*d415bd75Srobert 
unpack_hi_128(__m512i a,__m512i b)937*d415bd75Srobert INLINE __m512i unpack_hi_128(__m512i a, __m512i b) {
938*d415bd75Srobert   return _mm512_shuffle_i32x4(a, b, HI_IMM8);
939*d415bd75Srobert }
940*d415bd75Srobert 
transpose_vecs_512(__m512i vecs[16])941*d415bd75Srobert INLINE void transpose_vecs_512(__m512i vecs[16]) {
942*d415bd75Srobert   // Interleave 32-bit lanes. The _0 unpack is lanes
943*d415bd75Srobert   // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes
944*d415bd75Srobert   // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15.
945*d415bd75Srobert   __m512i ab_0 = _mm512_unpacklo_epi32(vecs[0], vecs[1]);
946*d415bd75Srobert   __m512i ab_2 = _mm512_unpackhi_epi32(vecs[0], vecs[1]);
947*d415bd75Srobert   __m512i cd_0 = _mm512_unpacklo_epi32(vecs[2], vecs[3]);
948*d415bd75Srobert   __m512i cd_2 = _mm512_unpackhi_epi32(vecs[2], vecs[3]);
949*d415bd75Srobert   __m512i ef_0 = _mm512_unpacklo_epi32(vecs[4], vecs[5]);
950*d415bd75Srobert   __m512i ef_2 = _mm512_unpackhi_epi32(vecs[4], vecs[5]);
951*d415bd75Srobert   __m512i gh_0 = _mm512_unpacklo_epi32(vecs[6], vecs[7]);
952*d415bd75Srobert   __m512i gh_2 = _mm512_unpackhi_epi32(vecs[6], vecs[7]);
953*d415bd75Srobert   __m512i ij_0 = _mm512_unpacklo_epi32(vecs[8], vecs[9]);
954*d415bd75Srobert   __m512i ij_2 = _mm512_unpackhi_epi32(vecs[8], vecs[9]);
955*d415bd75Srobert   __m512i kl_0 = _mm512_unpacklo_epi32(vecs[10], vecs[11]);
956*d415bd75Srobert   __m512i kl_2 = _mm512_unpackhi_epi32(vecs[10], vecs[11]);
957*d415bd75Srobert   __m512i mn_0 = _mm512_unpacklo_epi32(vecs[12], vecs[13]);
958*d415bd75Srobert   __m512i mn_2 = _mm512_unpackhi_epi32(vecs[12], vecs[13]);
959*d415bd75Srobert   __m512i op_0 = _mm512_unpacklo_epi32(vecs[14], vecs[15]);
960*d415bd75Srobert   __m512i op_2 = _mm512_unpackhi_epi32(vecs[14], vecs[15]);
961*d415bd75Srobert 
962*d415bd75Srobert   // Interleave 64-bit lates. The _0 unpack is lanes
963*d415bd75Srobert   // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes
964*d415bd75Srobert   // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes
965*d415bd75Srobert   // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes
966*d415bd75Srobert   // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15.
967*d415bd75Srobert   __m512i abcd_0 = _mm512_unpacklo_epi64(ab_0, cd_0);
968*d415bd75Srobert   __m512i abcd_1 = _mm512_unpackhi_epi64(ab_0, cd_0);
969*d415bd75Srobert   __m512i abcd_2 = _mm512_unpacklo_epi64(ab_2, cd_2);
970*d415bd75Srobert   __m512i abcd_3 = _mm512_unpackhi_epi64(ab_2, cd_2);
971*d415bd75Srobert   __m512i efgh_0 = _mm512_unpacklo_epi64(ef_0, gh_0);
972*d415bd75Srobert   __m512i efgh_1 = _mm512_unpackhi_epi64(ef_0, gh_0);
973*d415bd75Srobert   __m512i efgh_2 = _mm512_unpacklo_epi64(ef_2, gh_2);
974*d415bd75Srobert   __m512i efgh_3 = _mm512_unpackhi_epi64(ef_2, gh_2);
975*d415bd75Srobert   __m512i ijkl_0 = _mm512_unpacklo_epi64(ij_0, kl_0);
976*d415bd75Srobert   __m512i ijkl_1 = _mm512_unpackhi_epi64(ij_0, kl_0);
977*d415bd75Srobert   __m512i ijkl_2 = _mm512_unpacklo_epi64(ij_2, kl_2);
978*d415bd75Srobert   __m512i ijkl_3 = _mm512_unpackhi_epi64(ij_2, kl_2);
979*d415bd75Srobert   __m512i mnop_0 = _mm512_unpacklo_epi64(mn_0, op_0);
980*d415bd75Srobert   __m512i mnop_1 = _mm512_unpackhi_epi64(mn_0, op_0);
981*d415bd75Srobert   __m512i mnop_2 = _mm512_unpacklo_epi64(mn_2, op_2);
982*d415bd75Srobert   __m512i mnop_3 = _mm512_unpackhi_epi64(mn_2, op_2);
983*d415bd75Srobert 
984*d415bd75Srobert   // Interleave 128-bit lanes. The _0 unpack is
985*d415bd75Srobert   // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is
986*d415bd75Srobert   // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on.
987*d415bd75Srobert   __m512i abcdefgh_0 = unpack_lo_128(abcd_0, efgh_0);
988*d415bd75Srobert   __m512i abcdefgh_1 = unpack_lo_128(abcd_1, efgh_1);
989*d415bd75Srobert   __m512i abcdefgh_2 = unpack_lo_128(abcd_2, efgh_2);
990*d415bd75Srobert   __m512i abcdefgh_3 = unpack_lo_128(abcd_3, efgh_3);
991*d415bd75Srobert   __m512i abcdefgh_4 = unpack_hi_128(abcd_0, efgh_0);
992*d415bd75Srobert   __m512i abcdefgh_5 = unpack_hi_128(abcd_1, efgh_1);
993*d415bd75Srobert   __m512i abcdefgh_6 = unpack_hi_128(abcd_2, efgh_2);
994*d415bd75Srobert   __m512i abcdefgh_7 = unpack_hi_128(abcd_3, efgh_3);
995*d415bd75Srobert   __m512i ijklmnop_0 = unpack_lo_128(ijkl_0, mnop_0);
996*d415bd75Srobert   __m512i ijklmnop_1 = unpack_lo_128(ijkl_1, mnop_1);
997*d415bd75Srobert   __m512i ijklmnop_2 = unpack_lo_128(ijkl_2, mnop_2);
998*d415bd75Srobert   __m512i ijklmnop_3 = unpack_lo_128(ijkl_3, mnop_3);
999*d415bd75Srobert   __m512i ijklmnop_4 = unpack_hi_128(ijkl_0, mnop_0);
1000*d415bd75Srobert   __m512i ijklmnop_5 = unpack_hi_128(ijkl_1, mnop_1);
1001*d415bd75Srobert   __m512i ijklmnop_6 = unpack_hi_128(ijkl_2, mnop_2);
1002*d415bd75Srobert   __m512i ijklmnop_7 = unpack_hi_128(ijkl_3, mnop_3);
1003*d415bd75Srobert 
1004*d415bd75Srobert   // Interleave 128-bit lanes again for the final outputs.
1005*d415bd75Srobert   vecs[0] = unpack_lo_128(abcdefgh_0, ijklmnop_0);
1006*d415bd75Srobert   vecs[1] = unpack_lo_128(abcdefgh_1, ijklmnop_1);
1007*d415bd75Srobert   vecs[2] = unpack_lo_128(abcdefgh_2, ijklmnop_2);
1008*d415bd75Srobert   vecs[3] = unpack_lo_128(abcdefgh_3, ijklmnop_3);
1009*d415bd75Srobert   vecs[4] = unpack_lo_128(abcdefgh_4, ijklmnop_4);
1010*d415bd75Srobert   vecs[5] = unpack_lo_128(abcdefgh_5, ijklmnop_5);
1011*d415bd75Srobert   vecs[6] = unpack_lo_128(abcdefgh_6, ijklmnop_6);
1012*d415bd75Srobert   vecs[7] = unpack_lo_128(abcdefgh_7, ijklmnop_7);
1013*d415bd75Srobert   vecs[8] = unpack_hi_128(abcdefgh_0, ijklmnop_0);
1014*d415bd75Srobert   vecs[9] = unpack_hi_128(abcdefgh_1, ijklmnop_1);
1015*d415bd75Srobert   vecs[10] = unpack_hi_128(abcdefgh_2, ijklmnop_2);
1016*d415bd75Srobert   vecs[11] = unpack_hi_128(abcdefgh_3, ijklmnop_3);
1017*d415bd75Srobert   vecs[12] = unpack_hi_128(abcdefgh_4, ijklmnop_4);
1018*d415bd75Srobert   vecs[13] = unpack_hi_128(abcdefgh_5, ijklmnop_5);
1019*d415bd75Srobert   vecs[14] = unpack_hi_128(abcdefgh_6, ijklmnop_6);
1020*d415bd75Srobert   vecs[15] = unpack_hi_128(abcdefgh_7, ijklmnop_7);
1021*d415bd75Srobert }
1022*d415bd75Srobert 
transpose_msg_vecs16(const uint8_t * const * inputs,size_t block_offset,__m512i out[16])1023*d415bd75Srobert INLINE void transpose_msg_vecs16(const uint8_t *const *inputs,
1024*d415bd75Srobert                                  size_t block_offset, __m512i out[16]) {
1025*d415bd75Srobert   out[0] = loadu_512(&inputs[0][block_offset]);
1026*d415bd75Srobert   out[1] = loadu_512(&inputs[1][block_offset]);
1027*d415bd75Srobert   out[2] = loadu_512(&inputs[2][block_offset]);
1028*d415bd75Srobert   out[3] = loadu_512(&inputs[3][block_offset]);
1029*d415bd75Srobert   out[4] = loadu_512(&inputs[4][block_offset]);
1030*d415bd75Srobert   out[5] = loadu_512(&inputs[5][block_offset]);
1031*d415bd75Srobert   out[6] = loadu_512(&inputs[6][block_offset]);
1032*d415bd75Srobert   out[7] = loadu_512(&inputs[7][block_offset]);
1033*d415bd75Srobert   out[8] = loadu_512(&inputs[8][block_offset]);
1034*d415bd75Srobert   out[9] = loadu_512(&inputs[9][block_offset]);
1035*d415bd75Srobert   out[10] = loadu_512(&inputs[10][block_offset]);
1036*d415bd75Srobert   out[11] = loadu_512(&inputs[11][block_offset]);
1037*d415bd75Srobert   out[12] = loadu_512(&inputs[12][block_offset]);
1038*d415bd75Srobert   out[13] = loadu_512(&inputs[13][block_offset]);
1039*d415bd75Srobert   out[14] = loadu_512(&inputs[14][block_offset]);
1040*d415bd75Srobert   out[15] = loadu_512(&inputs[15][block_offset]);
1041*d415bd75Srobert   for (size_t i = 0; i < 16; ++i) {
1042*d415bd75Srobert     _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
1043*d415bd75Srobert   }
1044*d415bd75Srobert   transpose_vecs_512(out);
1045*d415bd75Srobert }
1046*d415bd75Srobert 
load_counters16(uint64_t counter,bool increment_counter,__m512i * out_lo,__m512i * out_hi)1047*d415bd75Srobert INLINE void load_counters16(uint64_t counter, bool increment_counter,
1048*d415bd75Srobert                             __m512i *out_lo, __m512i *out_hi) {
1049*d415bd75Srobert   const __m512i mask = _mm512_set1_epi32(-(int32_t)increment_counter);
1050*d415bd75Srobert   const __m512i add0 = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1051*d415bd75Srobert   const __m512i add1 = _mm512_and_si512(mask, add0);
1052*d415bd75Srobert   __m512i l = _mm512_add_epi32(_mm512_set1_epi32((int32_t)counter), add1);
1053*d415bd75Srobert   __mmask16 carry = _mm512_cmp_epu32_mask(l, add1, _MM_CMPINT_LT);
1054*d415bd75Srobert   __m512i h = _mm512_mask_add_epi32(_mm512_set1_epi32((int32_t)(counter >> 32)), carry, _mm512_set1_epi32((int32_t)(counter >> 32)), _mm512_set1_epi32(1));
1055*d415bd75Srobert   *out_lo = l;
1056*d415bd75Srobert   *out_hi = h;
1057*d415bd75Srobert }
1058*d415bd75Srobert 
1059*d415bd75Srobert static
blake3_hash16_avx512(const uint8_t * const * inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1060*d415bd75Srobert void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
1061*d415bd75Srobert                           const uint32_t key[8], uint64_t counter,
1062*d415bd75Srobert                           bool increment_counter, uint8_t flags,
1063*d415bd75Srobert                           uint8_t flags_start, uint8_t flags_end,
1064*d415bd75Srobert                           uint8_t *out) {
1065*d415bd75Srobert   __m512i h_vecs[8] = {
1066*d415bd75Srobert       set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]),
1067*d415bd75Srobert       set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]),
1068*d415bd75Srobert   };
1069*d415bd75Srobert   __m512i counter_low_vec, counter_high_vec;
1070*d415bd75Srobert   load_counters16(counter, increment_counter, &counter_low_vec,
1071*d415bd75Srobert                   &counter_high_vec);
1072*d415bd75Srobert   uint8_t block_flags = flags | flags_start;
1073*d415bd75Srobert 
1074*d415bd75Srobert   for (size_t block = 0; block < blocks; block++) {
1075*d415bd75Srobert     if (block + 1 == blocks) {
1076*d415bd75Srobert       block_flags |= flags_end;
1077*d415bd75Srobert     }
1078*d415bd75Srobert     __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN);
1079*d415bd75Srobert     __m512i block_flags_vec = set1_512(block_flags);
1080*d415bd75Srobert     __m512i msg_vecs[16];
1081*d415bd75Srobert     transpose_msg_vecs16(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
1082*d415bd75Srobert 
1083*d415bd75Srobert     __m512i v[16] = {
1084*d415bd75Srobert         h_vecs[0],       h_vecs[1],        h_vecs[2],       h_vecs[3],
1085*d415bd75Srobert         h_vecs[4],       h_vecs[5],        h_vecs[6],       h_vecs[7],
1086*d415bd75Srobert         set1_512(IV[0]), set1_512(IV[1]),  set1_512(IV[2]), set1_512(IV[3]),
1087*d415bd75Srobert         counter_low_vec, counter_high_vec, block_len_vec,   block_flags_vec,
1088*d415bd75Srobert     };
1089*d415bd75Srobert     round_fn16(v, msg_vecs, 0);
1090*d415bd75Srobert     round_fn16(v, msg_vecs, 1);
1091*d415bd75Srobert     round_fn16(v, msg_vecs, 2);
1092*d415bd75Srobert     round_fn16(v, msg_vecs, 3);
1093*d415bd75Srobert     round_fn16(v, msg_vecs, 4);
1094*d415bd75Srobert     round_fn16(v, msg_vecs, 5);
1095*d415bd75Srobert     round_fn16(v, msg_vecs, 6);
1096*d415bd75Srobert     h_vecs[0] = xor_512(v[0], v[8]);
1097*d415bd75Srobert     h_vecs[1] = xor_512(v[1], v[9]);
1098*d415bd75Srobert     h_vecs[2] = xor_512(v[2], v[10]);
1099*d415bd75Srobert     h_vecs[3] = xor_512(v[3], v[11]);
1100*d415bd75Srobert     h_vecs[4] = xor_512(v[4], v[12]);
1101*d415bd75Srobert     h_vecs[5] = xor_512(v[5], v[13]);
1102*d415bd75Srobert     h_vecs[6] = xor_512(v[6], v[14]);
1103*d415bd75Srobert     h_vecs[7] = xor_512(v[7], v[15]);
1104*d415bd75Srobert 
1105*d415bd75Srobert     block_flags = flags;
1106*d415bd75Srobert   }
1107*d415bd75Srobert 
1108*d415bd75Srobert   // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8
1109*d415bd75Srobert   // state vectors. Pad the matrix with zeros. After transposition, store the
1110*d415bd75Srobert   // lower half of each vector.
1111*d415bd75Srobert   __m512i padded[16] = {
1112*d415bd75Srobert       h_vecs[0],   h_vecs[1],   h_vecs[2],   h_vecs[3],
1113*d415bd75Srobert       h_vecs[4],   h_vecs[5],   h_vecs[6],   h_vecs[7],
1114*d415bd75Srobert       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1115*d415bd75Srobert       set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1116*d415bd75Srobert   };
1117*d415bd75Srobert   transpose_vecs_512(padded);
1118*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[0 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[0]));
1119*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[1 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[1]));
1120*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[2 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[2]));
1121*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[3 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[3]));
1122*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[4 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[4]));
1123*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[5 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[5]));
1124*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[6 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[6]));
1125*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[7 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[7]));
1126*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[8 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[8]));
1127*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[9 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[9]));
1128*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[10 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[10]));
1129*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[11 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[11]));
1130*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[12 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[12]));
1131*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[13 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[13]));
1132*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[14 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[14]));
1133*d415bd75Srobert   _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
1134*d415bd75Srobert }
1135*d415bd75Srobert 
1136*d415bd75Srobert /*
1137*d415bd75Srobert  * ----------------------------------------------------------------------------
1138*d415bd75Srobert  * hash_many_avx512
1139*d415bd75Srobert  * ----------------------------------------------------------------------------
1140*d415bd75Srobert  */
1141*d415bd75Srobert 
hash_one_avx512(const uint8_t * input,size_t blocks,const uint32_t key[8],uint64_t counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t out[BLAKE3_OUT_LEN])1142*d415bd75Srobert INLINE void hash_one_avx512(const uint8_t *input, size_t blocks,
1143*d415bd75Srobert                             const uint32_t key[8], uint64_t counter,
1144*d415bd75Srobert                             uint8_t flags, uint8_t flags_start,
1145*d415bd75Srobert                             uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) {
1146*d415bd75Srobert   uint32_t cv[8];
1147*d415bd75Srobert   memcpy(cv, key, BLAKE3_KEY_LEN);
1148*d415bd75Srobert   uint8_t block_flags = flags | flags_start;
1149*d415bd75Srobert   while (blocks > 0) {
1150*d415bd75Srobert     if (blocks == 1) {
1151*d415bd75Srobert       block_flags |= flags_end;
1152*d415bd75Srobert     }
1153*d415bd75Srobert     blake3_compress_in_place_avx512(cv, input, BLAKE3_BLOCK_LEN, counter,
1154*d415bd75Srobert                                     block_flags);
1155*d415bd75Srobert     input = &input[BLAKE3_BLOCK_LEN];
1156*d415bd75Srobert     blocks -= 1;
1157*d415bd75Srobert     block_flags = flags;
1158*d415bd75Srobert   }
1159*d415bd75Srobert   memcpy(out, cv, BLAKE3_OUT_LEN);
1160*d415bd75Srobert }
1161*d415bd75Srobert 
blake3_hash_many_avx512(const uint8_t * const * inputs,size_t num_inputs,size_t blocks,const uint32_t key[8],uint64_t counter,bool increment_counter,uint8_t flags,uint8_t flags_start,uint8_t flags_end,uint8_t * out)1162*d415bd75Srobert void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
1163*d415bd75Srobert                              size_t blocks, const uint32_t key[8],
1164*d415bd75Srobert                              uint64_t counter, bool increment_counter,
1165*d415bd75Srobert                              uint8_t flags, uint8_t flags_start,
1166*d415bd75Srobert                              uint8_t flags_end, uint8_t *out) {
1167*d415bd75Srobert   while (num_inputs >= 16) {
1168*d415bd75Srobert     blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags,
1169*d415bd75Srobert                          flags_start, flags_end, out);
1170*d415bd75Srobert     if (increment_counter) {
1171*d415bd75Srobert       counter += 16;
1172*d415bd75Srobert     }
1173*d415bd75Srobert     inputs += 16;
1174*d415bd75Srobert     num_inputs -= 16;
1175*d415bd75Srobert     out = &out[16 * BLAKE3_OUT_LEN];
1176*d415bd75Srobert   }
1177*d415bd75Srobert   while (num_inputs >= 8) {
1178*d415bd75Srobert     blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags,
1179*d415bd75Srobert                         flags_start, flags_end, out);
1180*d415bd75Srobert     if (increment_counter) {
1181*d415bd75Srobert       counter += 8;
1182*d415bd75Srobert     }
1183*d415bd75Srobert     inputs += 8;
1184*d415bd75Srobert     num_inputs -= 8;
1185*d415bd75Srobert     out = &out[8 * BLAKE3_OUT_LEN];
1186*d415bd75Srobert   }
1187*d415bd75Srobert   while (num_inputs >= 4) {
1188*d415bd75Srobert     blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags,
1189*d415bd75Srobert                         flags_start, flags_end, out);
1190*d415bd75Srobert     if (increment_counter) {
1191*d415bd75Srobert       counter += 4;
1192*d415bd75Srobert     }
1193*d415bd75Srobert     inputs += 4;
1194*d415bd75Srobert     num_inputs -= 4;
1195*d415bd75Srobert     out = &out[4 * BLAKE3_OUT_LEN];
1196*d415bd75Srobert   }
1197*d415bd75Srobert   while (num_inputs > 0) {
1198*d415bd75Srobert     hash_one_avx512(inputs[0], blocks, key, counter, flags, flags_start,
1199*d415bd75Srobert                     flags_end, out);
1200*d415bd75Srobert     if (increment_counter) {
1201*d415bd75Srobert       counter += 1;
1202*d415bd75Srobert     }
1203*d415bd75Srobert     inputs += 1;
1204*d415bd75Srobert     num_inputs -= 1;
1205*d415bd75Srobert     out = &out[BLAKE3_OUT_LEN];
1206*d415bd75Srobert   }
1207*d415bd75Srobert }
1208