1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #include "ntt/inv-ntt-avx512.hpp"
5
6 #include <immintrin.h>
7
8 #include <cstring>
9 #include <functional>
10 #include <vector>
11
12 #include "hexl/logging/logging.hpp"
13 #include "hexl/ntt/ntt.hpp"
14 #include "hexl/number-theory/number-theory.hpp"
15 #include "ntt/ntt-avx512-util.hpp"
16 #include "ntt/ntt-internal.hpp"
17 #include "util/avx512-util.hpp"
18
19 namespace intel {
20 namespace hexl {
21
22 #ifdef HEXL_HAS_AVX512IFMA
23 template void InverseTransformFromBitReverseAVX512<NTT::s_ifma_shift_bits>(
24 uint64_t* result, const uint64_t* operand, uint64_t degree,
25 uint64_t modulus, const uint64_t* inv_root_of_unity_powers,
26 const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
27 uint64_t output_mod_factor, uint64_t recursion_depth,
28 uint64_t recursion_half);
29 #endif
30
31 #ifdef HEXL_HAS_AVX512DQ
32 template void InverseTransformFromBitReverseAVX512<32>(
33 uint64_t* result, const uint64_t* operand, uint64_t degree,
34 uint64_t modulus, const uint64_t* inv_root_of_unity_powers,
35 const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
36 uint64_t output_mod_factor, uint64_t recursion_depth,
37 uint64_t recursion_half);
38
39 template void InverseTransformFromBitReverseAVX512<NTT::s_default_shift_bits>(
40 uint64_t* result, const uint64_t* operand, uint64_t degree,
41 uint64_t modulus, const uint64_t* inv_root_of_unity_powers,
42 const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
43 uint64_t output_mod_factor, uint64_t recursion_depth,
44 uint64_t recursion_half);
45 #endif
46
47 #ifdef HEXL_HAS_AVX512DQ
48
49 /// @brief The Harvey butterfly: assume X, Y in [0, 2q), and return X', Y' in
50 /// [0, 2q). such that X', Y' = X + Y (mod q), W(X - Y) (mod q).
51 /// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form
52 /// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form
53 /// @param[in] W Root of unity representing 8 64-bit signed integers in SIMD
54 /// form
55 /// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett
56 /// reduction
57 /// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit
58 /// signed integers in SIMD form
59 /// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit
60 /// signed integers in SIMD form
61 /// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise,
62 /// assumes \p X, \p Y < 2*\p q
63 /// @details See Algorithm 3 of https://arxiv.org/pdf/1205.2926.pdf
64 template <int BitShift, bool InputLessThanMod>
InvButterfly(__m512i * X,__m512i * Y,__m512i W,__m512i W_precon,__m512i neg_modulus,__m512i twice_modulus)65 inline void InvButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon,
66 __m512i neg_modulus, __m512i twice_modulus) {
67 __m512i Y_minus_2q = _mm512_sub_epi64(*Y, twice_modulus);
68 __m512i T = _mm512_sub_epi64(*X, Y_minus_2q);
69
70 if (InputLessThanMod) {
71 // No need for modulus reduction, since inputs are in [0, q)
72 *X = _mm512_add_epi64(*X, *Y);
73 } else {
74 *X = _mm512_add_epi64(*X, Y_minus_2q);
75 __mmask8 sign_bits = _mm512_movepi64_mask(*X);
76 *X = _mm512_mask_add_epi64(*X, sign_bits, *X, twice_modulus);
77 }
78
79 if (BitShift == 32) {
80 __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, T);
81 Q = _mm512_srli_epi64(Q, 32);
82 __m512i Q_p = _mm512_hexl_mullo_epi<64>(Q, neg_modulus);
83 *Y = _mm512_hexl_mullo_add_lo_epi<64>(Q_p, W, T);
84 } else if (BitShift == 52) {
85 __m512i Q = _mm512_hexl_mulhi_epi<BitShift>(W_precon, T);
86 __m512i Q_p = _mm512_hexl_mullo_epi<BitShift>(Q, neg_modulus);
87 *Y = _mm512_hexl_mullo_add_lo_epi<BitShift>(Q_p, W, T);
88 } else if (BitShift == 64) {
89 // Perform approximate computation of Q, as described in page 7 of
90 // https://arxiv.org/pdf/2003.04510.pdf
91 __m512i Q = _mm512_hexl_mulhi_approx_epi<BitShift>(W_precon, T);
92 __m512i Q_p = _mm512_hexl_mullo_epi<BitShift>(Q, neg_modulus);
93 // Compute Y in range [0, 4q)
94 *Y = _mm512_hexl_mullo_add_lo_epi<BitShift>(Q_p, W, T);
95 // Reduce Y to range [0, 2q)
96 *Y = _mm512_hexl_small_mod_epu64<2>(*Y, twice_modulus);
97 } else {
98 HEXL_CHECK(false, "Invalid BitShift " << BitShift);
99 }
100 }
101
102 template <int BitShift, bool InputLessThanMod>
InvT1(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)103 void InvT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
104 uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
105 const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
106 const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
107 size_t j1 = 0;
108
109 // 8 | m guaranteed by n >= 16
110 HEXL_LOOP_UNROLL_8
111 for (size_t i = m / 8; i > 0; --i) {
112 uint64_t* X = operand + j1;
113 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
114
115 __m512i v_X;
116 __m512i v_Y;
117 LoadInvInterleavedT1(X, &v_X, &v_Y);
118
119 __m512i v_W = _mm512_loadu_si512(v_W_pt++);
120 __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
121
122 InvButterfly<BitShift, InputLessThanMod>(&v_X, &v_Y, v_W, v_W_precon,
123 v_neg_modulus, v_twice_mod);
124
125 _mm512_storeu_si512(v_X_pt++, v_X);
126 _mm512_storeu_si512(v_X_pt, v_Y);
127
128 j1 += 16;
129 }
130 }
131
132 template <int BitShift>
InvT2(uint64_t * X,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)133 void InvT2(uint64_t* X, __m512i v_neg_modulus, __m512i v_twice_mod, uint64_t m,
134 const uint64_t* W, const uint64_t* W_precon) {
135 // 4 | m guaranteed by n >= 16
136 HEXL_LOOP_UNROLL_4
137 for (size_t i = m / 4; i > 0; --i) {
138 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
139
140 __m512i v_X;
141 __m512i v_Y;
142 LoadInvInterleavedT2(X, &v_X, &v_Y);
143
144 __m512i v_W = LoadWOpT2(static_cast<const void*>(W));
145 __m512i v_W_precon = LoadWOpT2(static_cast<const void*>(W_precon));
146
147 InvButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
148 v_twice_mod);
149
150 _mm512_storeu_si512(v_X_pt++, v_X);
151 _mm512_storeu_si512(v_X_pt, v_Y);
152 X += 16;
153
154 W += 4;
155 W_precon += 4;
156 }
157 }
158
159 template <int BitShift>
InvT4(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)160 void InvT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
161 uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
162 uint64_t* X = operand;
163
164 // 2 | m guaranteed by n >= 16
165 HEXL_LOOP_UNROLL_4
166 for (size_t i = m / 2; i > 0; --i) {
167 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
168
169 __m512i v_X;
170 __m512i v_Y;
171 LoadInvInterleavedT4(X, &v_X, &v_Y);
172
173 __m512i v_W = LoadWOpT4(static_cast<const void*>(W));
174 __m512i v_W_precon = LoadWOpT4(static_cast<const void*>(W_precon));
175
176 InvButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
177 v_twice_mod);
178
179 WriteInvInterleavedT4(v_X, v_Y, v_X_pt);
180 X += 16;
181
182 W += 2;
183 W_precon += 2;
184 }
185 }
186
187 template <int BitShift>
InvT8(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t t,uint64_t m,const uint64_t * W,const uint64_t * W_precon)188 void InvT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
189 uint64_t t, uint64_t m, const uint64_t* W,
190 const uint64_t* W_precon) {
191 size_t j1 = 0;
192
193 HEXL_LOOP_UNROLL_4
194 for (size_t i = 0; i < m; i++) {
195 uint64_t* X = operand + j1;
196 uint64_t* Y = X + t;
197
198 __m512i v_W = _mm512_set1_epi64(static_cast<int64_t>(*W++));
199 __m512i v_W_precon = _mm512_set1_epi64(static_cast<int64_t>(*W_precon++));
200
201 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
202 __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y);
203
204 // assume 8 | t
205 for (size_t j = t / 8; j > 0; --j) {
206 __m512i v_X = _mm512_loadu_si512(v_X_pt);
207 __m512i v_Y = _mm512_loadu_si512(v_Y_pt);
208
209 InvButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
210 v_twice_mod);
211
212 _mm512_storeu_si512(v_X_pt++, v_X);
213 _mm512_storeu_si512(v_Y_pt++, v_Y);
214 }
215 j1 += (t << 1);
216 }
217 }
218
219 template <int BitShift>
InverseTransformFromBitReverseAVX512(uint64_t * result,const uint64_t * operand,uint64_t n,uint64_t modulus,const uint64_t * inv_root_of_unity_powers,const uint64_t * precon_inv_root_of_unity_powers,uint64_t input_mod_factor,uint64_t output_mod_factor,uint64_t recursion_depth,uint64_t recursion_half)220 void InverseTransformFromBitReverseAVX512(
221 uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus,
222 const uint64_t* inv_root_of_unity_powers,
223 const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
224 uint64_t output_mod_factor, uint64_t recursion_depth,
225 uint64_t recursion_half) {
226 HEXL_CHECK(NTT::CheckArguments(n, modulus), "");
227 HEXL_CHECK(n >= 16,
228 "InverseTransformFromBitReverseAVX512 doesn't support small "
229 "transforms. Need n >= 16, got n = "
230 << n);
231 HEXL_CHECK(modulus < NTT::s_max_inv_modulus(BitShift),
232 "modulus " << modulus << " too large for BitShift " << BitShift
233 << " => maximum value "
234 << NTT::s_max_inv_modulus(BitShift));
235 HEXL_CHECK_BOUNDS(precon_inv_root_of_unity_powers, n, MaximumValue(BitShift),
236 "precon_inv_root_of_unity_powers too large");
237 HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large");
238 // Skip input bound checking for recursive steps
239 HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0,
240 input_mod_factor * modulus,
241 "operand larger than input_mod_factor * modulus ("
242 << input_mod_factor << " * " << modulus << ")");
243 HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2,
244 "input_mod_factor must be 1 or 2; got " << input_mod_factor);
245 HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2,
246 "output_mod_factor must be 1 or 2; got " << output_mod_factor);
247
248 uint64_t twice_mod = modulus << 1;
249 __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
250 __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
251 __m512i v_twice_mod = _mm512_set1_epi64(static_cast<int64_t>(twice_mod));
252
253 size_t t = 1;
254 size_t m = (n >> 1);
255 size_t W_idx = 1 + m * recursion_half;
256
257 static const size_t base_ntt_size = 1024;
258
259 if (n <= base_ntt_size) { // Perform breadth-first InvNTT
260 if (operand != result) {
261 std::memcpy(result, operand, n * sizeof(uint64_t));
262 }
263
264 // Extract t=1, t=2, t=4 loops separately
265 {
266 // t = 1
267 const uint64_t* W = &inv_root_of_unity_powers[W_idx];
268 const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx];
269 if ((input_mod_factor == 1) && (recursion_depth == 0)) {
270 InvT1<BitShift, true>(result, v_neg_modulus, v_twice_mod, m, W,
271 W_precon);
272 } else {
273 InvT1<BitShift, false>(result, v_neg_modulus, v_twice_mod, m, W,
274 W_precon);
275 }
276
277 t <<= 1;
278 m >>= 1;
279 uint64_t W_idx_delta =
280 m * ((1ULL << (recursion_depth + 1)) - recursion_half);
281 W_idx += W_idx_delta;
282
283 // t = 2
284 W = &inv_root_of_unity_powers[W_idx];
285 W_precon = &precon_inv_root_of_unity_powers[W_idx];
286 InvT2<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
287
288 t <<= 1;
289 m >>= 1;
290 W_idx_delta >>= 1;
291 W_idx += W_idx_delta;
292
293 // t = 4
294 W = &inv_root_of_unity_powers[W_idx];
295 W_precon = &precon_inv_root_of_unity_powers[W_idx];
296 InvT4<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
297 t <<= 1;
298 m >>= 1;
299 W_idx_delta >>= 1;
300 W_idx += W_idx_delta;
301
302 // t >= 8
303 for (; m > 1;) {
304 W = &inv_root_of_unity_powers[W_idx];
305 W_precon = &precon_inv_root_of_unity_powers[W_idx];
306 InvT8<BitShift>(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon);
307 t <<= 1;
308 m >>= 1;
309 W_idx_delta >>= 1;
310 W_idx += W_idx_delta;
311 }
312 }
313 } else {
314 InverseTransformFromBitReverseAVX512<BitShift>(
315 result, operand, n / 2, modulus, inv_root_of_unity_powers,
316 precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor,
317 recursion_depth + 1, 2 * recursion_half);
318 InverseTransformFromBitReverseAVX512<BitShift>(
319 &result[n / 2], &operand[n / 2], n / 2, modulus,
320 inv_root_of_unity_powers, precon_inv_root_of_unity_powers,
321 input_mod_factor, output_mod_factor, recursion_depth + 1,
322 2 * recursion_half + 1);
323
324 uint64_t W_idx_delta =
325 m * ((1ULL << (recursion_depth + 1)) - recursion_half);
326 for (; m > 2; m >>= 1) {
327 t <<= 1;
328 W_idx_delta >>= 1;
329 W_idx += W_idx_delta;
330 }
331 if (m == 2) {
332 const uint64_t* W = &inv_root_of_unity_powers[W_idx];
333 const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx];
334 InvT8<BitShift>(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon);
335 t <<= 1;
336 m >>= 1;
337 W_idx_delta >>= 1;
338 W_idx += W_idx_delta;
339 }
340 }
341
342 // Final loop through data
343 if (recursion_depth == 0) {
344 HEXL_VLOG(4, "AVX512 intermediate result "
345 << std::vector<uint64_t>(result, result + n));
346
347 const uint64_t W = inv_root_of_unity_powers[W_idx];
348 MultiplyFactor mf_inv_n(InverseMod(n, modulus), BitShift, modulus);
349 const uint64_t inv_n = mf_inv_n.Operand();
350 const uint64_t inv_n_prime = mf_inv_n.BarrettFactor();
351
352 MultiplyFactor mf_inv_n_w(MultiplyMod(inv_n, W, modulus), BitShift,
353 modulus);
354 const uint64_t inv_n_w = mf_inv_n_w.Operand();
355 const uint64_t inv_n_w_prime = mf_inv_n_w.BarrettFactor();
356
357 HEXL_VLOG(4, "inv_n_w " << inv_n_w);
358
359 uint64_t* X = result;
360 uint64_t* Y = X + (n >> 1);
361
362 __m512i v_inv_n = _mm512_set1_epi64(static_cast<int64_t>(inv_n));
363 __m512i v_inv_n_prime =
364 _mm512_set1_epi64(static_cast<int64_t>(inv_n_prime));
365 __m512i v_inv_n_w = _mm512_set1_epi64(static_cast<int64_t>(inv_n_w));
366 __m512i v_inv_n_w_prime =
367 _mm512_set1_epi64(static_cast<int64_t>(inv_n_w_prime));
368
369 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
370 __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y);
371
372 // Merge final InvNTT loop with modulus reduction baked-in
373 HEXL_LOOP_UNROLL_4
374 for (size_t j = n / 16; j > 0; --j) {
375 __m512i v_X = _mm512_loadu_si512(v_X_pt);
376 __m512i v_Y = _mm512_loadu_si512(v_Y_pt);
377
378 // Slightly different from regular InvButterfly because different W is
379 // used for X and Y
380 __m512i Y_minus_2q = _mm512_sub_epi64(v_Y, v_twice_mod);
381 __m512i X_plus_Y_mod2q =
382 _mm512_hexl_small_add_mod_epi64(v_X, v_Y, v_twice_mod);
383 // T = *X + twice_mod - *Y
384 __m512i T = _mm512_sub_epi64(v_X, Y_minus_2q);
385
386 if (BitShift == 32) {
387 __m512i Q1 = _mm512_hexl_mullo_epi<64>(v_inv_n_prime, X_plus_Y_mod2q);
388 Q1 = _mm512_srli_epi64(Q1, 32);
389 // X = inv_N * X_plus_Y_mod2q - Q1 * modulus;
390 __m512i inv_N_tx = _mm512_hexl_mullo_epi<64>(v_inv_n, X_plus_Y_mod2q);
391 v_X = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_tx, Q1, v_neg_modulus);
392
393 __m512i Q2 = _mm512_hexl_mullo_epi<64>(v_inv_n_w_prime, T);
394 Q2 = _mm512_srli_epi64(Q2, 32);
395
396 // Y = inv_N_W * T - Q2 * modulus;
397 __m512i inv_N_W_T = _mm512_hexl_mullo_epi<64>(v_inv_n_w, T);
398 v_Y = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_W_T, Q2, v_neg_modulus);
399 } else {
400 __m512i Q1 =
401 _mm512_hexl_mulhi_epi<BitShift>(v_inv_n_prime, X_plus_Y_mod2q);
402 // X = inv_N * X_plus_Y_mod2q - Q1 * modulus;
403 __m512i inv_N_tx =
404 _mm512_hexl_mullo_epi<BitShift>(v_inv_n, X_plus_Y_mod2q);
405 v_X =
406 _mm512_hexl_mullo_add_lo_epi<BitShift>(inv_N_tx, Q1, v_neg_modulus);
407
408 __m512i Q2 = _mm512_hexl_mulhi_epi<BitShift>(v_inv_n_w_prime, T);
409 // Y = inv_N_W * T - Q2 * modulus;
410 __m512i inv_N_W_T = _mm512_hexl_mullo_epi<BitShift>(v_inv_n_w, T);
411 v_Y = _mm512_hexl_mullo_add_lo_epi<BitShift>(inv_N_W_T, Q2,
412 v_neg_modulus);
413 }
414
415 if (output_mod_factor == 1) {
416 // Modulus reduction from [0, 2q), to [0, q)
417 v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus);
418 v_Y = _mm512_hexl_small_mod_epu64(v_Y, v_modulus);
419 }
420
421 _mm512_storeu_si512(v_X_pt++, v_X);
422 _mm512_storeu_si512(v_Y_pt++, v_Y);
423 }
424
425 HEXL_VLOG(5, "AVX512 returning result "
426 << std::vector<uint64_t>(result, result + n));
427 }
428 }
429
430 #endif // HEXL_HAS_AVX512DQ
431
432 } // namespace hexl
433 } // namespace intel
434