1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #include <functional>
7 #include <vector>
8 
9 #include "eltwise/eltwise-reduce-mod-avx512.hpp"
10 #include "eltwise/eltwise-reduce-mod-internal.hpp"
11 #include "hexl/eltwise/eltwise-reduce-mod.hpp"
12 #include "hexl/logging/logging.hpp"
13 #include "hexl/number-theory/number-theory.hpp"
14 #include "hexl/util/check.hpp"
15 #include "util/avx512-util.hpp"
16 
17 namespace intel {
18 namespace hexl {
19 
20 #ifdef HEXL_HAS_AVX512DQ
21 template <int BitShift = 64>
EltwiseReduceModAVX512(uint64_t * result,const uint64_t * operand,uint64_t n,uint64_t modulus,uint64_t input_mod_factor,uint64_t output_mod_factor)22 void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand,
23                             uint64_t n, uint64_t modulus,
24                             uint64_t input_mod_factor,
25                             uint64_t output_mod_factor) {
26   HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr");
27   HEXL_CHECK(n != 0, "Require n != 0");
28   HEXL_CHECK(modulus > 1, "Require modulus > 1");
29   HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 ||
30                  input_mod_factor == 4,
31              "input_mod_factor must be modulus or 2 or 4" << input_mod_factor);
32   HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2,
33              "output_mod_factor must be 1 or 2 " << output_mod_factor);
34   HEXL_CHECK(input_mod_factor != output_mod_factor,
35              "input_mod_factor must not be equal to output_mod_factor ");
36 
37   uint64_t n_tmp = n;
38 
39   // Multi-word Barrett reduction precomputation
40   constexpr int64_t alpha = BitShift - 2;
41   constexpr int64_t beta = -2;
42   const uint64_t ceil_log_mod = Log2(modulus) + 1;  // "n" from Algorithm 2
43   uint64_t prod_right_shift = ceil_log_mod + beta;
44   __m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
45 
46   uint64_t barrett_factor =
47       MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift,
48                      modulus)
49           .BarrettFactor();
50 
51   uint64_t barrett_factor_52 = MultiplyFactor(1, 52, modulus).BarrettFactor();
52 
53   if (BitShift == 64) {
54     // Single-worded Barrett reduction.
55     barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor();
56   }
57 
58   __m512i v_bf = _mm512_set1_epi64(static_cast<int64_t>(barrett_factor));
59   __m512i v_bf_52 = _mm512_set1_epi64(static_cast<int64_t>(barrett_factor_52));
60 
61   // Deals with n not divisible by 8
62   uint64_t n_mod_8 = n_tmp % 8;
63   if (n_mod_8 != 0) {
64     EltwiseReduceModNative(result, operand, n_mod_8, modulus, input_mod_factor,
65                            output_mod_factor);
66     operand += n_mod_8;
67     result += n_mod_8;
68     n_tmp -= n_mod_8;
69   }
70 
71   uint64_t twice_mod = modulus << 1;
72   const __m512i* v_operand = reinterpret_cast<const __m512i*>(operand);
73   __m512i* v_result = reinterpret_cast<__m512i*>(result);
74   __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
75   __m512i v_twice_mod = _mm512_set1_epi64(static_cast<int64_t>(twice_mod));
76 
77   if (input_mod_factor == modulus) {
78     if (output_mod_factor == 2) {
79       for (size_t i = 0; i < n_tmp; i += 8) {
80         __m512i v_op = _mm512_loadu_si512(v_operand);
81         v_op = _mm512_hexl_barrett_reduce64<BitShift, 2>(
82             v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod);
83         HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
84                           "v_op exceeds bound " << modulus);
85         _mm512_storeu_si512(v_result, v_op);
86         ++v_operand;
87         ++v_result;
88       }
89     } else {
90       for (size_t i = 0; i < n_tmp; i += 8) {
91         __m512i v_op = _mm512_loadu_si512(v_operand);
92         v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
93             v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod);
94         HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
95                           "v_op exceeds bound " << modulus);
96         _mm512_storeu_si512(v_result, v_op);
97         ++v_operand;
98         ++v_result;
99       }
100     }
101   }
102 
103   if (input_mod_factor == 2) {
104     for (size_t i = 0; i < n_tmp; i += 8) {
105       __m512i v_op = _mm512_loadu_si512(v_operand);
106       v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus);
107       HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
108                         "v_op exceeds bound " << modulus);
109       _mm512_storeu_si512(v_result, v_op);
110       ++v_operand;
111       ++v_result;
112     }
113   }
114 
115   if (input_mod_factor == 4) {
116     if (output_mod_factor == 1) {
117       for (size_t i = 0; i < n_tmp; i += 8) {
118         __m512i v_op = _mm512_loadu_si512(v_operand);
119         v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod);
120         v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus);
121         HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
122                           "v_op exceeds bound " << modulus);
123         _mm512_storeu_si512(v_result, v_op);
124         ++v_operand;
125         ++v_result;
126       }
127     }
128     if (output_mod_factor == 2) {
129       for (size_t i = 0; i < n_tmp; i += 8) {
130         __m512i v_op = _mm512_loadu_si512(v_operand);
131         v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod);
132         HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod,
133                           "v_op exceeds bound " << twice_mod);
134         _mm512_storeu_si512(v_result, v_op);
135         ++v_operand;
136         ++v_result;
137       }
138     }
139   }
140 }
141 
142 /// @brief Returns Montgomery form of modular product ab mod q, computed via the
143 ///  REDC algorithm, also known as Montgomery reduction.
144 /// @tparam BitShift denotes the operational length, in bits, of the operands
145 /// and result values.
146 /// @tparam r defines the value of R, being R = 2^r. R > modulus.
147 /// @param[in] a input vector. T = ab in the range [0, Rq − 1].
148 /// @param[in] b input vector.
149 /// @param[in] modulus such that gcd(R, modulus) = 1.
150 /// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
151 /// @param[in] n number of elements in input vector.
152 /// @param[out] result unsigned long int vector in the range [0, q − 1] such
153 /// that S ≡ TR^−1 mod q
154 template <int BitShift, int r>
EltwiseMontReduceModAVX512(uint64_t * result,const uint64_t * a,const uint64_t * b,uint64_t n,uint64_t modulus,uint64_t inv_mod)155 void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
156                                 const uint64_t* b, uint64_t n, uint64_t modulus,
157                                 uint64_t inv_mod) {
158   HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
159   HEXL_CHECK(b != nullptr, "Require operand b != nullptr");
160   HEXL_CHECK(n != 0, "Require n != 0");
161   HEXL_CHECK(modulus > 1, "Require modulus > 1");
162 
163   uint64_t R = (1ULL << r);
164   HEXL_CHECK(std::__gcd(static_cast<int64_t>(modulus), static_cast<int64_t>(R)),
165              1);
166   HEXL_CHECK(R > modulus, "Needs R bigger than q.");
167 
168   // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones
169   uint64_t mod_R_mask = R - 1;
170   uint64_t prod_rs;
171   if (BitShift == 64) {
172     prod_rs = (1ULL << 63) - 1;
173   } else {
174     prod_rs = (1ULL << (52 - r));
175   }
176   uint64_t n_tmp = n;
177 
178   // Deals with n not divisible by 8
179   uint64_t n_mod_8 = n_tmp % 8;
180   if (n_mod_8 != 0) {
181     for (size_t i = 0; i < n_mod_8; ++i) {
182       uint64_t T_hi;
183       uint64_t T_lo;
184       MultiplyUInt64(a[i], b[i], &T_hi, &T_lo);
185       result[i] = MontgomeryReduce<BitShift>(T_hi, T_lo, modulus, r, mod_R_mask,
186                                              inv_mod);
187     }
188     a += n_mod_8;
189     b += n_mod_8;
190     result += n_mod_8;
191     n_tmp -= n_mod_8;
192   }
193 
194   const __m512i* v_a = reinterpret_cast<const __m512i*>(a);
195   const __m512i* v_b = reinterpret_cast<const __m512i*>(b);
196   __m512i* v_result = reinterpret_cast<__m512i*>(result);
197   __m512i v_modulus = _mm512_set1_epi64(modulus);
198   __m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
199   __m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
200 
201   for (size_t i = 0; i < n_tmp; i += 8) {
202     __m512i v_a_op = _mm512_loadu_si512(v_a);
203     __m512i v_b_op = _mm512_loadu_si512(v_b);
204     __m512i v_T_hi = _mm512_hexl_mulhi_epi<BitShift>(v_a_op, v_b_op);
205     __m512i v_T_lo = _mm512_hexl_mullo_epi<BitShift>(v_a_op, v_b_op);
206 
207     if (BitShift == 64) {
208       v_T_hi = _mm512_slli_epi64(v_T_hi, 1);
209       __m512i tmp = _mm512_srli_epi64(v_T_lo, 63);
210       v_T_hi = _mm512_add_epi64(v_T_hi, tmp);
211       v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs);
212     }
213 
214     __m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
215         v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
216     HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
217                       "v_op exceeds bound " << modulus);
218     _mm512_storeu_si512(v_result, v_c);
219     ++v_a;
220     ++v_b;
221     ++v_result;
222   }
223 }
224 
225 /// @brief Returns Montgomery form of a mod q, computed via the REDC algorithm,
226 /// also known as Montgomery reduction.
227 /// @tparam BitShift denotes the operational length, in bits, of the operands
228 /// and result values.
229 /// @tparam r defines the value of R, being R = 2^r. R > modulus.
230 /// @param[in] a input vector. T = a(R^2 mod q) in the range [0, Rq − 1].
231 /// @param[in] R2_mod_q R^2 mod q.
232 /// @param[in] modulus such that gcd(R, modulus) = 1.
233 /// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
234 /// @param[in] n number of elements in input vector.
235 /// @param[out] result unsigned long int vector in the range [0, q − 1] such
236 /// that S ≡ TR^−1 mod q
237 template <int BitShift, int r>
EltwiseMontgomeryFormAVX512(uint64_t * result,const uint64_t * a,uint64_t R2_mod_q,uint64_t n,uint64_t modulus,uint64_t inv_mod)238 void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
239                                  uint64_t R2_mod_q, uint64_t n,
240                                  uint64_t modulus, uint64_t inv_mod) {
241   HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
242   HEXL_CHECK(n != 0, "Require n != 0");
243   HEXL_CHECK(modulus > 1, "Require modulus > 1");
244 
245   uint64_t R = (1ULL << r);
246   HEXL_CHECK(std::__gcd(static_cast<int64_t>(modulus), static_cast<int64_t>(R)),
247              1);
248   HEXL_CHECK(R > modulus, "Needs R bigger than q.");
249 
250   // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones
251   uint64_t mod_R_mask = R - 1;
252   uint64_t prod_rs;
253   if (BitShift == 64) {
254     prod_rs = (1ULL << 63) - 1;
255   } else {
256     prod_rs = (1ULL << (52 - r));
257   }
258   uint64_t n_tmp = n;
259 
260   // Deals with n not divisible by 8
261   uint64_t n_mod_8 = n_tmp % 8;
262   if (n_mod_8 != 0) {
263     for (size_t i = 0; i < n_mod_8; ++i) {
264       uint64_t T_hi;
265       uint64_t T_lo;
266       MultiplyUInt64(a[i], R2_mod_q, &T_hi, &T_lo);
267       result[i] = MontgomeryReduce<BitShift>(T_hi, T_lo, modulus, r, mod_R_mask,
268                                              inv_mod);
269     }
270     a += n_mod_8;
271     result += n_mod_8;
272     n_tmp -= n_mod_8;
273   }
274 
275   const __m512i* v_a = reinterpret_cast<const __m512i*>(a);
276   __m512i* v_result = reinterpret_cast<__m512i*>(result);
277   __m512i v_b = _mm512_set1_epi64(R2_mod_q);
278   __m512i v_modulus = _mm512_set1_epi64(modulus);
279   __m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
280   __m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
281 
282   for (size_t i = 0; i < n_tmp; i += 8) {
283     __m512i v_a_op = _mm512_loadu_si512(v_a);
284     __m512i v_T_hi = _mm512_hexl_mulhi_epi<BitShift>(v_a_op, v_b);
285     __m512i v_T_lo = _mm512_hexl_mullo_epi<BitShift>(v_a_op, v_b);
286 
287     if (BitShift == 64) {
288       v_T_hi = _mm512_slli_epi64(v_T_hi, 1);
289       __m512i tmp = _mm512_srli_epi64(v_T_lo, 63);
290       v_T_hi = _mm512_add_epi64(v_T_hi, tmp);
291       v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs);
292     }
293 
294     __m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
295         v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
296     HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
297                       "v_op exceeds bound " << modulus);
298     _mm512_storeu_si512(v_result, v_c);
299     ++v_a;
300     ++v_result;
301   }
302 }
303 
304 #endif
305 
306 }  // namespace hexl
307 }  // namespace intel
308