1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #include "ntt/ntt-internal.hpp"
5 
6 #include <cstring>
7 #include <utility>
8 
9 #include "hexl/logging/logging.hpp"
10 #include "hexl/ntt/ntt.hpp"
11 #include "hexl/number-theory/number-theory.hpp"
12 #include "hexl/util/aligned-allocator.hpp"
13 #include "hexl/util/check.hpp"
14 #include "hexl/util/defines.hpp"
15 #include "ntt/fwd-ntt-avx512.hpp"
16 #include "ntt/inv-ntt-avx512.hpp"
17 #include "util/cpu-features.hpp"
18 
19 namespace intel {
20 namespace hexl {
21 
22 AllocatorStrategyPtr mallocStrategy = AllocatorStrategyPtr(new MallocStrategy);
23 
NTT(uint64_t degree,uint64_t q,uint64_t root_of_unity,std::shared_ptr<AllocatorBase> alloc_ptr)24 NTT::NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity,
25          std::shared_ptr<AllocatorBase> alloc_ptr)
26     : m_degree(degree),
27       m_q(q),
28       m_w(root_of_unity),
29       m_alloc(alloc_ptr),
30       m_aligned_alloc(AlignedAllocator<uint64_t, 64>(m_alloc)),
31       m_root_of_unity_powers(m_aligned_alloc),
32       m_precon32_root_of_unity_powers(m_aligned_alloc),
33       m_precon64_root_of_unity_powers(m_aligned_alloc),
34       m_avx512_root_of_unity_powers(m_aligned_alloc),
35       m_avx512_precon32_root_of_unity_powers(m_aligned_alloc),
36       m_avx512_precon52_root_of_unity_powers(m_aligned_alloc),
37       m_avx512_precon64_root_of_unity_powers(m_aligned_alloc),
38       m_precon32_inv_root_of_unity_powers(m_aligned_alloc),
39       m_precon52_inv_root_of_unity_powers(m_aligned_alloc),
40       m_precon64_inv_root_of_unity_powers(m_aligned_alloc),
41       m_inv_root_of_unity_powers(m_aligned_alloc) {
42   HEXL_CHECK(CheckArguments(degree, q), "");
43   HEXL_CHECK(IsPrimitiveRoot(m_w, 2 * degree, q),
44              m_w << " is not a primitive 2*" << degree << "'th root of unity");
45 
46   m_degree_bits = Log2(m_degree);
47   m_w_inv = InverseMod(m_w, m_q);
48   ComputeRootOfUnityPowers();
49 }
50 
NTT(uint64_t degree,uint64_t q,std::shared_ptr<AllocatorBase> alloc_ptr)51 NTT::NTT(uint64_t degree, uint64_t q, std::shared_ptr<AllocatorBase> alloc_ptr)
52     : NTT(degree, q, MinimalPrimitiveRoot(2 * degree, q), alloc_ptr) {}
53 
ComputeRootOfUnityPowers()54 void NTT::ComputeRootOfUnityPowers() {
55   AlignedVector64<uint64_t> root_of_unity_powers(m_degree, 0, m_aligned_alloc);
56   AlignedVector64<uint64_t> inv_root_of_unity_powers(m_degree, 0,
57                                                      m_aligned_alloc);
58 
59   // 64-bit preconditioned inverse and root of unity powers
60   root_of_unity_powers[0] = 1;
61   inv_root_of_unity_powers[0] = InverseMod(1, m_q);
62   uint64_t idx = 0;
63   uint64_t prev_idx = idx;
64 
65   for (size_t i = 1; i < m_degree; i++) {
66     idx = ReverseBits(i, m_degree_bits);
67     root_of_unity_powers[idx] =
68         MultiplyMod(root_of_unity_powers[prev_idx], m_w, m_q);
69     inv_root_of_unity_powers[idx] = InverseMod(root_of_unity_powers[idx], m_q);
70 
71     prev_idx = idx;
72   }
73 
74   m_root_of_unity_powers = root_of_unity_powers;
75   m_avx512_root_of_unity_powers = m_root_of_unity_powers;
76 
77   // Duplicate each root of unity at indices [N/4, N/2].
78   // These are the roots of unity used in the FwdNTT FwdT2 function
79   // By creating these duplicates, we avoid extra permutations while loading the
80   // roots of unity
81   AlignedVector64<uint64_t> W2_roots;
82   W2_roots.reserve(m_degree / 2);
83   for (size_t i = m_degree / 4; i < m_degree / 2; ++i) {
84     W2_roots.push_back(m_root_of_unity_powers[i]);
85     W2_roots.push_back(m_root_of_unity_powers[i]);
86   }
87   m_avx512_root_of_unity_powers.erase(
88       m_avx512_root_of_unity_powers.begin() + m_degree / 4,
89       m_avx512_root_of_unity_powers.begin() + m_degree / 2);
90   m_avx512_root_of_unity_powers.insert(
91       m_avx512_root_of_unity_powers.begin() + m_degree / 4, W2_roots.begin(),
92       W2_roots.end());
93 
94   // Duplicate each root of unity at indices [N/8, N/4].
95   // These are the roots of unity used in the FwdNTT FwdT4 function
96   // By creating these duplicates, we avoid extra permutations while loading the
97   // roots of unity
98   AlignedVector64<uint64_t> W4_roots;
99   W4_roots.reserve(m_degree / 2);
100   for (size_t i = m_degree / 8; i < m_degree / 4; ++i) {
101     W4_roots.push_back(m_root_of_unity_powers[i]);
102     W4_roots.push_back(m_root_of_unity_powers[i]);
103     W4_roots.push_back(m_root_of_unity_powers[i]);
104     W4_roots.push_back(m_root_of_unity_powers[i]);
105   }
106   m_avx512_root_of_unity_powers.erase(
107       m_avx512_root_of_unity_powers.begin() + m_degree / 8,
108       m_avx512_root_of_unity_powers.begin() + m_degree / 4);
109   m_avx512_root_of_unity_powers.insert(
110       m_avx512_root_of_unity_powers.begin() + m_degree / 8, W4_roots.begin(),
111       W4_roots.end());
112 
113   auto compute_barrett_vector = [&](const AlignedVector64<uint64_t>& values,
114                                     uint64_t bit_shift) {
115     AlignedVector64<uint64_t> barrett_vector(m_aligned_alloc);
116     for (uint64_t value : values) {
117       MultiplyFactor mf(value, bit_shift, m_q);
118       barrett_vector.push_back(mf.BarrettFactor());
119     }
120     return barrett_vector;
121   };
122 
123   m_precon32_root_of_unity_powers =
124       compute_barrett_vector(root_of_unity_powers, 32);
125   m_precon64_root_of_unity_powers =
126       compute_barrett_vector(root_of_unity_powers, 64);
127 
128   // 52-bit preconditioned root of unity powers
129   if (has_avx512ifma) {
130     m_avx512_precon52_root_of_unity_powers =
131         compute_barrett_vector(m_avx512_root_of_unity_powers, 52);
132   }
133 
134   if (has_avx512dq) {
135     m_avx512_precon32_root_of_unity_powers =
136         compute_barrett_vector(m_avx512_root_of_unity_powers, 32);
137     m_avx512_precon64_root_of_unity_powers =
138         compute_barrett_vector(m_avx512_root_of_unity_powers, 64);
139   }
140 
141   // Inverse root of unity powers
142 
143   // Reordering inv_root_of_powers
144   AlignedVector64<uint64_t> temp(m_degree, 0, m_aligned_alloc);
145   temp[0] = inv_root_of_unity_powers[0];
146   idx = 1;
147 
148   for (size_t m = (m_degree >> 1); m > 0; m >>= 1) {
149     for (size_t i = 0; i < m; i++) {
150       temp[idx] = inv_root_of_unity_powers[m + i];
151       idx++;
152     }
153   }
154   m_inv_root_of_unity_powers = std::move(temp);
155 
156   // 32-bit preconditioned inverse root of unity powers
157   m_precon32_inv_root_of_unity_powers =
158       compute_barrett_vector(m_inv_root_of_unity_powers, 32);
159 
160   // 52-bit preconditioned inverse root of unity powers
161   if (has_avx512ifma) {
162     m_precon52_inv_root_of_unity_powers =
163         compute_barrett_vector(m_inv_root_of_unity_powers, 52);
164   }
165 
166   // 64-bit preconditioned inverse root of unity powers
167   m_precon64_inv_root_of_unity_powers =
168       compute_barrett_vector(m_inv_root_of_unity_powers, 64);
169 }
170 
CheckArguments(uint64_t degree,uint64_t modulus)171 bool NTT::CheckArguments(uint64_t degree, uint64_t modulus) {
172   HEXL_UNUSED(degree);
173   HEXL_UNUSED(modulus);
174   HEXL_CHECK(IsPowerOfTwo(degree),
175              "degree " << degree << " is not a power of 2");
176   HEXL_CHECK(degree <= (1ULL << NTT::MaxDegreeBits()),
177              "degree should be less than 2^" << NTT::MaxDegreeBits() << " got "
178                                              << degree);
179   HEXL_CHECK(modulus <= (1ULL << NTT::MaxModulusBits()),
180              "modulus should be less than 2^" << NTT::MaxModulusBits()
181                                               << " got " << modulus);
182   HEXL_CHECK(modulus % (2 * degree) == 1, "modulus mod 2n != 1");
183   HEXL_CHECK(IsPrime(modulus), "modulus is not prime");
184 
185   return true;
186 }
187 
ComputeForward(uint64_t * result,const uint64_t * operand,uint64_t input_mod_factor,uint64_t output_mod_factor)188 void NTT::ComputeForward(uint64_t* result, const uint64_t* operand,
189                          uint64_t input_mod_factor,
190                          uint64_t output_mod_factor) {
191   HEXL_CHECK(result != nullptr, "result == nullptr");
192   HEXL_CHECK(operand != nullptr, "operand == nullptr");
193   HEXL_CHECK(
194       input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4,
195       "input_mod_factor must be 1, 2 or 4; got " << input_mod_factor);
196   HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4,
197              "output_mod_factor must be 1 or 4; got " << output_mod_factor);
198   HEXL_CHECK_BOUNDS(
199       operand, m_degree, m_q * input_mod_factor,
200       "value in operand exceeds bound " << m_q * input_mod_factor);
201 
202 #ifdef HEXL_HAS_AVX512IFMA
203   if (has_avx512ifma && (m_q < s_max_fwd_ifma_modulus && (m_degree >= 16))) {
204     const uint64_t* root_of_unity_powers = GetAVX512RootOfUnityPowers().data();
205     const uint64_t* precon_root_of_unity_powers =
206         GetAVX512Precon52RootOfUnityPowers().data();
207 
208     HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA FwdNTT");
209     ForwardTransformToBitReverseAVX512<s_ifma_shift_bits>(
210         result, operand, m_degree, m_q, root_of_unity_powers,
211         precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
212     return;
213   }
214 #endif
215 
216 #ifdef HEXL_HAS_AVX512DQ
217   if (has_avx512dq && m_degree >= 16) {
218     if (m_q < s_max_fwd_32_modulus) {
219       HEXL_VLOG(3, "Calling 32-bit AVX512-DQ FwdNTT");
220       const uint64_t* root_of_unity_powers =
221           GetAVX512RootOfUnityPowers().data();
222       const uint64_t* precon_root_of_unity_powers =
223           GetAVX512Precon32RootOfUnityPowers().data();
224       ForwardTransformToBitReverseAVX512<32>(
225           result, operand, m_degree, m_q, root_of_unity_powers,
226           precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
227     } else {
228       HEXL_VLOG(3, "Calling 64-bit AVX512-DQ FwdNTT");
229       const uint64_t* root_of_unity_powers =
230           GetAVX512RootOfUnityPowers().data();
231       const uint64_t* precon_root_of_unity_powers =
232           GetAVX512Precon64RootOfUnityPowers().data();
233 
234       ForwardTransformToBitReverseAVX512<s_default_shift_bits>(
235           result, operand, m_degree, m_q, root_of_unity_powers,
236           precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
237     }
238     return;
239   }
240 #endif
241 
242   HEXL_VLOG(3, "Calling ForwardTransformToBitReverseRadix2");
243   const uint64_t* root_of_unity_powers = GetRootOfUnityPowers().data();
244   const uint64_t* precon_root_of_unity_powers =
245       GetPrecon64RootOfUnityPowers().data();
246 
247   ForwardTransformToBitReverseRadix2(
248       result, operand, m_degree, m_q, root_of_unity_powers,
249       precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
250 }
251 
ComputeInverse(uint64_t * result,const uint64_t * operand,uint64_t input_mod_factor,uint64_t output_mod_factor)252 void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand,
253                          uint64_t input_mod_factor,
254                          uint64_t output_mod_factor) {
255   HEXL_CHECK(result != nullptr, "result == nullptr");
256   HEXL_CHECK(operand != nullptr, "operand == nullptr");
257   HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2,
258              "input_mod_factor must be 1 or 2; got " << input_mod_factor);
259   HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2,
260              "output_mod_factor must be 1 or 2; got " << output_mod_factor);
261   HEXL_CHECK_BOUNDS(operand, m_degree, m_q * input_mod_factor,
262                     "operand exceeds bound " << m_q * input_mod_factor);
263 
264 #ifdef HEXL_HAS_AVX512IFMA
265   if (has_avx512ifma && (m_q < s_max_inv_ifma_modulus) && (m_degree >= 16)) {
266     HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA InvNTT");
267     const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data();
268     const uint64_t* precon_inv_root_of_unity_powers =
269         GetPrecon52InvRootOfUnityPowers().data();
270     InverseTransformFromBitReverseAVX512<s_ifma_shift_bits>(
271         result, operand, m_degree, m_q, inv_root_of_unity_powers,
272         precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
273     return;
274   }
275 #endif
276 
277 #ifdef HEXL_HAS_AVX512DQ
278   if (has_avx512dq && m_degree >= 16) {
279     if (m_q < s_max_inv_32_modulus) {
280       HEXL_VLOG(3, "Calling 32-bit AVX512-DQ InvNTT");
281       const uint64_t* inv_root_of_unity_powers =
282           GetInvRootOfUnityPowers().data();
283       const uint64_t* precon_inv_root_of_unity_powers =
284           GetPrecon32InvRootOfUnityPowers().data();
285       InverseTransformFromBitReverseAVX512<32>(
286           result, operand, m_degree, m_q, inv_root_of_unity_powers,
287           precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
288     } else {
289       HEXL_VLOG(3, "Calling 64-bit AVX512 InvNTT");
290       const uint64_t* inv_root_of_unity_powers =
291           GetInvRootOfUnityPowers().data();
292       const uint64_t* precon_inv_root_of_unity_powers =
293           GetPrecon64InvRootOfUnityPowers().data();
294 
295       InverseTransformFromBitReverseAVX512<s_default_shift_bits>(
296           result, operand, m_degree, m_q, inv_root_of_unity_powers,
297           precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
298     }
299     return;
300   }
301 #endif
302 
303   HEXL_VLOG(3, "Calling 64-bit default InvNTT");
304   const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data();
305   const uint64_t* precon_inv_root_of_unity_powers =
306       GetPrecon64InvRootOfUnityPowers().data();
307   InverseTransformFromBitReverseRadix2(
308       result, operand, m_degree, m_q, inv_root_of_unity_powers,
309       precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
310 }
311 
312 }  // namespace hexl
313 }  // namespace intel
314