1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #pragma once
5
6 #include <immintrin.h>
7
8 #include <vector>
9
10 #include "hexl/logging/logging.hpp"
11 #include "hexl/number-theory/number-theory.hpp"
12 #include "hexl/util/check.hpp"
13 #include "hexl/util/defines.hpp"
14 #include "hexl/util/util.hpp"
15
16 namespace intel {
17 namespace hexl {
18
19 #ifdef HEXL_HAS_AVX512DQ
20
21 /// @brief Returns the unsigned 64-bit integer values in x as a vector
ExtractValues(__m512i x)22 inline std::vector<uint64_t> ExtractValues(__m512i x) {
23 __m256i x0 = _mm512_extracti64x4_epi64(x, 0);
24 __m256i x1 = _mm512_extracti64x4_epi64(x, 1);
25
26 std::vector<uint64_t> xs{static_cast<uint64_t>(_mm256_extract_epi64(x0, 0)),
27 static_cast<uint64_t>(_mm256_extract_epi64(x0, 1)),
28 static_cast<uint64_t>(_mm256_extract_epi64(x0, 2)),
29 static_cast<uint64_t>(_mm256_extract_epi64(x0, 3)),
30 static_cast<uint64_t>(_mm256_extract_epi64(x1, 0)),
31 static_cast<uint64_t>(_mm256_extract_epi64(x1, 1)),
32 static_cast<uint64_t>(_mm256_extract_epi64(x1, 2)),
33 static_cast<uint64_t>(_mm256_extract_epi64(x1, 3))};
34
35 return xs;
36 }
37
38 /// @brief Returns the signed 64-bit integer values in x as a vector
ExtractIntValues(__m512i x)39 inline std::vector<int64_t> ExtractIntValues(__m512i x) {
40 __m256i x0 = _mm512_extracti64x4_epi64(x, 0);
41 __m256i x1 = _mm512_extracti64x4_epi64(x, 1);
42
43 std::vector<int64_t> xs{static_cast<int64_t>(_mm256_extract_epi64(x0, 0)),
44 static_cast<int64_t>(_mm256_extract_epi64(x0, 1)),
45 static_cast<int64_t>(_mm256_extract_epi64(x0, 2)),
46 static_cast<int64_t>(_mm256_extract_epi64(x0, 3)),
47 static_cast<int64_t>(_mm256_extract_epi64(x1, 0)),
48 static_cast<int64_t>(_mm256_extract_epi64(x1, 1)),
49 static_cast<int64_t>(_mm256_extract_epi64(x1, 2)),
50 static_cast<int64_t>(_mm256_extract_epi64(x1, 3))};
51
52 return xs;
53 }
54
55 // Returns the 64-bit floating-point values in x as a vector
ExtractValues(__m512d x)56 inline std::vector<double> ExtractValues(__m512d x) {
57 std::vector<double> ret(8, 0);
58 double* x_data = reinterpret_cast<double*>(&x);
59 for (size_t i = 0; i < 8; ++i) {
60 ret[i] = x_data[i];
61 }
62 return ret;
63 }
64
65 // Returns lower NumBits bits from a 64-bit value
66 template <int NumBits>
ClearTopBits64(__m512i x)67 inline __m512i ClearTopBits64(__m512i x) {
68 const __m512i low52b_mask = _mm512_set1_epi64((1ULL << NumBits) - 1);
69 return _mm512_and_epi64(x, low52b_mask);
70 }
71
72 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of x
73 // and y to form a 2*BitShift-bit intermediate result.
74 // Returns the high BitShift-bit unsigned integer from the intermediate result
75 template <int BitShift>
76 inline __m512i _mm512_hexl_mulhi_epi(__m512i x, __m512i y);
77
78 // Dummy implementation to avoid template substitution errors
79 template <>
_mm512_hexl_mulhi_epi(__m512i x,__m512i y)80 inline __m512i _mm512_hexl_mulhi_epi<32>(__m512i x, __m512i y) {
81 HEXL_CHECK(false, "Unimplemented");
82 HEXL_UNUSED(x);
83 HEXL_UNUSED(y);
84 return x;
85 }
86
87 template <>
_mm512_hexl_mulhi_epi(__m512i x,__m512i y)88 inline __m512i _mm512_hexl_mulhi_epi<64>(__m512i x, __m512i y) {
89 // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit
90 __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff);
91 // Shuffle high bits with low bits in each 64-bit integer =>
92 // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ...
93 __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1);
94 // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ...
95 __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1);
96 __m512i z_lo_lo = _mm512_mul_epu32(x, y); // x_lo * y_lo
97 __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi
98 __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo
99 __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi
100
101 // x_hi | x_lo
102 // x y_hi | y_lo
103 // ------------------------------
104 // [x_lo * y_lo] // z_lo_lo
105 // + [z_lo * y_hi] // z_lo_hi
106 // + [x_hi * y_lo] // z_hi_lo
107 // + [x_hi * y_hi] // z_hi_hi
108 // ^-----------^ <-- only bits needed
109 // sum_| hi | mid | lo |
110
111 // Low bits of z_lo_lo are not needed
112 __m512i z_lo_lo_shift = _mm512_srli_epi64(z_lo_lo, 32);
113
114 // [x_lo * y_lo] // z_lo_lo
115 // + [z_lo * y_hi] // z_lo_hi
116 // ------------------------
117 // | sum_tmp |
118 // |sum_mid|sum_lo|
119 __m512i sum_tmp = _mm512_add_epi64(z_lo_hi, z_lo_lo_shift);
120 __m512i sum_lo = _mm512_and_si512(sum_tmp, lo_mask);
121 __m512i sum_mid = _mm512_srli_epi64(sum_tmp, 32);
122 // | |sum_lo|
123 // + [x_hi * y_lo] // z_hi_lo
124 // ------------------
125 // [ sum_mid2 ]
126 __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo);
127 __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32);
128 __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid);
129 return _mm512_add_epi64(sum_hi, sum_mid2_hi);
130 }
131
132 #ifdef HEXL_HAS_AVX512IFMA
133 template <>
_mm512_hexl_mulhi_epi(__m512i x,__m512i y)134 inline __m512i _mm512_hexl_mulhi_epi<52>(__m512i x, __m512i y) {
135 __m512i zero = _mm512_set1_epi64(0);
136 return _mm512_madd52hi_epu64(zero, x, y);
137 }
138 #endif
139
140 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of x
141 // and y to form a 2*BitShift-bit intermediate result.
142 // Returns the high BitShift-bit unsigned integer from the intermediate result,
143 // with approximation error at most 1
144 template <int BitShift>
145 inline __m512i _mm512_hexl_mulhi_approx_epi(__m512i x, __m512i y);
146
147 // Dummy implementation to avoid template substitution errors
148 template <>
_mm512_hexl_mulhi_approx_epi(__m512i x,__m512i y)149 inline __m512i _mm512_hexl_mulhi_approx_epi<32>(__m512i x, __m512i y) {
150 HEXL_CHECK(false, "Unimplemented");
151 HEXL_UNUSED(x);
152 HEXL_UNUSED(y);
153 return x;
154 }
155
156 template <>
_mm512_hexl_mulhi_approx_epi(__m512i x,__m512i y)157 inline __m512i _mm512_hexl_mulhi_approx_epi<64>(__m512i x, __m512i y) {
158 // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit
159 __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff);
160 // Shuffle high bits with low bits in each 64-bit integer =>
161 // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ...
162 __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1);
163 // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ...
164 __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1);
165 __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi); // x_lo * y_hi
166 __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y); // x_hi * y_lo
167 __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi); // x_hi * y_hi
168
169 // x_hi | x_lo
170 // x y_hi | y_lo
171 // ------------------------------
172 // [x_lo * y_lo] // unused, resulting in approximation
173 // + [z_lo * y_hi] // z_lo_hi
174 // + [x_hi * y_lo] // z_hi_lo
175 // + [x_hi * y_hi] // z_hi_hi
176 // ^-----------^ <-- only bits needed
177 // sum_| hi | mid | lo |
178
179 __m512i sum_lo = _mm512_and_si512(z_lo_hi, lo_mask);
180 __m512i sum_mid = _mm512_srli_epi64(z_lo_hi, 32);
181 // | |sum_lo|
182 // + [x_hi * y_lo] // z_hi_lo
183 // ------------------
184 // [ sum_mid2 ]
185 __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo);
186 __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32);
187 __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid);
188 return _mm512_add_epi64(sum_hi, sum_mid2_hi);
189 }
190
191 #ifdef HEXL_HAS_AVX512IFMA
192 template <>
_mm512_hexl_mulhi_approx_epi(__m512i x,__m512i y)193 inline __m512i _mm512_hexl_mulhi_approx_epi<52>(__m512i x, __m512i y) {
194 __m512i zero = _mm512_set1_epi64(0);
195 return _mm512_madd52hi_epu64(zero, x, y);
196 }
197 #endif
198
199 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of x
200 // and y to form a 2*BitShift-bit intermediate result.
201 // Returns the low BitShift-bit unsigned integer from the intermediate result
202 template <int BitShift>
203 inline __m512i _mm512_hexl_mullo_epi(__m512i x, __m512i y);
204
205 // Dummy implementation to avoid template substitution errors
206 template <>
_mm512_hexl_mullo_epi(__m512i x,__m512i y)207 inline __m512i _mm512_hexl_mullo_epi<32>(__m512i x, __m512i y) {
208 HEXL_CHECK(false, "Unimplemented");
209 HEXL_UNUSED(x);
210 HEXL_UNUSED(y);
211 return x;
212 }
213
214 template <>
_mm512_hexl_mullo_epi(__m512i x,__m512i y)215 inline __m512i _mm512_hexl_mullo_epi<64>(__m512i x, __m512i y) {
216 return _mm512_mullo_epi64(x, y);
217 }
218
219 #ifdef HEXL_HAS_AVX512IFMA
220 template <>
_mm512_hexl_mullo_epi(__m512i x,__m512i y)221 inline __m512i _mm512_hexl_mullo_epi<52>(__m512i x, __m512i y) {
222 __m512i zero = _mm512_set1_epi64(0);
223 return _mm512_madd52lo_epu64(zero, x, y);
224 }
225 #endif
226
227 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of y
228 // and z to form a 2*BitShift-bit intermediate result. The low BitShift bits of
229 // the result are added to x, then the low BitShift bits of the result are
230 // returned.
231 template <int BitShift>
232 inline __m512i _mm512_hexl_mullo_add_lo_epi(__m512i x, __m512i y, __m512i z);
233
234 #ifdef HEXL_HAS_AVX512IFMA
235 template <>
_mm512_hexl_mullo_add_lo_epi(__m512i x,__m512i y,__m512i z)236 inline __m512i _mm512_hexl_mullo_add_lo_epi<52>(__m512i x, __m512i y,
237 __m512i z) {
238 __m512i result = _mm512_madd52lo_epu64(x, y, z);
239
240 // Clear high 12 bits from result
241 result = ClearTopBits64<52>(result);
242 return result;
243 }
244 #endif
245
246 // Dummy implementation to avoid template substitution errors
247 template <>
_mm512_hexl_mullo_add_lo_epi(__m512i x,__m512i y,__m512i z)248 inline __m512i _mm512_hexl_mullo_add_lo_epi<32>(__m512i x, __m512i y,
249 __m512i z) {
250 HEXL_CHECK(false, "Unimplemented");
251 HEXL_UNUSED(x);
252 HEXL_UNUSED(y);
253 HEXL_UNUSED(z);
254 return x;
255 }
256
257 template <>
_mm512_hexl_mullo_add_lo_epi(__m512i x,__m512i y,__m512i z)258 inline __m512i _mm512_hexl_mullo_add_lo_epi<64>(__m512i x, __m512i y,
259 __m512i z) {
260 __m512i prod = _mm512_mullo_epi64(y, z);
261 return _mm512_add_epi64(x, prod);
262 }
263
264 // Returns x mod q across each 64-bit integer SIMD lanes
265 // Assumes x < InputModFactor * q in all lanes
266 template <int InputModFactor = 2>
_mm512_hexl_small_mod_epu64(__m512i x,__m512i q,__m512i * q_times_2=nullptr,__m512i * q_times_4=nullptr)267 inline __m512i _mm512_hexl_small_mod_epu64(__m512i x, __m512i q,
268 __m512i* q_times_2 = nullptr,
269 __m512i* q_times_4 = nullptr) {
270 HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 ||
271 InputModFactor == 4 || InputModFactor == 8,
272 "InputModFactor must be 1, 2, 4, or 8");
273 if (InputModFactor == 1) {
274 return x;
275 }
276 if (InputModFactor == 2) {
277 return _mm512_min_epu64(x, _mm512_sub_epi64(x, q));
278 }
279 if (InputModFactor == 4) {
280 HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr");
281 x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2));
282 return _mm512_min_epu64(x, _mm512_sub_epi64(x, q));
283 }
284 if (InputModFactor == 8) {
285 HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr");
286 HEXL_CHECK(q_times_4 != nullptr, "q_times_4 must not be nullptr");
287 x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_4));
288 x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2));
289 return _mm512_min_epu64(x, _mm512_sub_epi64(x, q));
290 }
291 HEXL_CHECK(false, "Invalid InputModFactor");
292 return x; // Return dummy value
293 }
294
295 // Returns (x + y) mod q; assumes 0 < x, y < q
_mm512_hexl_small_add_mod_epi64(__m512i x,__m512i y,__m512i q)296 inline __m512i _mm512_hexl_small_add_mod_epi64(__m512i x, __m512i y,
297 __m512i q) {
298 HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0],
299 "x exceeds bound " << ExtractValues(q)[0]);
300 HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0],
301 "y exceeds bound " << ExtractValues(q)[0]);
302 return _mm512_hexl_small_mod_epu64(_mm512_add_epi64(x, y), q);
303
304 // Alternate implementation:
305 // x += y - q;
306 // if (x < 0) x+= q
307 // return x
308 // __m512i v_diff = _mm512_sub_epi64(y, q);
309 // x = _mm512_add_epi64(x, v_diff);
310 // __mmask8 sign_bits = _mm512_movepi64_mask(x);
311 // return _mm512_mask_add_epi64(x, sign_bits, x, q);
312 }
313
314 // Returns (x - y) mod q; assumes 0 < x, y < q
315
_mm512_hexl_small_sub_mod_epi64(__m512i x,__m512i y,__m512i q)316 inline __m512i _mm512_hexl_small_sub_mod_epi64(__m512i x, __m512i y,
317 __m512i q) {
318 HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0],
319 "x exceeds bound " << ExtractValues(q)[0]);
320 HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0],
321 "y exceeds bound " << ExtractValues(q)[0]);
322
323 // diff = x - y;
324 // return (diff < 0) ? (diff + q) : diff
325 __m512i v_diff = _mm512_sub_epi64(x, y);
326 __mmask8 sign_bits = _mm512_movepi64_mask(v_diff);
327 return _mm512_mask_add_epi64(v_diff, sign_bits, v_diff, q);
328 }
329
_mm512_hexl_cmp_epu64_mask(__m512i a,__m512i b,CMPINT cmp)330 inline __mmask8 _mm512_hexl_cmp_epu64_mask(__m512i a, __m512i b, CMPINT cmp) {
331 switch (cmp) {
332 case CMPINT::EQ:
333 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::EQ));
334 case CMPINT::LT:
335 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::LT));
336 case CMPINT::LE:
337 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::LE));
338 case CMPINT::FALSE:
339 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::FALSE));
340 case CMPINT::NE:
341 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::NE));
342 case CMPINT::NLT:
343 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::NLT));
344 case CMPINT::NLE:
345 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::NLE));
346 case CMPINT::TRUE:
347 return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::TRUE));
348 }
349 __mmask8 dummy = 0; // Avoid end of non-void function warning
350 return dummy;
351 }
352
353 // Returns c[i] = a[i] CMP b[i] ? match_value : 0
_mm512_hexl_cmp_epi64(__m512i a,__m512i b,CMPINT cmp,uint64_t match_value)354 inline __m512i _mm512_hexl_cmp_epi64(__m512i a, __m512i b, CMPINT cmp,
355 uint64_t match_value) {
356 __mmask8 mask = _mm512_hexl_cmp_epu64_mask(a, b, cmp);
357 return _mm512_maskz_broadcastq_epi64(
358 mask, _mm_set1_epi64x(static_cast<int64_t>(match_value)));
359 }
360
361 // Returns c[i] = a[i] >= b[i] ? match_value : 0
_mm512_hexl_cmpge_epu64(__m512i a,__m512i b,uint64_t match_value)362 inline __m512i _mm512_hexl_cmpge_epu64(__m512i a, __m512i b,
363 uint64_t match_value) {
364 return _mm512_hexl_cmp_epi64(a, b, CMPINT::NLT, match_value);
365 }
366
367 // Returns c[i] = a[i] < b[i] ? match_value : 0
_mm512_hexl_cmplt_epu64(__m512i a,__m512i b,uint64_t match_value)368 inline __m512i _mm512_hexl_cmplt_epu64(__m512i a, __m512i b,
369 uint64_t match_value) {
370 return _mm512_hexl_cmp_epi64(a, b, CMPINT::LT, match_value);
371 }
372
373 // Returns c[i] = a[i] <= b[i] ? match_value : 0
_mm512_hexl_cmple_epu64(__m512i a,__m512i b,uint64_t match_value)374 inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b,
375 uint64_t match_value) {
376 return _mm512_hexl_cmp_epi64(a, b, CMPINT::LE, match_value);
377 }
378
379 // Returns Montgomery form of ab mod q, computed via the REDC algorithm,
380 // also known as Montgomery reduction.
381 // Template: r with R = 2^r
382 // Inputs: q such that gcd(R, q) = 1. R > q.
383 // v_inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
384 // T = ab in the range [0, Rq − 1].
385 // T_hi and T_lo for BitShift = 64 should be given in 63 bits.
386 // Output: Integer S in the range [0, q − 1] such that S ≡ TR^−1 mod q
387 template <int BitShift, int r>
_mm512_hexl_montgomery_reduce(__m512i T_hi,__m512i T_lo,__m512i q,__m512i v_inv_mod,__m512i v_rs_or_msk)388 inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo,
389 __m512i q, __m512i v_inv_mod,
390 __m512i v_rs_or_msk) {
391 HEXL_CHECK(BitShift == 52 || BitShift == 64,
392 "Invalid bitshift " << BitShift << "; need 52 or 64");
393
394 #ifdef HEXL_HAS_AVX512IFMA
395 if (BitShift == 52) {
396 // Operation:
397 // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask
398 __m512i m = ClearTopBits64<r>(T_lo);
399 m = _mm512_hexl_mullo_epi<BitShift>(m, v_inv_mod);
400 m = ClearTopBits64<r>(m);
401
402 // Operation: t ← (T + mN) / R = (T + m*q) >> r
403 // Hi part
404 __m512i t_hi = _mm512_madd52hi_epu64(T_hi, m, q);
405 // Low part
406 __m512i t = _mm512_madd52lo_epu64(T_lo, m, q);
407 t = _mm512_srli_epi64(t, r);
408 // Join parts
409 t = _mm512_madd52lo_epu64(t, t_hi, v_rs_or_msk);
410
411 // If this function exists for 52 bits we could save 1 cycle
412 // t = _mm512_shrdi_epi64 (t_hi, t, r)
413
414 // Operation: t ≥ q? return (t - q) : return t
415 return _mm512_hexl_small_mod_epu64<2>(t, q);
416 }
417 #endif
418
419 HEXL_CHECK(BitShift == 64, "Invalid bitshift " << BitShift << "; need 64");
420
421 // Operation:
422 // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask
423 __m512i m = ClearTopBits64<r>(T_lo);
424 m = _mm512_hexl_mullo_epi<BitShift>(m, v_inv_mod);
425 m = ClearTopBits64<r>(m);
426
427 __m512i mq_hi = _mm512_hexl_mulhi_epi<BitShift>(m, q);
428 __m512i mq_lo = _mm512_hexl_mullo_epi<BitShift>(m, q);
429
430 // to 63 bits
431 mq_hi = _mm512_slli_epi64(mq_hi, 1);
432 __m512i tmp = _mm512_srli_epi64(mq_lo, 63);
433 mq_hi = _mm512_add_epi64(mq_hi, tmp);
434 mq_lo = _mm512_and_epi64(mq_lo, v_rs_or_msk);
435
436 __m512i t_hi = _mm512_add_epi64(T_hi, mq_hi);
437 t_hi = _mm512_slli_epi64(t_hi, 63 - r);
438 __m512i t = _mm512_add_epi64(T_lo, mq_lo);
439 t = _mm512_srli_epi64(t, r);
440
441 // Join parts
442 t = _mm512_add_epi64(t_hi, t);
443
444 return _mm512_hexl_small_mod_epu64<2>(t, q);
445 }
446
447 // Returns x mod q, computed via Barrett reduction
448 // @param q_barr floor(2^BitShift / q)
449 template <int BitShift = 64, int OutputModFactor = 1>
_mm512_hexl_barrett_reduce64(__m512i x,__m512i q,__m512i q_barr_64,__m512i q_barr_52,uint64_t prod_right_shift,__m512i v_neg_mod)450 inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
451 __m512i q_barr_64,
452 __m512i q_barr_52,
453 uint64_t prod_right_shift,
454 __m512i v_neg_mod) {
455 HEXL_UNUSED(q_barr_52);
456 HEXL_UNUSED(prod_right_shift);
457 HEXL_UNUSED(v_neg_mod);
458 HEXL_CHECK(BitShift == 52 || BitShift == 64,
459 "Invalid bitshift " << BitShift << "; need 52 or 64");
460
461 #ifdef HEXL_HAS_AVX512IFMA
462 if (BitShift == 52) {
463 __m512i two_pow_fiftytwo = _mm512_set1_epi64(2251799813685248);
464 __mmask8 mask =
465 _mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT);
466 if (mask != 0) {
467 // values above 2^52
468 __m512i x_hi = _mm512_srli_epi64(x, static_cast<unsigned int>(52ULL));
469 __m512i x_lo = ClearTopBits64<52>(x);
470
471 // c1 = floor(U / 2^{n + beta})
472 __m512i c1_lo =
473 _mm512_srli_epi64(x_lo, static_cast<unsigned int>(prod_right_shift));
474 __m512i c1_hi = _mm512_slli_epi64(
475 x_hi, static_cast<unsigned int>(52ULL - (prod_right_shift)));
476 __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi);
477
478 // alpha - beta == 52, so we only need high 52 bits
479 __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64);
480
481 // Z = prod_lo - (p * q_hat)_lo
482 x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod);
483 } else {
484 __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52);
485 __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q);
486 x = _mm512_sub_epi64(x, tmp1_times_mod);
487 }
488 }
489 #endif
490 if (BitShift == 64) {
491 __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64);
492 __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q);
493 x = _mm512_sub_epi64(x, tmp1_times_mod);
494 }
495
496 // Correction
497 if (OutputModFactor == 1) {
498 x = _mm512_hexl_small_mod_epu64<2>(x, q);
499 }
500 return x;
501 }
502
503 // Concatenate packed 64-bit integers in x and y, producing an intermediate
504 // 128-bit result. Shift the result right by bit_shift bits, and return the
505 // lower 64 bits. The bit_shift is a run-time argument, rather than a
506 // compile-time template parameter, so we can't use _mm512_shrdi_epi64
_mm512_hexl_shrdi_epi64(__m512i x,__m512i y,unsigned int bit_shift)507 inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y,
508 unsigned int bit_shift) {
509 __m512i c_lo = _mm512_srli_epi64(x, bit_shift);
510 __m512i c_hi = _mm512_slli_epi64(y, 64 - bit_shift);
511 return _mm512_add_epi64(c_lo, c_hi);
512 }
513
514 // Concatenate packed 64-bit integers in x and y, producing an intermediate
515 // 128-bit result. Shift the result right by BitShift bits, and return the lower
516 // 64 bits.
517 template <int BitShift>
_mm512_hexl_shrdi_epi64(__m512i x,__m512i y)518 inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y) {
519 #ifdef HEXL_HAS_AVX512VBMI2
520 return _mm512_shrdi_epi64(x, y, BitShift);
521 #else
522 return _mm512_hexl_shrdi_epi64(x, y, BitShift);
523 #endif
524 }
525
526 #endif // HEXL_HAS_AVX512DQ
527
528 } // namespace hexl
529 } // namespace intel
530