1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #include <immintrin.h>
7 
8 #include <vector>
9 
10 #include "hexl/logging/logging.hpp"
11 #include "hexl/number-theory/number-theory.hpp"
12 #include "hexl/util/check.hpp"
13 #include "hexl/util/defines.hpp"
14 #include "hexl/util/util.hpp"
15 
16 namespace intel {
17 namespace hexl {
18 
19 #ifdef HEXL_HAS_AVX512DQ
20 
21 /// @brief Returns the unsigned 64-bit integer values in x as a vector
ExtractValues(__m512i x)22 inline std::vector<uint64_t> ExtractValues(__m512i x) {
23   __m256i x0 = _mm512_extracti64x4_epi64(x, 0);
24   __m256i x1 = _mm512_extracti64x4_epi64(x, 1);
25 
26   std::vector<uint64_t> xs{static_cast<uint64_t>(_mm256_extract_epi64(x0, 0)),
27                            static_cast<uint64_t>(_mm256_extract_epi64(x0, 1)),
28                            static_cast<uint64_t>(_mm256_extract_epi64(x0, 2)),
29                            static_cast<uint64_t>(_mm256_extract_epi64(x0, 3)),
30                            static_cast<uint64_t>(_mm256_extract_epi64(x1, 0)),
31                            static_cast<uint64_t>(_mm256_extract_epi64(x1, 1)),
32                            static_cast<uint64_t>(_mm256_extract_epi64(x1, 2)),
33                            static_cast<uint64_t>(_mm256_extract_epi64(x1, 3))};
34 
35   return xs;
36 }
37 
38 /// @brief Returns the signed 64-bit integer values in x as a vector
ExtractIntValues(__m512i x)39 inline std::vector<int64_t> ExtractIntValues(__m512i x) {
40   __m256i x0 = _mm512_extracti64x4_epi64(x, 0);
41   __m256i x1 = _mm512_extracti64x4_epi64(x, 1);
42 
43   std::vector<int64_t> xs{static_cast<int64_t>(_mm256_extract_epi64(x0, 0)),
44                           static_cast<int64_t>(_mm256_extract_epi64(x0, 1)),
45                           static_cast<int64_t>(_mm256_extract_epi64(x0, 2)),
46                           static_cast<int64_t>(_mm256_extract_epi64(x0, 3)),
47                           static_cast<int64_t>(_mm256_extract_epi64(x1, 0)),
48                           static_cast<int64_t>(_mm256_extract_epi64(x1, 1)),
49                           static_cast<int64_t>(_mm256_extract_epi64(x1, 2)),
50                           static_cast<int64_t>(_mm256_extract_epi64(x1, 3))};
51 
52   return xs;
53 }
54 
55 // Returns the 64-bit floating-point values in x as a vector
ExtractValues(__m512d x)56 inline std::vector<double> ExtractValues(__m512d x) {
57   std::vector<double> ret(8, 0);
58   double* x_data = reinterpret_cast<double*>(&x);
59   for (size_t i = 0; i < 8; ++i) {
60     ret[i] = x_data[i];
61   }
62   return ret;
63 }
64 
65 // Returns lower NumBits bits from a 64-bit value
66 template <int NumBits>
ClearTopBits64(__m512i x)67 inline __m512i ClearTopBits64(__m512i x) {
68   const __m512i low52b_mask = _mm512_set1_epi64((1ULL << NumBits) - 1);
69   return _mm512_and_epi64(x, low52b_mask);
70 }
71 
72 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of x
73 // and y to form a 2*BitShift-bit intermediate result.
74 // Returns the high BitShift-bit unsigned integer from the intermediate result
75 template <int BitShift>
76 inline __m512i _mm512_hexl_mulhi_epi(__m512i x, __m512i y);
77 
78 // Dummy implementation to avoid template substitution errors
79 template <>
_mm512_hexl_mulhi_epi(__m512i x,__m512i y)80 inline __m512i _mm512_hexl_mulhi_epi<32>(__m512i x, __m512i y) {
81   HEXL_CHECK(false, "Unimplemented");
82   HEXL_UNUSED(x);
83   HEXL_UNUSED(y);
84   return x;
85 }
86 
87 template <>
_mm512_hexl_mulhi_epi(__m512i x,__m512i y)88 inline __m512i _mm512_hexl_mulhi_epi<64>(__m512i x, __m512i y) {
89   // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit
90   __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff);
91   // Shuffle high bits with low bits in each 64-bit integer =>
92   // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ...
93   __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1);
94   // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ...
95   __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1);
96   __m512i z_lo_lo = _mm512_mul_epu32(x, y);        // x_lo * y_lo
97   __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi);     // x_lo * y_hi
98   __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y);     // x_hi * y_lo
99   __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi);  // x_hi * y_hi
100 
101   //                   x_hi | x_lo
102   // x                 y_hi | y_lo
103   // ------------------------------
104   //                  [x_lo * y_lo]    // z_lo_lo
105   // +           [z_lo * y_hi]         // z_lo_hi
106   // +           [x_hi * y_lo]         // z_hi_lo
107   // +    [x_hi * y_hi]                // z_hi_hi
108   //     ^-----------^ <-- only bits needed
109   //  sum_|  hi | mid | lo  |
110 
111   // Low bits of z_lo_lo are not needed
112   __m512i z_lo_lo_shift = _mm512_srli_epi64(z_lo_lo, 32);
113 
114   //                   [x_lo  *  y_lo] // z_lo_lo
115   //          + [z_lo  *  y_hi]        // z_lo_hi
116   //          ------------------------
117   //            |    sum_tmp   |
118   //            |sum_mid|sum_lo|
119   __m512i sum_tmp = _mm512_add_epi64(z_lo_hi, z_lo_lo_shift);
120   __m512i sum_lo = _mm512_and_si512(sum_tmp, lo_mask);
121   __m512i sum_mid = _mm512_srli_epi64(sum_tmp, 32);
122   //            |       |sum_lo|
123   //          + [x_hi   *  y_lo]       // z_hi_lo
124   //          ------------------
125   //            [   sum_mid2   ]
126   __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo);
127   __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32);
128   __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid);
129   return _mm512_add_epi64(sum_hi, sum_mid2_hi);
130 }
131 
132 #ifdef HEXL_HAS_AVX512IFMA
133 template <>
_mm512_hexl_mulhi_epi(__m512i x,__m512i y)134 inline __m512i _mm512_hexl_mulhi_epi<52>(__m512i x, __m512i y) {
135   __m512i zero = _mm512_set1_epi64(0);
136   return _mm512_madd52hi_epu64(zero, x, y);
137 }
138 #endif
139 
140 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of x
141 // and y to form a 2*BitShift-bit intermediate result.
142 // Returns the high BitShift-bit unsigned integer from the intermediate result,
143 // with approximation error at most 1
144 template <int BitShift>
145 inline __m512i _mm512_hexl_mulhi_approx_epi(__m512i x, __m512i y);
146 
147 // Dummy implementation to avoid template substitution errors
148 template <>
_mm512_hexl_mulhi_approx_epi(__m512i x,__m512i y)149 inline __m512i _mm512_hexl_mulhi_approx_epi<32>(__m512i x, __m512i y) {
150   HEXL_CHECK(false, "Unimplemented");
151   HEXL_UNUSED(x);
152   HEXL_UNUSED(y);
153   return x;
154 }
155 
156 template <>
_mm512_hexl_mulhi_approx_epi(__m512i x,__m512i y)157 inline __m512i _mm512_hexl_mulhi_approx_epi<64>(__m512i x, __m512i y) {
158   // https://stackoverflow.com/questions/28807341/simd-signed-with-unsigned-multiplication-for-64-bit-64-bit-to-128-bit
159   __m512i lo_mask = _mm512_set1_epi64(0x00000000ffffffff);
160   // Shuffle high bits with low bits in each 64-bit integer =>
161   // x0_lo, x0_hi, x1_lo, x1_hi, x2_lo, x2_hi, ...
162   __m512i x_hi = _mm512_shuffle_epi32(x, (_MM_PERM_ENUM)0xB1);
163   // y0_lo, y0_hi, y1_lo, y1_hi, y2_lo, y2_hi, ...
164   __m512i y_hi = _mm512_shuffle_epi32(y, (_MM_PERM_ENUM)0xB1);
165   __m512i z_lo_hi = _mm512_mul_epu32(x, y_hi);     // x_lo * y_hi
166   __m512i z_hi_lo = _mm512_mul_epu32(x_hi, y);     // x_hi * y_lo
167   __m512i z_hi_hi = _mm512_mul_epu32(x_hi, y_hi);  // x_hi * y_hi
168 
169   //                   x_hi | x_lo
170   // x                 y_hi | y_lo
171   // ------------------------------
172   //                  [x_lo * y_lo]    // unused, resulting in approximation
173   // +           [z_lo * y_hi]         // z_lo_hi
174   // +           [x_hi * y_lo]         // z_hi_lo
175   // +    [x_hi * y_hi]                // z_hi_hi
176   //     ^-----------^ <-- only bits needed
177   //  sum_|  hi | mid | lo  |
178 
179   __m512i sum_lo = _mm512_and_si512(z_lo_hi, lo_mask);
180   __m512i sum_mid = _mm512_srli_epi64(z_lo_hi, 32);
181   //            |       |sum_lo|
182   //          + [x_hi   *  y_lo]       // z_hi_lo
183   //          ------------------
184   //            [   sum_mid2   ]
185   __m512i sum_mid2 = _mm512_add_epi64(z_hi_lo, sum_lo);
186   __m512i sum_mid2_hi = _mm512_srli_epi64(sum_mid2, 32);
187   __m512i sum_hi = _mm512_add_epi64(z_hi_hi, sum_mid);
188   return _mm512_add_epi64(sum_hi, sum_mid2_hi);
189 }
190 
191 #ifdef HEXL_HAS_AVX512IFMA
192 template <>
_mm512_hexl_mulhi_approx_epi(__m512i x,__m512i y)193 inline __m512i _mm512_hexl_mulhi_approx_epi<52>(__m512i x, __m512i y) {
194   __m512i zero = _mm512_set1_epi64(0);
195   return _mm512_madd52hi_epu64(zero, x, y);
196 }
197 #endif
198 
199 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of x
200 // and y to form a 2*BitShift-bit intermediate result.
201 // Returns the low BitShift-bit unsigned integer from the intermediate result
202 template <int BitShift>
203 inline __m512i _mm512_hexl_mullo_epi(__m512i x, __m512i y);
204 
205 // Dummy implementation to avoid template substitution errors
206 template <>
_mm512_hexl_mullo_epi(__m512i x,__m512i y)207 inline __m512i _mm512_hexl_mullo_epi<32>(__m512i x, __m512i y) {
208   HEXL_CHECK(false, "Unimplemented");
209   HEXL_UNUSED(x);
210   HEXL_UNUSED(y);
211   return x;
212 }
213 
214 template <>
_mm512_hexl_mullo_epi(__m512i x,__m512i y)215 inline __m512i _mm512_hexl_mullo_epi<64>(__m512i x, __m512i y) {
216   return _mm512_mullo_epi64(x, y);
217 }
218 
219 #ifdef HEXL_HAS_AVX512IFMA
220 template <>
_mm512_hexl_mullo_epi(__m512i x,__m512i y)221 inline __m512i _mm512_hexl_mullo_epi<52>(__m512i x, __m512i y) {
222   __m512i zero = _mm512_set1_epi64(0);
223   return _mm512_madd52lo_epu64(zero, x, y);
224 }
225 #endif
226 
227 // Multiply packed unsigned BitShift-bit integers in each 64-bit element of y
228 // and z to form a 2*BitShift-bit intermediate result. The low BitShift bits of
229 // the result are added to x, then the low BitShift bits of the result are
230 // returned.
231 template <int BitShift>
232 inline __m512i _mm512_hexl_mullo_add_lo_epi(__m512i x, __m512i y, __m512i z);
233 
234 #ifdef HEXL_HAS_AVX512IFMA
235 template <>
_mm512_hexl_mullo_add_lo_epi(__m512i x,__m512i y,__m512i z)236 inline __m512i _mm512_hexl_mullo_add_lo_epi<52>(__m512i x, __m512i y,
237                                                 __m512i z) {
238   __m512i result = _mm512_madd52lo_epu64(x, y, z);
239 
240   // Clear high 12 bits from result
241   result = ClearTopBits64<52>(result);
242   return result;
243 }
244 #endif
245 
246 // Dummy implementation to avoid template substitution errors
247 template <>
_mm512_hexl_mullo_add_lo_epi(__m512i x,__m512i y,__m512i z)248 inline __m512i _mm512_hexl_mullo_add_lo_epi<32>(__m512i x, __m512i y,
249                                                 __m512i z) {
250   HEXL_CHECK(false, "Unimplemented");
251   HEXL_UNUSED(x);
252   HEXL_UNUSED(y);
253   HEXL_UNUSED(z);
254   return x;
255 }
256 
257 template <>
_mm512_hexl_mullo_add_lo_epi(__m512i x,__m512i y,__m512i z)258 inline __m512i _mm512_hexl_mullo_add_lo_epi<64>(__m512i x, __m512i y,
259                                                 __m512i z) {
260   __m512i prod = _mm512_mullo_epi64(y, z);
261   return _mm512_add_epi64(x, prod);
262 }
263 
264 // Returns x mod q across each 64-bit integer SIMD lanes
265 // Assumes x < InputModFactor * q in all lanes
266 template <int InputModFactor = 2>
_mm512_hexl_small_mod_epu64(__m512i x,__m512i q,__m512i * q_times_2=nullptr,__m512i * q_times_4=nullptr)267 inline __m512i _mm512_hexl_small_mod_epu64(__m512i x, __m512i q,
268                                            __m512i* q_times_2 = nullptr,
269                                            __m512i* q_times_4 = nullptr) {
270   HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 ||
271                  InputModFactor == 4 || InputModFactor == 8,
272              "InputModFactor must be 1, 2, 4, or 8");
273   if (InputModFactor == 1) {
274     return x;
275   }
276   if (InputModFactor == 2) {
277     return _mm512_min_epu64(x, _mm512_sub_epi64(x, q));
278   }
279   if (InputModFactor == 4) {
280     HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr");
281     x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2));
282     return _mm512_min_epu64(x, _mm512_sub_epi64(x, q));
283   }
284   if (InputModFactor == 8) {
285     HEXL_CHECK(q_times_2 != nullptr, "q_times_2 must not be nullptr");
286     HEXL_CHECK(q_times_4 != nullptr, "q_times_4 must not be nullptr");
287     x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_4));
288     x = _mm512_min_epu64(x, _mm512_sub_epi64(x, *q_times_2));
289     return _mm512_min_epu64(x, _mm512_sub_epi64(x, q));
290   }
291   HEXL_CHECK(false, "Invalid InputModFactor");
292   return x;  // Return dummy value
293 }
294 
295 // Returns (x + y) mod q; assumes 0 < x, y < q
_mm512_hexl_small_add_mod_epi64(__m512i x,__m512i y,__m512i q)296 inline __m512i _mm512_hexl_small_add_mod_epi64(__m512i x, __m512i y,
297                                                __m512i q) {
298   HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0],
299                     "x exceeds bound " << ExtractValues(q)[0]);
300   HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0],
301                     "y exceeds bound " << ExtractValues(q)[0]);
302   return _mm512_hexl_small_mod_epu64(_mm512_add_epi64(x, y), q);
303 
304   // Alternate implementation:
305   // x += y - q;
306   // if (x < 0) x+= q
307   // return x
308   // __m512i v_diff = _mm512_sub_epi64(y, q);
309   // x = _mm512_add_epi64(x, v_diff);
310   // __mmask8 sign_bits = _mm512_movepi64_mask(x);
311   // return _mm512_mask_add_epi64(x, sign_bits, x, q);
312 }
313 
314 // Returns (x - y) mod q; assumes 0 < x, y < q
315 
_mm512_hexl_small_sub_mod_epi64(__m512i x,__m512i y,__m512i q)316 inline __m512i _mm512_hexl_small_sub_mod_epi64(__m512i x, __m512i y,
317                                                __m512i q) {
318   HEXL_CHECK_BOUNDS(ExtractValues(x).data(), 8, ExtractValues(q)[0],
319                     "x exceeds bound " << ExtractValues(q)[0]);
320   HEXL_CHECK_BOUNDS(ExtractValues(y).data(), 8, ExtractValues(q)[0],
321                     "y exceeds bound " << ExtractValues(q)[0]);
322 
323   // diff = x - y;
324   // return (diff < 0) ? (diff + q) : diff
325   __m512i v_diff = _mm512_sub_epi64(x, y);
326   __mmask8 sign_bits = _mm512_movepi64_mask(v_diff);
327   return _mm512_mask_add_epi64(v_diff, sign_bits, v_diff, q);
328 }
329 
_mm512_hexl_cmp_epu64_mask(__m512i a,__m512i b,CMPINT cmp)330 inline __mmask8 _mm512_hexl_cmp_epu64_mask(__m512i a, __m512i b, CMPINT cmp) {
331   switch (cmp) {
332     case CMPINT::EQ:
333       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::EQ));
334     case CMPINT::LT:
335       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::LT));
336     case CMPINT::LE:
337       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::LE));
338     case CMPINT::FALSE:
339       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::FALSE));
340     case CMPINT::NE:
341       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::NE));
342     case CMPINT::NLT:
343       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::NLT));
344     case CMPINT::NLE:
345       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::NLE));
346     case CMPINT::TRUE:
347       return _mm512_cmp_epu64_mask(a, b, static_cast<int>(CMPINT::TRUE));
348   }
349   __mmask8 dummy = 0;  // Avoid end of non-void function warning
350   return dummy;
351 }
352 
353 // Returns c[i] = a[i] CMP b[i] ? match_value : 0
_mm512_hexl_cmp_epi64(__m512i a,__m512i b,CMPINT cmp,uint64_t match_value)354 inline __m512i _mm512_hexl_cmp_epi64(__m512i a, __m512i b, CMPINT cmp,
355                                      uint64_t match_value) {
356   __mmask8 mask = _mm512_hexl_cmp_epu64_mask(a, b, cmp);
357   return _mm512_maskz_broadcastq_epi64(
358       mask, _mm_set1_epi64x(static_cast<int64_t>(match_value)));
359 }
360 
361 // Returns c[i] = a[i] >= b[i] ? match_value : 0
_mm512_hexl_cmpge_epu64(__m512i a,__m512i b,uint64_t match_value)362 inline __m512i _mm512_hexl_cmpge_epu64(__m512i a, __m512i b,
363                                        uint64_t match_value) {
364   return _mm512_hexl_cmp_epi64(a, b, CMPINT::NLT, match_value);
365 }
366 
367 // Returns c[i] = a[i] < b[i] ? match_value : 0
_mm512_hexl_cmplt_epu64(__m512i a,__m512i b,uint64_t match_value)368 inline __m512i _mm512_hexl_cmplt_epu64(__m512i a, __m512i b,
369                                        uint64_t match_value) {
370   return _mm512_hexl_cmp_epi64(a, b, CMPINT::LT, match_value);
371 }
372 
373 // Returns c[i] = a[i] <= b[i] ? match_value : 0
_mm512_hexl_cmple_epu64(__m512i a,__m512i b,uint64_t match_value)374 inline __m512i _mm512_hexl_cmple_epu64(__m512i a, __m512i b,
375                                        uint64_t match_value) {
376   return _mm512_hexl_cmp_epi64(a, b, CMPINT::LE, match_value);
377 }
378 
379 // Returns Montgomery form of ab mod q, computed via the REDC algorithm,
380 // also known as Montgomery reduction.
381 // Template: r with R = 2^r
382 // Inputs: q such that gcd(R, q) = 1. R > q.
383 //         v_inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
384 //         T = ab in the range [0, Rq − 1].
385 // T_hi and T_lo for BitShift = 64 should be given in 63 bits.
386 // Output: Integer S in the range [0, q − 1] such that S ≡ TR^−1 mod q
387 template <int BitShift, int r>
_mm512_hexl_montgomery_reduce(__m512i T_hi,__m512i T_lo,__m512i q,__m512i v_inv_mod,__m512i v_rs_or_msk)388 inline __m512i _mm512_hexl_montgomery_reduce(__m512i T_hi, __m512i T_lo,
389                                              __m512i q, __m512i v_inv_mod,
390                                              __m512i v_rs_or_msk) {
391   HEXL_CHECK(BitShift == 52 || BitShift == 64,
392              "Invalid bitshift " << BitShift << "; need 52 or 64");
393 
394 #ifdef HEXL_HAS_AVX512IFMA
395   if (BitShift == 52) {
396     // Operation:
397     // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask
398     __m512i m = ClearTopBits64<r>(T_lo);
399     m = _mm512_hexl_mullo_epi<BitShift>(m, v_inv_mod);
400     m = ClearTopBits64<r>(m);
401 
402     // Operation: t ← (T + mN) / R = (T + m*q) >> r
403     // Hi part
404     __m512i t_hi = _mm512_madd52hi_epu64(T_hi, m, q);
405     // Low part
406     __m512i t = _mm512_madd52lo_epu64(T_lo, m, q);
407     t = _mm512_srli_epi64(t, r);
408     // Join parts
409     t = _mm512_madd52lo_epu64(t, t_hi, v_rs_or_msk);
410 
411     // If this function exists for 52 bits we could save 1 cycle
412     // t = _mm512_shrdi_epi64 (t_hi, t, r)
413 
414     // Operation: t ≥ q? return (t - q) : return t
415     return _mm512_hexl_small_mod_epu64<2>(t, q);
416   }
417 #endif
418 
419   HEXL_CHECK(BitShift == 64, "Invalid bitshift " << BitShift << "; need 64");
420 
421   // Operation:
422   // m ← ((T mod R)N′) mod R | m ← ((T & mod_R_mask)*v_inv_mod) & mod_R_mask
423   __m512i m = ClearTopBits64<r>(T_lo);
424   m = _mm512_hexl_mullo_epi<BitShift>(m, v_inv_mod);
425   m = ClearTopBits64<r>(m);
426 
427   __m512i mq_hi = _mm512_hexl_mulhi_epi<BitShift>(m, q);
428   __m512i mq_lo = _mm512_hexl_mullo_epi<BitShift>(m, q);
429 
430   // to 63 bits
431   mq_hi = _mm512_slli_epi64(mq_hi, 1);
432   __m512i tmp = _mm512_srli_epi64(mq_lo, 63);
433   mq_hi = _mm512_add_epi64(mq_hi, tmp);
434   mq_lo = _mm512_and_epi64(mq_lo, v_rs_or_msk);
435 
436   __m512i t_hi = _mm512_add_epi64(T_hi, mq_hi);
437   t_hi = _mm512_slli_epi64(t_hi, 63 - r);
438   __m512i t = _mm512_add_epi64(T_lo, mq_lo);
439   t = _mm512_srli_epi64(t, r);
440 
441   // Join parts
442   t = _mm512_add_epi64(t_hi, t);
443 
444   return _mm512_hexl_small_mod_epu64<2>(t, q);
445 }
446 
447 // Returns x mod q, computed via Barrett reduction
448 // @param q_barr floor(2^BitShift / q)
449 template <int BitShift = 64, int OutputModFactor = 1>
_mm512_hexl_barrett_reduce64(__m512i x,__m512i q,__m512i q_barr_64,__m512i q_barr_52,uint64_t prod_right_shift,__m512i v_neg_mod)450 inline __m512i _mm512_hexl_barrett_reduce64(__m512i x, __m512i q,
451                                             __m512i q_barr_64,
452                                             __m512i q_barr_52,
453                                             uint64_t prod_right_shift,
454                                             __m512i v_neg_mod) {
455   HEXL_UNUSED(q_barr_52);
456   HEXL_UNUSED(prod_right_shift);
457   HEXL_UNUSED(v_neg_mod);
458   HEXL_CHECK(BitShift == 52 || BitShift == 64,
459              "Invalid bitshift " << BitShift << "; need 52 or 64");
460 
461 #ifdef HEXL_HAS_AVX512IFMA
462   if (BitShift == 52) {
463     __m512i two_pow_fiftytwo = _mm512_set1_epi64(2251799813685248);
464     __mmask8 mask =
465         _mm512_hexl_cmp_epu64_mask(x, two_pow_fiftytwo, CMPINT::NLT);
466     if (mask != 0) {
467       // values above 2^52
468       __m512i x_hi = _mm512_srli_epi64(x, static_cast<unsigned int>(52ULL));
469       __m512i x_lo = ClearTopBits64<52>(x);
470 
471       // c1 = floor(U / 2^{n + beta})
472       __m512i c1_lo =
473           _mm512_srli_epi64(x_lo, static_cast<unsigned int>(prod_right_shift));
474       __m512i c1_hi = _mm512_slli_epi64(
475           x_hi, static_cast<unsigned int>(52ULL - (prod_right_shift)));
476       __m512i c1 = _mm512_or_epi64(c1_lo, c1_hi);
477 
478       // alpha - beta == 52, so we only need high 52 bits
479       __m512i q_hat = _mm512_hexl_mulhi_epi<52>(c1, q_barr_64);
480 
481       // Z = prod_lo - (p * q_hat)_lo
482       x = _mm512_hexl_mullo_add_lo_epi<52>(x_lo, q_hat, v_neg_mod);
483     } else {
484       __m512i rnd1_hi = _mm512_hexl_mulhi_epi<52>(x, q_barr_52);
485       __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<52>(rnd1_hi, q);
486       x = _mm512_sub_epi64(x, tmp1_times_mod);
487     }
488   }
489 #endif
490   if (BitShift == 64) {
491     __m512i rnd1_hi = _mm512_hexl_mulhi_epi<64>(x, q_barr_64);
492     __m512i tmp1_times_mod = _mm512_hexl_mullo_epi<64>(rnd1_hi, q);
493     x = _mm512_sub_epi64(x, tmp1_times_mod);
494   }
495 
496   // Correction
497   if (OutputModFactor == 1) {
498     x = _mm512_hexl_small_mod_epu64<2>(x, q);
499   }
500   return x;
501 }
502 
503 // Concatenate packed 64-bit integers in x and y, producing an intermediate
504 // 128-bit result. Shift the result right by bit_shift bits, and return the
505 // lower 64 bits. The bit_shift is a run-time argument, rather than a
506 // compile-time template parameter, so we can't use _mm512_shrdi_epi64
_mm512_hexl_shrdi_epi64(__m512i x,__m512i y,unsigned int bit_shift)507 inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y,
508                                        unsigned int bit_shift) {
509   __m512i c_lo = _mm512_srli_epi64(x, bit_shift);
510   __m512i c_hi = _mm512_slli_epi64(y, 64 - bit_shift);
511   return _mm512_add_epi64(c_lo, c_hi);
512 }
513 
514 // Concatenate packed 64-bit integers in x and y, producing an intermediate
515 // 128-bit result. Shift the result right by BitShift bits, and return the lower
516 // 64 bits.
517 template <int BitShift>
_mm512_hexl_shrdi_epi64(__m512i x,__m512i y)518 inline __m512i _mm512_hexl_shrdi_epi64(__m512i x, __m512i y) {
519 #ifdef HEXL_HAS_AVX512VBMI2
520   return _mm512_shrdi_epi64(x, y, BitShift);
521 #else
522   return _mm512_hexl_shrdi_epi64(x, y, BitShift);
523 #endif
524 }
525 
526 #endif  // HEXL_HAS_AVX512DQ
527 
528 }  // namespace hexl
529 }  // namespace intel
530