1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #include "ntt/inv-ntt-avx512.hpp"
5 
6 #include <immintrin.h>
7 
8 #include <cstring>
9 #include <functional>
10 #include <vector>
11 
12 #include "hexl/logging/logging.hpp"
13 #include "hexl/ntt/ntt.hpp"
14 #include "hexl/number-theory/number-theory.hpp"
15 #include "ntt/ntt-avx512-util.hpp"
16 #include "ntt/ntt-internal.hpp"
17 #include "util/avx512-util.hpp"
18 
19 namespace intel {
20 namespace hexl {
21 
22 #ifdef HEXL_HAS_AVX512IFMA
23 template void InverseTransformFromBitReverseAVX512<NTT::s_ifma_shift_bits>(
24     uint64_t* result, const uint64_t* operand, uint64_t degree,
25     uint64_t modulus, const uint64_t* inv_root_of_unity_powers,
26     const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
27     uint64_t output_mod_factor, uint64_t recursion_depth,
28     uint64_t recursion_half);
29 #endif
30 
31 #ifdef HEXL_HAS_AVX512DQ
32 template void InverseTransformFromBitReverseAVX512<32>(
33     uint64_t* result, const uint64_t* operand, uint64_t degree,
34     uint64_t modulus, const uint64_t* inv_root_of_unity_powers,
35     const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
36     uint64_t output_mod_factor, uint64_t recursion_depth,
37     uint64_t recursion_half);
38 
39 template void InverseTransformFromBitReverseAVX512<NTT::s_default_shift_bits>(
40     uint64_t* result, const uint64_t* operand, uint64_t degree,
41     uint64_t modulus, const uint64_t* inv_root_of_unity_powers,
42     const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
43     uint64_t output_mod_factor, uint64_t recursion_depth,
44     uint64_t recursion_half);
45 #endif
46 
47 #ifdef HEXL_HAS_AVX512DQ
48 
49 /// @brief The Harvey butterfly: assume X, Y in [0, 2q), and return X', Y' in
50 /// [0, 2q). such that X', Y' = X + Y (mod q), W(X - Y) (mod q).
51 /// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form
52 /// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form
53 /// @param[in] W Root of unity representing 8 64-bit signed integers in SIMD
54 /// form
55 /// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett
56 /// reduction
57 /// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit
58 /// signed integers in SIMD form
59 /// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit
60 /// signed integers in SIMD form
61 /// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise,
62 /// assumes \p X, \p Y < 2*\p q
63 /// @details See Algorithm 3 of https://arxiv.org/pdf/1205.2926.pdf
64 template <int BitShift, bool InputLessThanMod>
InvButterfly(__m512i * X,__m512i * Y,__m512i W,__m512i W_precon,__m512i neg_modulus,__m512i twice_modulus)65 inline void InvButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon,
66                          __m512i neg_modulus, __m512i twice_modulus) {
67   __m512i Y_minus_2q = _mm512_sub_epi64(*Y, twice_modulus);
68   __m512i T = _mm512_sub_epi64(*X, Y_minus_2q);
69 
70   if (InputLessThanMod) {
71     // No need for modulus reduction, since inputs are in [0, q)
72     *X = _mm512_add_epi64(*X, *Y);
73   } else {
74     *X = _mm512_add_epi64(*X, Y_minus_2q);
75     __mmask8 sign_bits = _mm512_movepi64_mask(*X);
76     *X = _mm512_mask_add_epi64(*X, sign_bits, *X, twice_modulus);
77   }
78 
79   if (BitShift == 32) {
80     __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, T);
81     Q = _mm512_srli_epi64(Q, 32);
82     __m512i Q_p = _mm512_hexl_mullo_epi<64>(Q, neg_modulus);
83     *Y = _mm512_hexl_mullo_add_lo_epi<64>(Q_p, W, T);
84   } else if (BitShift == 52) {
85     __m512i Q = _mm512_hexl_mulhi_epi<BitShift>(W_precon, T);
86     __m512i Q_p = _mm512_hexl_mullo_epi<BitShift>(Q, neg_modulus);
87     *Y = _mm512_hexl_mullo_add_lo_epi<BitShift>(Q_p, W, T);
88   } else if (BitShift == 64) {
89     // Perform approximate computation of Q, as described in page 7 of
90     // https://arxiv.org/pdf/2003.04510.pdf
91     __m512i Q = _mm512_hexl_mulhi_approx_epi<BitShift>(W_precon, T);
92     __m512i Q_p = _mm512_hexl_mullo_epi<BitShift>(Q, neg_modulus);
93     // Compute Y in range [0, 4q)
94     *Y = _mm512_hexl_mullo_add_lo_epi<BitShift>(Q_p, W, T);
95     // Reduce Y to range [0, 2q)
96     *Y = _mm512_hexl_small_mod_epu64<2>(*Y, twice_modulus);
97   } else {
98     HEXL_CHECK(false, "Invalid BitShift " << BitShift);
99   }
100 }
101 
102 template <int BitShift, bool InputLessThanMod>
InvT1(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)103 void InvT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
104            uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
105   const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
106   const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
107   size_t j1 = 0;
108 
109   // 8 | m guaranteed by n >= 16
110   HEXL_LOOP_UNROLL_8
111   for (size_t i = m / 8; i > 0; --i) {
112     uint64_t* X = operand + j1;
113     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
114 
115     __m512i v_X;
116     __m512i v_Y;
117     LoadInvInterleavedT1(X, &v_X, &v_Y);
118 
119     __m512i v_W = _mm512_loadu_si512(v_W_pt++);
120     __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
121 
122     InvButterfly<BitShift, InputLessThanMod>(&v_X, &v_Y, v_W, v_W_precon,
123                                              v_neg_modulus, v_twice_mod);
124 
125     _mm512_storeu_si512(v_X_pt++, v_X);
126     _mm512_storeu_si512(v_X_pt, v_Y);
127 
128     j1 += 16;
129   }
130 }
131 
132 template <int BitShift>
InvT2(uint64_t * X,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)133 void InvT2(uint64_t* X, __m512i v_neg_modulus, __m512i v_twice_mod, uint64_t m,
134            const uint64_t* W, const uint64_t* W_precon) {
135   // 4 | m guaranteed by n >= 16
136   HEXL_LOOP_UNROLL_4
137   for (size_t i = m / 4; i > 0; --i) {
138     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
139 
140     __m512i v_X;
141     __m512i v_Y;
142     LoadInvInterleavedT2(X, &v_X, &v_Y);
143 
144     __m512i v_W = LoadWOpT2(static_cast<const void*>(W));
145     __m512i v_W_precon = LoadWOpT2(static_cast<const void*>(W_precon));
146 
147     InvButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
148                                   v_twice_mod);
149 
150     _mm512_storeu_si512(v_X_pt++, v_X);
151     _mm512_storeu_si512(v_X_pt, v_Y);
152     X += 16;
153 
154     W += 4;
155     W_precon += 4;
156   }
157 }
158 
159 template <int BitShift>
InvT4(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)160 void InvT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
161            uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
162   uint64_t* X = operand;
163 
164   // 2 | m guaranteed by n >= 16
165   HEXL_LOOP_UNROLL_4
166   for (size_t i = m / 2; i > 0; --i) {
167     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
168 
169     __m512i v_X;
170     __m512i v_Y;
171     LoadInvInterleavedT4(X, &v_X, &v_Y);
172 
173     __m512i v_W = LoadWOpT4(static_cast<const void*>(W));
174     __m512i v_W_precon = LoadWOpT4(static_cast<const void*>(W_precon));
175 
176     InvButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
177                                   v_twice_mod);
178 
179     WriteInvInterleavedT4(v_X, v_Y, v_X_pt);
180     X += 16;
181 
182     W += 2;
183     W_precon += 2;
184   }
185 }
186 
187 template <int BitShift>
InvT8(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t t,uint64_t m,const uint64_t * W,const uint64_t * W_precon)188 void InvT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
189            uint64_t t, uint64_t m, const uint64_t* W,
190            const uint64_t* W_precon) {
191   size_t j1 = 0;
192 
193   HEXL_LOOP_UNROLL_4
194   for (size_t i = 0; i < m; i++) {
195     uint64_t* X = operand + j1;
196     uint64_t* Y = X + t;
197 
198     __m512i v_W = _mm512_set1_epi64(static_cast<int64_t>(*W++));
199     __m512i v_W_precon = _mm512_set1_epi64(static_cast<int64_t>(*W_precon++));
200 
201     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
202     __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y);
203 
204     // assume 8 | t
205     for (size_t j = t / 8; j > 0; --j) {
206       __m512i v_X = _mm512_loadu_si512(v_X_pt);
207       __m512i v_Y = _mm512_loadu_si512(v_Y_pt);
208 
209       InvButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
210                                     v_twice_mod);
211 
212       _mm512_storeu_si512(v_X_pt++, v_X);
213       _mm512_storeu_si512(v_Y_pt++, v_Y);
214     }
215     j1 += (t << 1);
216   }
217 }
218 
219 template <int BitShift>
InverseTransformFromBitReverseAVX512(uint64_t * result,const uint64_t * operand,uint64_t n,uint64_t modulus,const uint64_t * inv_root_of_unity_powers,const uint64_t * precon_inv_root_of_unity_powers,uint64_t input_mod_factor,uint64_t output_mod_factor,uint64_t recursion_depth,uint64_t recursion_half)220 void InverseTransformFromBitReverseAVX512(
221     uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus,
222     const uint64_t* inv_root_of_unity_powers,
223     const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor,
224     uint64_t output_mod_factor, uint64_t recursion_depth,
225     uint64_t recursion_half) {
226   HEXL_CHECK(NTT::CheckArguments(n, modulus), "");
227   HEXL_CHECK(n >= 16,
228              "InverseTransformFromBitReverseAVX512 doesn't support small "
229              "transforms. Need n >= 16, got n = "
230                  << n);
231   HEXL_CHECK(modulus < NTT::s_max_inv_modulus(BitShift),
232              "modulus " << modulus << " too large for BitShift " << BitShift
233                         << " => maximum value "
234                         << NTT::s_max_inv_modulus(BitShift));
235   HEXL_CHECK_BOUNDS(precon_inv_root_of_unity_powers, n, MaximumValue(BitShift),
236                     "precon_inv_root_of_unity_powers too large");
237   HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large");
238   // Skip input bound checking for recursive steps
239   HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0,
240                     input_mod_factor * modulus,
241                     "operand larger than input_mod_factor * modulus ("
242                         << input_mod_factor << " * " << modulus << ")");
243   HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2,
244              "input_mod_factor must be 1 or 2; got " << input_mod_factor);
245   HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2,
246              "output_mod_factor must be 1 or 2; got " << output_mod_factor);
247 
248   uint64_t twice_mod = modulus << 1;
249   __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
250   __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
251   __m512i v_twice_mod = _mm512_set1_epi64(static_cast<int64_t>(twice_mod));
252 
253   size_t t = 1;
254   size_t m = (n >> 1);
255   size_t W_idx = 1 + m * recursion_half;
256 
257   static const size_t base_ntt_size = 1024;
258 
259   if (n <= base_ntt_size) {  // Perform breadth-first InvNTT
260     if (operand != result) {
261       std::memcpy(result, operand, n * sizeof(uint64_t));
262     }
263 
264     // Extract t=1, t=2, t=4 loops separately
265     {
266       // t = 1
267       const uint64_t* W = &inv_root_of_unity_powers[W_idx];
268       const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx];
269       if ((input_mod_factor == 1) && (recursion_depth == 0)) {
270         InvT1<BitShift, true>(result, v_neg_modulus, v_twice_mod, m, W,
271                               W_precon);
272       } else {
273         InvT1<BitShift, false>(result, v_neg_modulus, v_twice_mod, m, W,
274                                W_precon);
275       }
276 
277       t <<= 1;
278       m >>= 1;
279       uint64_t W_idx_delta =
280           m * ((1ULL << (recursion_depth + 1)) - recursion_half);
281       W_idx += W_idx_delta;
282 
283       // t = 2
284       W = &inv_root_of_unity_powers[W_idx];
285       W_precon = &precon_inv_root_of_unity_powers[W_idx];
286       InvT2<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
287 
288       t <<= 1;
289       m >>= 1;
290       W_idx_delta >>= 1;
291       W_idx += W_idx_delta;
292 
293       // t = 4
294       W = &inv_root_of_unity_powers[W_idx];
295       W_precon = &precon_inv_root_of_unity_powers[W_idx];
296       InvT4<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
297       t <<= 1;
298       m >>= 1;
299       W_idx_delta >>= 1;
300       W_idx += W_idx_delta;
301 
302       // t >= 8
303       for (; m > 1;) {
304         W = &inv_root_of_unity_powers[W_idx];
305         W_precon = &precon_inv_root_of_unity_powers[W_idx];
306         InvT8<BitShift>(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon);
307         t <<= 1;
308         m >>= 1;
309         W_idx_delta >>= 1;
310         W_idx += W_idx_delta;
311       }
312     }
313   } else {
314     InverseTransformFromBitReverseAVX512<BitShift>(
315         result, operand, n / 2, modulus, inv_root_of_unity_powers,
316         precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor,
317         recursion_depth + 1, 2 * recursion_half);
318     InverseTransformFromBitReverseAVX512<BitShift>(
319         &result[n / 2], &operand[n / 2], n / 2, modulus,
320         inv_root_of_unity_powers, precon_inv_root_of_unity_powers,
321         input_mod_factor, output_mod_factor, recursion_depth + 1,
322         2 * recursion_half + 1);
323 
324     uint64_t W_idx_delta =
325         m * ((1ULL << (recursion_depth + 1)) - recursion_half);
326     for (; m > 2; m >>= 1) {
327       t <<= 1;
328       W_idx_delta >>= 1;
329       W_idx += W_idx_delta;
330     }
331     if (m == 2) {
332       const uint64_t* W = &inv_root_of_unity_powers[W_idx];
333       const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx];
334       InvT8<BitShift>(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon);
335       t <<= 1;
336       m >>= 1;
337       W_idx_delta >>= 1;
338       W_idx += W_idx_delta;
339     }
340   }
341 
342   // Final loop through data
343   if (recursion_depth == 0) {
344     HEXL_VLOG(4, "AVX512 intermediate result "
345                      << std::vector<uint64_t>(result, result + n));
346 
347     const uint64_t W = inv_root_of_unity_powers[W_idx];
348     MultiplyFactor mf_inv_n(InverseMod(n, modulus), BitShift, modulus);
349     const uint64_t inv_n = mf_inv_n.Operand();
350     const uint64_t inv_n_prime = mf_inv_n.BarrettFactor();
351 
352     MultiplyFactor mf_inv_n_w(MultiplyMod(inv_n, W, modulus), BitShift,
353                               modulus);
354     const uint64_t inv_n_w = mf_inv_n_w.Operand();
355     const uint64_t inv_n_w_prime = mf_inv_n_w.BarrettFactor();
356 
357     HEXL_VLOG(4, "inv_n_w " << inv_n_w);
358 
359     uint64_t* X = result;
360     uint64_t* Y = X + (n >> 1);
361 
362     __m512i v_inv_n = _mm512_set1_epi64(static_cast<int64_t>(inv_n));
363     __m512i v_inv_n_prime =
364         _mm512_set1_epi64(static_cast<int64_t>(inv_n_prime));
365     __m512i v_inv_n_w = _mm512_set1_epi64(static_cast<int64_t>(inv_n_w));
366     __m512i v_inv_n_w_prime =
367         _mm512_set1_epi64(static_cast<int64_t>(inv_n_w_prime));
368 
369     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
370     __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y);
371 
372     // Merge final InvNTT loop with modulus reduction baked-in
373     HEXL_LOOP_UNROLL_4
374     for (size_t j = n / 16; j > 0; --j) {
375       __m512i v_X = _mm512_loadu_si512(v_X_pt);
376       __m512i v_Y = _mm512_loadu_si512(v_Y_pt);
377 
378       // Slightly different from regular InvButterfly because different W is
379       // used for X and Y
380       __m512i Y_minus_2q = _mm512_sub_epi64(v_Y, v_twice_mod);
381       __m512i X_plus_Y_mod2q =
382           _mm512_hexl_small_add_mod_epi64(v_X, v_Y, v_twice_mod);
383       // T = *X + twice_mod - *Y
384       __m512i T = _mm512_sub_epi64(v_X, Y_minus_2q);
385 
386       if (BitShift == 32) {
387         __m512i Q1 = _mm512_hexl_mullo_epi<64>(v_inv_n_prime, X_plus_Y_mod2q);
388         Q1 = _mm512_srli_epi64(Q1, 32);
389         // X = inv_N * X_plus_Y_mod2q - Q1 * modulus;
390         __m512i inv_N_tx = _mm512_hexl_mullo_epi<64>(v_inv_n, X_plus_Y_mod2q);
391         v_X = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_tx, Q1, v_neg_modulus);
392 
393         __m512i Q2 = _mm512_hexl_mullo_epi<64>(v_inv_n_w_prime, T);
394         Q2 = _mm512_srli_epi64(Q2, 32);
395 
396         // Y = inv_N_W * T - Q2 * modulus;
397         __m512i inv_N_W_T = _mm512_hexl_mullo_epi<64>(v_inv_n_w, T);
398         v_Y = _mm512_hexl_mullo_add_lo_epi<64>(inv_N_W_T, Q2, v_neg_modulus);
399       } else {
400         __m512i Q1 =
401             _mm512_hexl_mulhi_epi<BitShift>(v_inv_n_prime, X_plus_Y_mod2q);
402         // X = inv_N * X_plus_Y_mod2q - Q1 * modulus;
403         __m512i inv_N_tx =
404             _mm512_hexl_mullo_epi<BitShift>(v_inv_n, X_plus_Y_mod2q);
405         v_X =
406             _mm512_hexl_mullo_add_lo_epi<BitShift>(inv_N_tx, Q1, v_neg_modulus);
407 
408         __m512i Q2 = _mm512_hexl_mulhi_epi<BitShift>(v_inv_n_w_prime, T);
409         // Y = inv_N_W * T - Q2 * modulus;
410         __m512i inv_N_W_T = _mm512_hexl_mullo_epi<BitShift>(v_inv_n_w, T);
411         v_Y = _mm512_hexl_mullo_add_lo_epi<BitShift>(inv_N_W_T, Q2,
412                                                      v_neg_modulus);
413       }
414 
415       if (output_mod_factor == 1) {
416         // Modulus reduction from [0, 2q), to [0, q)
417         v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus);
418         v_Y = _mm512_hexl_small_mod_epu64(v_Y, v_modulus);
419       }
420 
421       _mm512_storeu_si512(v_X_pt++, v_X);
422       _mm512_storeu_si512(v_Y_pt++, v_Y);
423     }
424 
425     HEXL_VLOG(5, "AVX512 returning result "
426                      << std::vector<uint64_t>(result, result + n));
427   }
428 }
429 
430 #endif  // HEXL_HAS_AVX512DQ
431 
432 }  // namespace hexl
433 }  // namespace intel
434