1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #pragma once
5
6 #include <functional>
7 #include <vector>
8
9 #include "eltwise/eltwise-reduce-mod-avx512.hpp"
10 #include "eltwise/eltwise-reduce-mod-internal.hpp"
11 #include "hexl/eltwise/eltwise-reduce-mod.hpp"
12 #include "hexl/logging/logging.hpp"
13 #include "hexl/number-theory/number-theory.hpp"
14 #include "hexl/util/check.hpp"
15 #include "util/avx512-util.hpp"
16
17 namespace intel {
18 namespace hexl {
19
20 #ifdef HEXL_HAS_AVX512DQ
21 template <int BitShift = 64>
EltwiseReduceModAVX512(uint64_t * result,const uint64_t * operand,uint64_t n,uint64_t modulus,uint64_t input_mod_factor,uint64_t output_mod_factor)22 void EltwiseReduceModAVX512(uint64_t* result, const uint64_t* operand,
23 uint64_t n, uint64_t modulus,
24 uint64_t input_mod_factor,
25 uint64_t output_mod_factor) {
26 HEXL_CHECK(operand != nullptr, "Require operand1 != nullptr");
27 HEXL_CHECK(n != 0, "Require n != 0");
28 HEXL_CHECK(modulus > 1, "Require modulus > 1");
29 HEXL_CHECK(input_mod_factor == modulus || input_mod_factor == 2 ||
30 input_mod_factor == 4,
31 "input_mod_factor must be modulus or 2 or 4" << input_mod_factor);
32 HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2,
33 "output_mod_factor must be 1 or 2 " << output_mod_factor);
34 HEXL_CHECK(input_mod_factor != output_mod_factor,
35 "input_mod_factor must not be equal to output_mod_factor ");
36
37 uint64_t n_tmp = n;
38
39 // Multi-word Barrett reduction precomputation
40 constexpr int64_t alpha = BitShift - 2;
41 constexpr int64_t beta = -2;
42 const uint64_t ceil_log_mod = Log2(modulus) + 1; // "n" from Algorithm 2
43 uint64_t prod_right_shift = ceil_log_mod + beta;
44 __m512i v_neg_mod = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
45
46 uint64_t barrett_factor =
47 MultiplyFactor(uint64_t(1) << (ceil_log_mod + alpha - BitShift), BitShift,
48 modulus)
49 .BarrettFactor();
50
51 uint64_t barrett_factor_52 = MultiplyFactor(1, 52, modulus).BarrettFactor();
52
53 if (BitShift == 64) {
54 // Single-worded Barrett reduction.
55 barrett_factor = MultiplyFactor(1, 64, modulus).BarrettFactor();
56 }
57
58 __m512i v_bf = _mm512_set1_epi64(static_cast<int64_t>(barrett_factor));
59 __m512i v_bf_52 = _mm512_set1_epi64(static_cast<int64_t>(barrett_factor_52));
60
61 // Deals with n not divisible by 8
62 uint64_t n_mod_8 = n_tmp % 8;
63 if (n_mod_8 != 0) {
64 EltwiseReduceModNative(result, operand, n_mod_8, modulus, input_mod_factor,
65 output_mod_factor);
66 operand += n_mod_8;
67 result += n_mod_8;
68 n_tmp -= n_mod_8;
69 }
70
71 uint64_t twice_mod = modulus << 1;
72 const __m512i* v_operand = reinterpret_cast<const __m512i*>(operand);
73 __m512i* v_result = reinterpret_cast<__m512i*>(result);
74 __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
75 __m512i v_twice_mod = _mm512_set1_epi64(static_cast<int64_t>(twice_mod));
76
77 if (input_mod_factor == modulus) {
78 if (output_mod_factor == 2) {
79 for (size_t i = 0; i < n_tmp; i += 8) {
80 __m512i v_op = _mm512_loadu_si512(v_operand);
81 v_op = _mm512_hexl_barrett_reduce64<BitShift, 2>(
82 v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod);
83 HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
84 "v_op exceeds bound " << modulus);
85 _mm512_storeu_si512(v_result, v_op);
86 ++v_operand;
87 ++v_result;
88 }
89 } else {
90 for (size_t i = 0; i < n_tmp; i += 8) {
91 __m512i v_op = _mm512_loadu_si512(v_operand);
92 v_op = _mm512_hexl_barrett_reduce64<BitShift, 1>(
93 v_op, v_modulus, v_bf, v_bf_52, prod_right_shift, v_neg_mod);
94 HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
95 "v_op exceeds bound " << modulus);
96 _mm512_storeu_si512(v_result, v_op);
97 ++v_operand;
98 ++v_result;
99 }
100 }
101 }
102
103 if (input_mod_factor == 2) {
104 for (size_t i = 0; i < n_tmp; i += 8) {
105 __m512i v_op = _mm512_loadu_si512(v_operand);
106 v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus);
107 HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
108 "v_op exceeds bound " << modulus);
109 _mm512_storeu_si512(v_result, v_op);
110 ++v_operand;
111 ++v_result;
112 }
113 }
114
115 if (input_mod_factor == 4) {
116 if (output_mod_factor == 1) {
117 for (size_t i = 0; i < n_tmp; i += 8) {
118 __m512i v_op = _mm512_loadu_si512(v_operand);
119 v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod);
120 v_op = _mm512_hexl_small_mod_epu64(v_op, v_modulus);
121 HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, modulus,
122 "v_op exceeds bound " << modulus);
123 _mm512_storeu_si512(v_result, v_op);
124 ++v_operand;
125 ++v_result;
126 }
127 }
128 if (output_mod_factor == 2) {
129 for (size_t i = 0; i < n_tmp; i += 8) {
130 __m512i v_op = _mm512_loadu_si512(v_operand);
131 v_op = _mm512_hexl_small_mod_epu64(v_op, v_twice_mod);
132 HEXL_CHECK_BOUNDS(ExtractValues(v_op).data(), 8, twice_mod,
133 "v_op exceeds bound " << twice_mod);
134 _mm512_storeu_si512(v_result, v_op);
135 ++v_operand;
136 ++v_result;
137 }
138 }
139 }
140 }
141
142 /// @brief Returns Montgomery form of modular product ab mod q, computed via the
143 /// REDC algorithm, also known as Montgomery reduction.
144 /// @tparam BitShift denotes the operational length, in bits, of the operands
145 /// and result values.
146 /// @tparam r defines the value of R, being R = 2^r. R > modulus.
147 /// @param[in] a input vector. T = ab in the range [0, Rq − 1].
148 /// @param[in] b input vector.
149 /// @param[in] modulus such that gcd(R, modulus) = 1.
150 /// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
151 /// @param[in] n number of elements in input vector.
152 /// @param[out] result unsigned long int vector in the range [0, q − 1] such
153 /// that S ≡ TR^−1 mod q
154 template <int BitShift, int r>
EltwiseMontReduceModAVX512(uint64_t * result,const uint64_t * a,const uint64_t * b,uint64_t n,uint64_t modulus,uint64_t inv_mod)155 void EltwiseMontReduceModAVX512(uint64_t* result, const uint64_t* a,
156 const uint64_t* b, uint64_t n, uint64_t modulus,
157 uint64_t inv_mod) {
158 HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
159 HEXL_CHECK(b != nullptr, "Require operand b != nullptr");
160 HEXL_CHECK(n != 0, "Require n != 0");
161 HEXL_CHECK(modulus > 1, "Require modulus > 1");
162
163 uint64_t R = (1ULL << r);
164 HEXL_CHECK(std::__gcd(static_cast<int64_t>(modulus), static_cast<int64_t>(R)),
165 1);
166 HEXL_CHECK(R > modulus, "Needs R bigger than q.");
167
168 // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones
169 uint64_t mod_R_mask = R - 1;
170 uint64_t prod_rs;
171 if (BitShift == 64) {
172 prod_rs = (1ULL << 63) - 1;
173 } else {
174 prod_rs = (1ULL << (52 - r));
175 }
176 uint64_t n_tmp = n;
177
178 // Deals with n not divisible by 8
179 uint64_t n_mod_8 = n_tmp % 8;
180 if (n_mod_8 != 0) {
181 for (size_t i = 0; i < n_mod_8; ++i) {
182 uint64_t T_hi;
183 uint64_t T_lo;
184 MultiplyUInt64(a[i], b[i], &T_hi, &T_lo);
185 result[i] = MontgomeryReduce<BitShift>(T_hi, T_lo, modulus, r, mod_R_mask,
186 inv_mod);
187 }
188 a += n_mod_8;
189 b += n_mod_8;
190 result += n_mod_8;
191 n_tmp -= n_mod_8;
192 }
193
194 const __m512i* v_a = reinterpret_cast<const __m512i*>(a);
195 const __m512i* v_b = reinterpret_cast<const __m512i*>(b);
196 __m512i* v_result = reinterpret_cast<__m512i*>(result);
197 __m512i v_modulus = _mm512_set1_epi64(modulus);
198 __m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
199 __m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
200
201 for (size_t i = 0; i < n_tmp; i += 8) {
202 __m512i v_a_op = _mm512_loadu_si512(v_a);
203 __m512i v_b_op = _mm512_loadu_si512(v_b);
204 __m512i v_T_hi = _mm512_hexl_mulhi_epi<BitShift>(v_a_op, v_b_op);
205 __m512i v_T_lo = _mm512_hexl_mullo_epi<BitShift>(v_a_op, v_b_op);
206
207 if (BitShift == 64) {
208 v_T_hi = _mm512_slli_epi64(v_T_hi, 1);
209 __m512i tmp = _mm512_srli_epi64(v_T_lo, 63);
210 v_T_hi = _mm512_add_epi64(v_T_hi, tmp);
211 v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs);
212 }
213
214 __m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
215 v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
216 HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
217 "v_op exceeds bound " << modulus);
218 _mm512_storeu_si512(v_result, v_c);
219 ++v_a;
220 ++v_b;
221 ++v_result;
222 }
223 }
224
225 /// @brief Returns Montgomery form of a mod q, computed via the REDC algorithm,
226 /// also known as Montgomery reduction.
227 /// @tparam BitShift denotes the operational length, in bits, of the operands
228 /// and result values.
229 /// @tparam r defines the value of R, being R = 2^r. R > modulus.
230 /// @param[in] a input vector. T = a(R^2 mod q) in the range [0, Rq − 1].
231 /// @param[in] R2_mod_q R^2 mod q.
232 /// @param[in] modulus such that gcd(R, modulus) = 1.
233 /// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R,
234 /// @param[in] n number of elements in input vector.
235 /// @param[out] result unsigned long int vector in the range [0, q − 1] such
236 /// that S ≡ TR^−1 mod q
237 template <int BitShift, int r>
EltwiseMontgomeryFormAVX512(uint64_t * result,const uint64_t * a,uint64_t R2_mod_q,uint64_t n,uint64_t modulus,uint64_t inv_mod)238 void EltwiseMontgomeryFormAVX512(uint64_t* result, const uint64_t* a,
239 uint64_t R2_mod_q, uint64_t n,
240 uint64_t modulus, uint64_t inv_mod) {
241 HEXL_CHECK(a != nullptr, "Require operand a != nullptr");
242 HEXL_CHECK(n != 0, "Require n != 0");
243 HEXL_CHECK(modulus > 1, "Require modulus > 1");
244
245 uint64_t R = (1ULL << r);
246 HEXL_CHECK(std::__gcd(static_cast<int64_t>(modulus), static_cast<int64_t>(R)),
247 1);
248 HEXL_CHECK(R > modulus, "Needs R bigger than q.");
249
250 // mod_R_mask[63:r] all zeros & mod_R_mask[r-1:0] all ones
251 uint64_t mod_R_mask = R - 1;
252 uint64_t prod_rs;
253 if (BitShift == 64) {
254 prod_rs = (1ULL << 63) - 1;
255 } else {
256 prod_rs = (1ULL << (52 - r));
257 }
258 uint64_t n_tmp = n;
259
260 // Deals with n not divisible by 8
261 uint64_t n_mod_8 = n_tmp % 8;
262 if (n_mod_8 != 0) {
263 for (size_t i = 0; i < n_mod_8; ++i) {
264 uint64_t T_hi;
265 uint64_t T_lo;
266 MultiplyUInt64(a[i], R2_mod_q, &T_hi, &T_lo);
267 result[i] = MontgomeryReduce<BitShift>(T_hi, T_lo, modulus, r, mod_R_mask,
268 inv_mod);
269 }
270 a += n_mod_8;
271 result += n_mod_8;
272 n_tmp -= n_mod_8;
273 }
274
275 const __m512i* v_a = reinterpret_cast<const __m512i*>(a);
276 __m512i* v_result = reinterpret_cast<__m512i*>(result);
277 __m512i v_b = _mm512_set1_epi64(R2_mod_q);
278 __m512i v_modulus = _mm512_set1_epi64(modulus);
279 __m512i v_inv_mod = _mm512_set1_epi64(inv_mod);
280 __m512i v_prod_rs = _mm512_set1_epi64(prod_rs);
281
282 for (size_t i = 0; i < n_tmp; i += 8) {
283 __m512i v_a_op = _mm512_loadu_si512(v_a);
284 __m512i v_T_hi = _mm512_hexl_mulhi_epi<BitShift>(v_a_op, v_b);
285 __m512i v_T_lo = _mm512_hexl_mullo_epi<BitShift>(v_a_op, v_b);
286
287 if (BitShift == 64) {
288 v_T_hi = _mm512_slli_epi64(v_T_hi, 1);
289 __m512i tmp = _mm512_srli_epi64(v_T_lo, 63);
290 v_T_hi = _mm512_add_epi64(v_T_hi, tmp);
291 v_T_lo = _mm512_and_epi64(v_T_lo, v_prod_rs);
292 }
293
294 __m512i v_c = _mm512_hexl_montgomery_reduce<BitShift, r>(
295 v_T_hi, v_T_lo, v_modulus, v_inv_mod, v_prod_rs);
296 HEXL_CHECK_BOUNDS(ExtractValues(v_c).data(), 8, modulus,
297 "v_op exceeds bound " << modulus);
298 _mm512_storeu_si512(v_result, v_c);
299 ++v_a;
300 ++v_result;
301 }
302 }
303
304 #endif
305
306 } // namespace hexl
307 } // namespace intel
308