1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #include "ntt/ntt-internal.hpp"
5
6 #include <cstring>
7 #include <utility>
8
9 #include "hexl/logging/logging.hpp"
10 #include "hexl/ntt/ntt.hpp"
11 #include "hexl/number-theory/number-theory.hpp"
12 #include "hexl/util/aligned-allocator.hpp"
13 #include "hexl/util/check.hpp"
14 #include "hexl/util/defines.hpp"
15 #include "ntt/fwd-ntt-avx512.hpp"
16 #include "ntt/inv-ntt-avx512.hpp"
17 #include "util/cpu-features.hpp"
18
19 namespace intel {
20 namespace hexl {
21
22 AllocatorStrategyPtr mallocStrategy = AllocatorStrategyPtr(new MallocStrategy);
23
NTT(uint64_t degree,uint64_t q,uint64_t root_of_unity,std::shared_ptr<AllocatorBase> alloc_ptr)24 NTT::NTT(uint64_t degree, uint64_t q, uint64_t root_of_unity,
25 std::shared_ptr<AllocatorBase> alloc_ptr)
26 : m_degree(degree),
27 m_q(q),
28 m_w(root_of_unity),
29 m_alloc(alloc_ptr),
30 m_aligned_alloc(AlignedAllocator<uint64_t, 64>(m_alloc)),
31 m_root_of_unity_powers(m_aligned_alloc),
32 m_precon32_root_of_unity_powers(m_aligned_alloc),
33 m_precon64_root_of_unity_powers(m_aligned_alloc),
34 m_avx512_root_of_unity_powers(m_aligned_alloc),
35 m_avx512_precon32_root_of_unity_powers(m_aligned_alloc),
36 m_avx512_precon52_root_of_unity_powers(m_aligned_alloc),
37 m_avx512_precon64_root_of_unity_powers(m_aligned_alloc),
38 m_precon32_inv_root_of_unity_powers(m_aligned_alloc),
39 m_precon52_inv_root_of_unity_powers(m_aligned_alloc),
40 m_precon64_inv_root_of_unity_powers(m_aligned_alloc),
41 m_inv_root_of_unity_powers(m_aligned_alloc) {
42 HEXL_CHECK(CheckArguments(degree, q), "");
43 HEXL_CHECK(IsPrimitiveRoot(m_w, 2 * degree, q),
44 m_w << " is not a primitive 2*" << degree << "'th root of unity");
45
46 m_degree_bits = Log2(m_degree);
47 m_w_inv = InverseMod(m_w, m_q);
48 ComputeRootOfUnityPowers();
49 }
50
NTT(uint64_t degree,uint64_t q,std::shared_ptr<AllocatorBase> alloc_ptr)51 NTT::NTT(uint64_t degree, uint64_t q, std::shared_ptr<AllocatorBase> alloc_ptr)
52 : NTT(degree, q, MinimalPrimitiveRoot(2 * degree, q), alloc_ptr) {}
53
ComputeRootOfUnityPowers()54 void NTT::ComputeRootOfUnityPowers() {
55 AlignedVector64<uint64_t> root_of_unity_powers(m_degree, 0, m_aligned_alloc);
56 AlignedVector64<uint64_t> inv_root_of_unity_powers(m_degree, 0,
57 m_aligned_alloc);
58
59 // 64-bit preconditioned inverse and root of unity powers
60 root_of_unity_powers[0] = 1;
61 inv_root_of_unity_powers[0] = InverseMod(1, m_q);
62 uint64_t idx = 0;
63 uint64_t prev_idx = idx;
64
65 for (size_t i = 1; i < m_degree; i++) {
66 idx = ReverseBits(i, m_degree_bits);
67 root_of_unity_powers[idx] =
68 MultiplyMod(root_of_unity_powers[prev_idx], m_w, m_q);
69 inv_root_of_unity_powers[idx] = InverseMod(root_of_unity_powers[idx], m_q);
70
71 prev_idx = idx;
72 }
73
74 m_root_of_unity_powers = root_of_unity_powers;
75 m_avx512_root_of_unity_powers = m_root_of_unity_powers;
76
77 // Duplicate each root of unity at indices [N/4, N/2].
78 // These are the roots of unity used in the FwdNTT FwdT2 function
79 // By creating these duplicates, we avoid extra permutations while loading the
80 // roots of unity
81 AlignedVector64<uint64_t> W2_roots;
82 W2_roots.reserve(m_degree / 2);
83 for (size_t i = m_degree / 4; i < m_degree / 2; ++i) {
84 W2_roots.push_back(m_root_of_unity_powers[i]);
85 W2_roots.push_back(m_root_of_unity_powers[i]);
86 }
87 m_avx512_root_of_unity_powers.erase(
88 m_avx512_root_of_unity_powers.begin() + m_degree / 4,
89 m_avx512_root_of_unity_powers.begin() + m_degree / 2);
90 m_avx512_root_of_unity_powers.insert(
91 m_avx512_root_of_unity_powers.begin() + m_degree / 4, W2_roots.begin(),
92 W2_roots.end());
93
94 // Duplicate each root of unity at indices [N/8, N/4].
95 // These are the roots of unity used in the FwdNTT FwdT4 function
96 // By creating these duplicates, we avoid extra permutations while loading the
97 // roots of unity
98 AlignedVector64<uint64_t> W4_roots;
99 W4_roots.reserve(m_degree / 2);
100 for (size_t i = m_degree / 8; i < m_degree / 4; ++i) {
101 W4_roots.push_back(m_root_of_unity_powers[i]);
102 W4_roots.push_back(m_root_of_unity_powers[i]);
103 W4_roots.push_back(m_root_of_unity_powers[i]);
104 W4_roots.push_back(m_root_of_unity_powers[i]);
105 }
106 m_avx512_root_of_unity_powers.erase(
107 m_avx512_root_of_unity_powers.begin() + m_degree / 8,
108 m_avx512_root_of_unity_powers.begin() + m_degree / 4);
109 m_avx512_root_of_unity_powers.insert(
110 m_avx512_root_of_unity_powers.begin() + m_degree / 8, W4_roots.begin(),
111 W4_roots.end());
112
113 auto compute_barrett_vector = [&](const AlignedVector64<uint64_t>& values,
114 uint64_t bit_shift) {
115 AlignedVector64<uint64_t> barrett_vector(m_aligned_alloc);
116 for (uint64_t value : values) {
117 MultiplyFactor mf(value, bit_shift, m_q);
118 barrett_vector.push_back(mf.BarrettFactor());
119 }
120 return barrett_vector;
121 };
122
123 m_precon32_root_of_unity_powers =
124 compute_barrett_vector(root_of_unity_powers, 32);
125 m_precon64_root_of_unity_powers =
126 compute_barrett_vector(root_of_unity_powers, 64);
127
128 // 52-bit preconditioned root of unity powers
129 if (has_avx512ifma) {
130 m_avx512_precon52_root_of_unity_powers =
131 compute_barrett_vector(m_avx512_root_of_unity_powers, 52);
132 }
133
134 if (has_avx512dq) {
135 m_avx512_precon32_root_of_unity_powers =
136 compute_barrett_vector(m_avx512_root_of_unity_powers, 32);
137 m_avx512_precon64_root_of_unity_powers =
138 compute_barrett_vector(m_avx512_root_of_unity_powers, 64);
139 }
140
141 // Inverse root of unity powers
142
143 // Reordering inv_root_of_powers
144 AlignedVector64<uint64_t> temp(m_degree, 0, m_aligned_alloc);
145 temp[0] = inv_root_of_unity_powers[0];
146 idx = 1;
147
148 for (size_t m = (m_degree >> 1); m > 0; m >>= 1) {
149 for (size_t i = 0; i < m; i++) {
150 temp[idx] = inv_root_of_unity_powers[m + i];
151 idx++;
152 }
153 }
154 m_inv_root_of_unity_powers = std::move(temp);
155
156 // 32-bit preconditioned inverse root of unity powers
157 m_precon32_inv_root_of_unity_powers =
158 compute_barrett_vector(m_inv_root_of_unity_powers, 32);
159
160 // 52-bit preconditioned inverse root of unity powers
161 if (has_avx512ifma) {
162 m_precon52_inv_root_of_unity_powers =
163 compute_barrett_vector(m_inv_root_of_unity_powers, 52);
164 }
165
166 // 64-bit preconditioned inverse root of unity powers
167 m_precon64_inv_root_of_unity_powers =
168 compute_barrett_vector(m_inv_root_of_unity_powers, 64);
169 }
170
CheckArguments(uint64_t degree,uint64_t modulus)171 bool NTT::CheckArguments(uint64_t degree, uint64_t modulus) {
172 HEXL_UNUSED(degree);
173 HEXL_UNUSED(modulus);
174 HEXL_CHECK(IsPowerOfTwo(degree),
175 "degree " << degree << " is not a power of 2");
176 HEXL_CHECK(degree <= (1ULL << NTT::MaxDegreeBits()),
177 "degree should be less than 2^" << NTT::MaxDegreeBits() << " got "
178 << degree);
179 HEXL_CHECK(modulus <= (1ULL << NTT::MaxModulusBits()),
180 "modulus should be less than 2^" << NTT::MaxModulusBits()
181 << " got " << modulus);
182 HEXL_CHECK(modulus % (2 * degree) == 1, "modulus mod 2n != 1");
183 HEXL_CHECK(IsPrime(modulus), "modulus is not prime");
184
185 return true;
186 }
187
ComputeForward(uint64_t * result,const uint64_t * operand,uint64_t input_mod_factor,uint64_t output_mod_factor)188 void NTT::ComputeForward(uint64_t* result, const uint64_t* operand,
189 uint64_t input_mod_factor,
190 uint64_t output_mod_factor) {
191 HEXL_CHECK(result != nullptr, "result == nullptr");
192 HEXL_CHECK(operand != nullptr, "operand == nullptr");
193 HEXL_CHECK(
194 input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4,
195 "input_mod_factor must be 1, 2 or 4; got " << input_mod_factor);
196 HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4,
197 "output_mod_factor must be 1 or 4; got " << output_mod_factor);
198 HEXL_CHECK_BOUNDS(
199 operand, m_degree, m_q * input_mod_factor,
200 "value in operand exceeds bound " << m_q * input_mod_factor);
201
202 #ifdef HEXL_HAS_AVX512IFMA
203 if (has_avx512ifma && (m_q < s_max_fwd_ifma_modulus && (m_degree >= 16))) {
204 const uint64_t* root_of_unity_powers = GetAVX512RootOfUnityPowers().data();
205 const uint64_t* precon_root_of_unity_powers =
206 GetAVX512Precon52RootOfUnityPowers().data();
207
208 HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA FwdNTT");
209 ForwardTransformToBitReverseAVX512<s_ifma_shift_bits>(
210 result, operand, m_degree, m_q, root_of_unity_powers,
211 precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
212 return;
213 }
214 #endif
215
216 #ifdef HEXL_HAS_AVX512DQ
217 if (has_avx512dq && m_degree >= 16) {
218 if (m_q < s_max_fwd_32_modulus) {
219 HEXL_VLOG(3, "Calling 32-bit AVX512-DQ FwdNTT");
220 const uint64_t* root_of_unity_powers =
221 GetAVX512RootOfUnityPowers().data();
222 const uint64_t* precon_root_of_unity_powers =
223 GetAVX512Precon32RootOfUnityPowers().data();
224 ForwardTransformToBitReverseAVX512<32>(
225 result, operand, m_degree, m_q, root_of_unity_powers,
226 precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
227 } else {
228 HEXL_VLOG(3, "Calling 64-bit AVX512-DQ FwdNTT");
229 const uint64_t* root_of_unity_powers =
230 GetAVX512RootOfUnityPowers().data();
231 const uint64_t* precon_root_of_unity_powers =
232 GetAVX512Precon64RootOfUnityPowers().data();
233
234 ForwardTransformToBitReverseAVX512<s_default_shift_bits>(
235 result, operand, m_degree, m_q, root_of_unity_powers,
236 precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
237 }
238 return;
239 }
240 #endif
241
242 HEXL_VLOG(3, "Calling ForwardTransformToBitReverseRadix2");
243 const uint64_t* root_of_unity_powers = GetRootOfUnityPowers().data();
244 const uint64_t* precon_root_of_unity_powers =
245 GetPrecon64RootOfUnityPowers().data();
246
247 ForwardTransformToBitReverseRadix2(
248 result, operand, m_degree, m_q, root_of_unity_powers,
249 precon_root_of_unity_powers, input_mod_factor, output_mod_factor);
250 }
251
ComputeInverse(uint64_t * result,const uint64_t * operand,uint64_t input_mod_factor,uint64_t output_mod_factor)252 void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand,
253 uint64_t input_mod_factor,
254 uint64_t output_mod_factor) {
255 HEXL_CHECK(result != nullptr, "result == nullptr");
256 HEXL_CHECK(operand != nullptr, "operand == nullptr");
257 HEXL_CHECK(input_mod_factor == 1 || input_mod_factor == 2,
258 "input_mod_factor must be 1 or 2; got " << input_mod_factor);
259 HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 2,
260 "output_mod_factor must be 1 or 2; got " << output_mod_factor);
261 HEXL_CHECK_BOUNDS(operand, m_degree, m_q * input_mod_factor,
262 "operand exceeds bound " << m_q * input_mod_factor);
263
264 #ifdef HEXL_HAS_AVX512IFMA
265 if (has_avx512ifma && (m_q < s_max_inv_ifma_modulus) && (m_degree >= 16)) {
266 HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA InvNTT");
267 const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data();
268 const uint64_t* precon_inv_root_of_unity_powers =
269 GetPrecon52InvRootOfUnityPowers().data();
270 InverseTransformFromBitReverseAVX512<s_ifma_shift_bits>(
271 result, operand, m_degree, m_q, inv_root_of_unity_powers,
272 precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
273 return;
274 }
275 #endif
276
277 #ifdef HEXL_HAS_AVX512DQ
278 if (has_avx512dq && m_degree >= 16) {
279 if (m_q < s_max_inv_32_modulus) {
280 HEXL_VLOG(3, "Calling 32-bit AVX512-DQ InvNTT");
281 const uint64_t* inv_root_of_unity_powers =
282 GetInvRootOfUnityPowers().data();
283 const uint64_t* precon_inv_root_of_unity_powers =
284 GetPrecon32InvRootOfUnityPowers().data();
285 InverseTransformFromBitReverseAVX512<32>(
286 result, operand, m_degree, m_q, inv_root_of_unity_powers,
287 precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
288 } else {
289 HEXL_VLOG(3, "Calling 64-bit AVX512 InvNTT");
290 const uint64_t* inv_root_of_unity_powers =
291 GetInvRootOfUnityPowers().data();
292 const uint64_t* precon_inv_root_of_unity_powers =
293 GetPrecon64InvRootOfUnityPowers().data();
294
295 InverseTransformFromBitReverseAVX512<s_default_shift_bits>(
296 result, operand, m_degree, m_q, inv_root_of_unity_powers,
297 precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
298 }
299 return;
300 }
301 #endif
302
303 HEXL_VLOG(3, "Calling 64-bit default InvNTT");
304 const uint64_t* inv_root_of_unity_powers = GetInvRootOfUnityPowers().data();
305 const uint64_t* precon_inv_root_of_unity_powers =
306 GetPrecon64InvRootOfUnityPowers().data();
307 InverseTransformFromBitReverseRadix2(
308 result, operand, m_degree, m_q, inv_root_of_unity_powers,
309 precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor);
310 }
311
312 } // namespace hexl
313 } // namespace intel
314