1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #include "ntt/fwd-ntt-avx512.hpp"
5 
6 #include <cstring>
7 #include <functional>
8 #include <vector>
9 
10 #include "hexl/logging/logging.hpp"
11 #include "hexl/ntt/ntt.hpp"
12 #include "hexl/number-theory/number-theory.hpp"
13 #include "ntt/ntt-avx512-util.hpp"
14 #include "ntt/ntt-internal.hpp"
15 #include "util/avx512-util.hpp"
16 
17 namespace intel {
18 namespace hexl {
19 
20 #ifdef HEXL_HAS_AVX512IFMA
21 template void ForwardTransformToBitReverseAVX512<NTT::s_ifma_shift_bits>(
22     uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod,
23     const uint64_t* root_of_unity_powers,
24     const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
25     uint64_t output_mod_factor, uint64_t recursion_depth,
26     uint64_t recursion_half);
27 #endif
28 
29 #ifdef HEXL_HAS_AVX512DQ
30 template void ForwardTransformToBitReverseAVX512<32>(
31     uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod,
32     const uint64_t* root_of_unity_powers,
33     const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
34     uint64_t output_mod_factor, uint64_t recursion_depth,
35     uint64_t recursion_half);
36 
37 template void ForwardTransformToBitReverseAVX512<NTT::s_default_shift_bits>(
38     uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod,
39     const uint64_t* root_of_unity_powers,
40     const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
41     uint64_t output_mod_factor, uint64_t recursion_depth,
42     uint64_t recursion_half);
43 #endif
44 
45 #ifdef HEXL_HAS_AVX512DQ
46 
47 /// @brief The Harvey butterfly: assume \p X, \p Y in [0, 4q), and return X', Y'
48 /// in [0, 4q) such that X', Y' = X + WY, X - WY (mod q).
49 /// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form
50 /// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form
51 /// @param[in] W Root of unity represented as 8 64-bit signed integers in
52 /// SIMD form
53 /// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett
54 /// reduction
55 /// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit
56 /// signed integers in SIMD form
57 /// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit
58 /// signed integers in SIMD form
59 /// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise,
60 /// assumes \p X, \p Y < 4*\p q
61 /// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf
62 template <int BitShift, bool InputLessThanMod>
FwdButterfly(__m512i * X,__m512i * Y,__m512i W,__m512i W_precon,__m512i neg_modulus,__m512i twice_modulus)63 void FwdButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon,
64                   __m512i neg_modulus, __m512i twice_modulus) {
65   if (!InputLessThanMod) {
66     *X = _mm512_hexl_small_mod_epu64(*X, twice_modulus);
67   }
68 
69   __m512i T;
70   if (BitShift == 32) {
71     __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, *Y);
72     Q = _mm512_srli_epi64(Q, 32);
73     __m512i W_Y = _mm512_hexl_mullo_epi<64>(W, *Y);
74     T = _mm512_hexl_mullo_add_lo_epi<64>(W_Y, Q, neg_modulus);
75   } else if (BitShift == 52) {
76     __m512i Q = _mm512_hexl_mulhi_epi<BitShift>(W_precon, *Y);
77     __m512i W_Y = _mm512_hexl_mullo_epi<BitShift>(W, *Y);
78     T = _mm512_hexl_mullo_add_lo_epi<BitShift>(W_Y, Q, neg_modulus);
79   } else if (BitShift == 64) {
80     // Perform approximate computation of Q, as described in page 7 of
81     // https://arxiv.org/pdf/2003.04510.pdf
82     __m512i Q = _mm512_hexl_mulhi_approx_epi<BitShift>(W_precon, *Y);
83     __m512i W_Y = _mm512_hexl_mullo_epi<BitShift>(W, *Y);
84     // Compute T in range [0, 4q)
85     T = _mm512_hexl_mullo_add_lo_epi<BitShift>(W_Y, Q, neg_modulus);
86     // Reduce T to range [0, 2q)
87     T = _mm512_hexl_small_mod_epu64<2>(T, twice_modulus);
88   } else {
89     HEXL_CHECK(false, "Invalid BitShift " << BitShift);
90   }
91 
92   __m512i twice_mod_minus_T = _mm512_sub_epi64(twice_modulus, T);
93   *Y = _mm512_add_epi64(*X, twice_mod_minus_T);
94   *X = _mm512_add_epi64(*X, T);
95 }
96 
97 template <int BitShift>
FwdT1(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)98 void FwdT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
99            uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
100   const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
101   const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
102   size_t j1 = 0;
103 
104   // 8 | m guaranteed by n >= 16
105   HEXL_LOOP_UNROLL_8
106   for (size_t i = m / 8; i > 0; --i) {
107     uint64_t* X = operand + j1;
108     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
109 
110     __m512i v_X;
111     __m512i v_Y;
112     LoadFwdInterleavedT1(X, &v_X, &v_Y);
113     __m512i v_W = _mm512_loadu_si512(v_W_pt++);
114     __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
115 
116     FwdButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
117                                   v_twice_mod);
118     WriteFwdInterleavedT1(v_X, v_Y, v_X_pt);
119 
120     j1 += 16;
121   }
122 }
123 
124 template <int BitShift>
FwdT2(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)125 void FwdT2(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
126            uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
127   const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
128   const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
129 
130   size_t j1 = 0;
131   // 4 | m guaranteed by n >= 16
132   HEXL_LOOP_UNROLL_4
133   for (size_t i = m / 4; i > 0; --i) {
134     uint64_t* X = operand + j1;
135     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
136 
137     __m512i v_X;
138     __m512i v_Y;
139     LoadFwdInterleavedT2(X, &v_X, &v_Y);
140 
141     __m512i v_W = _mm512_loadu_si512(v_W_pt++);
142     __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
143 
144     HEXL_CHECK(ExtractValues(v_W)[0] == ExtractValues(v_W)[1],
145                "bad v_W " << ExtractValues(v_W));
146     HEXL_CHECK(ExtractValues(v_W_precon)[0] == ExtractValues(v_W_precon)[1],
147                "bad v_W_precon " << ExtractValues(v_W_precon));
148     FwdButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
149                                   v_twice_mod);
150 
151     _mm512_storeu_si512(v_X_pt++, v_X);
152     _mm512_storeu_si512(v_X_pt, v_Y);
153 
154     j1 += 16;
155   }
156 }
157 
158 template <int BitShift>
FwdT4(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)159 void FwdT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
160            uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
161   size_t j1 = 0;
162   const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
163   const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
164 
165   // 2 | m guaranteed by n >= 16
166   HEXL_LOOP_UNROLL_4
167   for (size_t i = m / 2; i > 0; --i) {
168     uint64_t* X = operand + j1;
169     __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
170 
171     __m512i v_X;
172     __m512i v_Y;
173     LoadFwdInterleavedT4(X, &v_X, &v_Y);
174 
175     __m512i v_W = _mm512_loadu_si512(v_W_pt++);
176     __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
177     FwdButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
178                                   v_twice_mod);
179 
180     _mm512_storeu_si512(v_X_pt++, v_X);
181     _mm512_storeu_si512(v_X_pt, v_Y);
182 
183     j1 += 16;
184   }
185 }
186 
187 // Out-of-place implementation
188 template <int BitShift, bool InputLessThanMod>
FwdT8(uint64_t * result,const uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t t,uint64_t m,const uint64_t * W,const uint64_t * W_precon)189 void FwdT8(uint64_t* result, const uint64_t* operand, __m512i v_neg_modulus,
190            __m512i v_twice_mod, uint64_t t, uint64_t m, const uint64_t* W,
191            const uint64_t* W_precon) {
192   size_t j1 = 0;
193 
194   HEXL_LOOP_UNROLL_4
195   for (size_t i = 0; i < m; i++) {
196     // Referencing operand
197     const uint64_t* X_op = operand + j1;
198     const uint64_t* Y_op = X_op + t;
199 
200     const __m512i* v_X_op_pt = reinterpret_cast<const __m512i*>(X_op);
201     const __m512i* v_Y_op_pt = reinterpret_cast<const __m512i*>(Y_op);
202 
203     // Referencing result
204     uint64_t* X_r = result + j1;
205     uint64_t* Y_r = X_r + t;
206 
207     __m512i* v_X_r_pt = reinterpret_cast<__m512i*>(X_r);
208     __m512i* v_Y_r_pt = reinterpret_cast<__m512i*>(Y_r);
209 
210     // Weights and weights' preconditions
211     __m512i v_W = _mm512_set1_epi64(static_cast<int64_t>(*W++));
212     __m512i v_W_precon = _mm512_set1_epi64(static_cast<int64_t>(*W_precon++));
213 
214     // assume 8 | t
215     for (size_t j = t / 8; j > 0; --j) {
216       __m512i v_X = _mm512_loadu_si512(v_X_op_pt);
217       __m512i v_Y = _mm512_loadu_si512(v_Y_op_pt);
218 
219       FwdButterfly<BitShift, InputLessThanMod>(&v_X, &v_Y, v_W, v_W_precon,
220                                                v_neg_modulus, v_twice_mod);
221 
222       _mm512_storeu_si512(v_X_r_pt++, v_X);
223       _mm512_storeu_si512(v_Y_r_pt++, v_Y);
224 
225       // Increase operand pointers as well
226       v_X_op_pt++;
227       v_Y_op_pt++;
228     }
229     j1 += (t << 1);
230   }
231 }
232 
233 template <int BitShift>
ForwardTransformToBitReverseAVX512(uint64_t * result,const uint64_t * operand,uint64_t n,uint64_t modulus,const uint64_t * root_of_unity_powers,const uint64_t * precon_root_of_unity_powers,uint64_t input_mod_factor,uint64_t output_mod_factor,uint64_t recursion_depth,uint64_t recursion_half)234 void ForwardTransformToBitReverseAVX512(
235     uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus,
236     const uint64_t* root_of_unity_powers,
237     const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
238     uint64_t output_mod_factor, uint64_t recursion_depth,
239     uint64_t recursion_half) {
240   HEXL_CHECK(NTT::CheckArguments(n, modulus), "");
241   HEXL_CHECK(modulus < NTT::s_max_fwd_modulus(BitShift),
242              "modulus " << modulus << " too large for BitShift " << BitShift
243                         << " => maximum value "
244                         << NTT::s_max_fwd_modulus(BitShift));
245   HEXL_CHECK_BOUNDS(precon_root_of_unity_powers, n, MaximumValue(BitShift),
246                     "precon_root_of_unity_powers too large");
247   HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large");
248   // Skip input bound checking for recursive steps
249   HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0,
250                     input_mod_factor * modulus,
251                     "operand larger than input_mod_factor * modulus ("
252                         << input_mod_factor << " * " << modulus << ")");
253   HEXL_CHECK(n >= 16,
254              "Don't support small transforms. Need n >= 16, got n = " << n);
255   HEXL_CHECK(
256       input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4,
257       "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor);
258   HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4,
259              "output_mod_factor must be 1 or 4; got " << output_mod_factor);
260 
261   uint64_t twice_mod = modulus << 1;
262 
263   __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
264   __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
265   __m512i v_twice_mod = _mm512_set1_epi64(static_cast<int64_t>(twice_mod));
266 
267   HEXL_VLOG(5, "root_of_unity_powers " << std::vector<uint64_t>(
268                    root_of_unity_powers, root_of_unity_powers + n))
269   HEXL_VLOG(5,
270             "precon_root_of_unity_powers " << std::vector<uint64_t>(
271                 precon_root_of_unity_powers, precon_root_of_unity_powers + n));
272   HEXL_VLOG(5, "operand " << std::vector<uint64_t>(operand, operand + n));
273 
274   static const size_t base_ntt_size = 1024;
275 
276   if (n <= base_ntt_size) {  // Perform breadth-first NTT
277     size_t t = (n >> 1);
278     size_t m = 1;
279     size_t W_idx = (m << recursion_depth) + (recursion_half * m);
280 
281     // Copy for out-of-place in case m is <= base_ntt_size from start
282     if (result != operand) {
283       std::memcpy(result, operand, n * sizeof(uint64_t));
284     }
285 
286     // First iteration assumes input in [0,p)
287     if (m < (n >> 3)) {
288       const uint64_t* W = &root_of_unity_powers[W_idx];
289       const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];
290 
291       if ((input_mod_factor <= 2) && (recursion_depth == 0)) {
292         FwdT8<BitShift, true>(result, result, v_neg_modulus, v_twice_mod, t, m,
293                               W, W_precon);
294       } else {
295         FwdT8<BitShift, false>(result, result, v_neg_modulus, v_twice_mod, t, m,
296                                W, W_precon);
297       }
298 
299       t >>= 1;
300       m <<= 1;
301       W_idx <<= 1;
302     }
303     for (; m < (n >> 3); m <<= 1) {
304       const uint64_t* W = &root_of_unity_powers[W_idx];
305       const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];
306       FwdT8<BitShift, false>(result, result, v_neg_modulus, v_twice_mod, t, m,
307                              W, W_precon);
308       t >>= 1;
309       W_idx <<= 1;
310     }
311 
312     // Do T=4, T=2, T=1 separately
313     {
314       // Correction step needed due to extra copies of roots of unity in the
315       // AVX512 vectors loaded for FwdT2 and FwdT4
316       auto compute_new_W_idx = [&](size_t idx) {
317         // Originally, from root of unity vector index to loop:
318         // [0, N/8) => FwdT8
319         // [N/8, N/4) => FwdT4
320         // [N/4, N/2) => FwdT2
321         // [N/2, N) => FwdT1
322         // The new mapping from AVX512 root of unity vector index to loop:
323         // [0, N/8) => FwdT8
324         // [N/8, 5N/8) => FwdT4
325         // [5N/8, 9N/8) => FwdT2
326         // [9N/8, 13N/8) => FwdT1
327         size_t N = n << recursion_depth;
328 
329         // FwdT8 range
330         if (idx <= N / 8) {
331           return idx;
332         }
333         // FwdT4 range
334         if (idx <= N / 4) {
335           return (idx - N / 8) * 4 + (N / 8);
336         }
337         // FwdT2 range
338         if (idx <= N / 2) {
339           return (idx - N / 4) * 2 + (5 * N / 8);
340         }
341         // FwdT1 range
342         return idx + (5 * N / 8);
343       };
344 
345       size_t new_W_idx = compute_new_W_idx(W_idx);
346       const uint64_t* W = &root_of_unity_powers[new_W_idx];
347       const uint64_t* W_precon = &precon_root_of_unity_powers[new_W_idx];
348       FwdT4<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
349 
350       m <<= 1;
351       W_idx <<= 1;
352       new_W_idx = compute_new_W_idx(W_idx);
353       W = &root_of_unity_powers[new_W_idx];
354       W_precon = &precon_root_of_unity_powers[new_W_idx];
355       FwdT2<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
356 
357       m <<= 1;
358       W_idx <<= 1;
359       new_W_idx = compute_new_W_idx(W_idx);
360       W = &root_of_unity_powers[new_W_idx];
361       W_precon = &precon_root_of_unity_powers[new_W_idx];
362       FwdT1<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
363     }
364 
365     if (output_mod_factor == 1) {
366       // n power of two at least 8 => n divisible by 8
367       HEXL_CHECK(n % 8 == 0, "n " << n << " not a power of 2");
368       __m512i* v_X_pt = reinterpret_cast<__m512i*>(result);
369       for (size_t i = 0; i < n; i += 8) {
370         __m512i v_X = _mm512_loadu_si512(v_X_pt);
371 
372         // Reduce from [0, 4q) to [0, q)
373         v_X = _mm512_hexl_small_mod_epu64(v_X, v_twice_mod);
374         v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus);
375 
376         HEXL_CHECK_BOUNDS(ExtractValues(v_X).data(), 8, modulus,
377                           "v_X exceeds bound " << modulus);
378 
379         _mm512_storeu_si512(v_X_pt, v_X);
380 
381         ++v_X_pt;
382       }
383     }
384   } else {
385     // Perform depth-first NTT via recursive call
386     size_t t = (n >> 1);
387     size_t W_idx = (1ULL << recursion_depth) + recursion_half;
388     const uint64_t* W = &root_of_unity_powers[W_idx];
389     const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];
390 
391     FwdT8<BitShift, false>(result, operand, v_neg_modulus, v_twice_mod, t, 1, W,
392                            W_precon);
393 
394     ForwardTransformToBitReverseAVX512<BitShift>(
395         result, result, n / 2, modulus, root_of_unity_powers,
396         precon_root_of_unity_powers, input_mod_factor, output_mod_factor,
397         recursion_depth + 1, recursion_half * 2);
398 
399     ForwardTransformToBitReverseAVX512<BitShift>(
400         &result[n / 2], &result[n / 2], n / 2, modulus, root_of_unity_powers,
401         precon_root_of_unity_powers, input_mod_factor, output_mod_factor,
402         recursion_depth + 1, recursion_half * 2 + 1);
403   }
404 }
405 
406 #endif  // HEXL_HAS_AVX512DQ
407 
408 }  // namespace hexl
409 }  // namespace intel
410