1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #include "ntt/fwd-ntt-avx512.hpp"
5
6 #include <cstring>
7 #include <functional>
8 #include <vector>
9
10 #include "hexl/logging/logging.hpp"
11 #include "hexl/ntt/ntt.hpp"
12 #include "hexl/number-theory/number-theory.hpp"
13 #include "ntt/ntt-avx512-util.hpp"
14 #include "ntt/ntt-internal.hpp"
15 #include "util/avx512-util.hpp"
16
17 namespace intel {
18 namespace hexl {
19
20 #ifdef HEXL_HAS_AVX512IFMA
21 template void ForwardTransformToBitReverseAVX512<NTT::s_ifma_shift_bits>(
22 uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod,
23 const uint64_t* root_of_unity_powers,
24 const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
25 uint64_t output_mod_factor, uint64_t recursion_depth,
26 uint64_t recursion_half);
27 #endif
28
29 #ifdef HEXL_HAS_AVX512DQ
30 template void ForwardTransformToBitReverseAVX512<32>(
31 uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod,
32 const uint64_t* root_of_unity_powers,
33 const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
34 uint64_t output_mod_factor, uint64_t recursion_depth,
35 uint64_t recursion_half);
36
37 template void ForwardTransformToBitReverseAVX512<NTT::s_default_shift_bits>(
38 uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod,
39 const uint64_t* root_of_unity_powers,
40 const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
41 uint64_t output_mod_factor, uint64_t recursion_depth,
42 uint64_t recursion_half);
43 #endif
44
45 #ifdef HEXL_HAS_AVX512DQ
46
47 /// @brief The Harvey butterfly: assume \p X, \p Y in [0, 4q), and return X', Y'
48 /// in [0, 4q) such that X', Y' = X + WY, X - WY (mod q).
49 /// @param[in,out] X Input representing 8 64-bit signed integers in SIMD form
50 /// @param[in,out] Y Input representing 8 64-bit signed integers in SIMD form
51 /// @param[in] W Root of unity represented as 8 64-bit signed integers in
52 /// SIMD form
53 /// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett
54 /// reduction
55 /// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit
56 /// signed integers in SIMD form
57 /// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit
58 /// signed integers in SIMD form
59 /// @param InputLessThanMod If true, assumes \p X, \p Y < \p q. Otherwise,
60 /// assumes \p X, \p Y < 4*\p q
61 /// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf
62 template <int BitShift, bool InputLessThanMod>
FwdButterfly(__m512i * X,__m512i * Y,__m512i W,__m512i W_precon,__m512i neg_modulus,__m512i twice_modulus)63 void FwdButterfly(__m512i* X, __m512i* Y, __m512i W, __m512i W_precon,
64 __m512i neg_modulus, __m512i twice_modulus) {
65 if (!InputLessThanMod) {
66 *X = _mm512_hexl_small_mod_epu64(*X, twice_modulus);
67 }
68
69 __m512i T;
70 if (BitShift == 32) {
71 __m512i Q = _mm512_hexl_mullo_epi<64>(W_precon, *Y);
72 Q = _mm512_srli_epi64(Q, 32);
73 __m512i W_Y = _mm512_hexl_mullo_epi<64>(W, *Y);
74 T = _mm512_hexl_mullo_add_lo_epi<64>(W_Y, Q, neg_modulus);
75 } else if (BitShift == 52) {
76 __m512i Q = _mm512_hexl_mulhi_epi<BitShift>(W_precon, *Y);
77 __m512i W_Y = _mm512_hexl_mullo_epi<BitShift>(W, *Y);
78 T = _mm512_hexl_mullo_add_lo_epi<BitShift>(W_Y, Q, neg_modulus);
79 } else if (BitShift == 64) {
80 // Perform approximate computation of Q, as described in page 7 of
81 // https://arxiv.org/pdf/2003.04510.pdf
82 __m512i Q = _mm512_hexl_mulhi_approx_epi<BitShift>(W_precon, *Y);
83 __m512i W_Y = _mm512_hexl_mullo_epi<BitShift>(W, *Y);
84 // Compute T in range [0, 4q)
85 T = _mm512_hexl_mullo_add_lo_epi<BitShift>(W_Y, Q, neg_modulus);
86 // Reduce T to range [0, 2q)
87 T = _mm512_hexl_small_mod_epu64<2>(T, twice_modulus);
88 } else {
89 HEXL_CHECK(false, "Invalid BitShift " << BitShift);
90 }
91
92 __m512i twice_mod_minus_T = _mm512_sub_epi64(twice_modulus, T);
93 *Y = _mm512_add_epi64(*X, twice_mod_minus_T);
94 *X = _mm512_add_epi64(*X, T);
95 }
96
97 template <int BitShift>
FwdT1(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)98 void FwdT1(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
99 uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
100 const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
101 const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
102 size_t j1 = 0;
103
104 // 8 | m guaranteed by n >= 16
105 HEXL_LOOP_UNROLL_8
106 for (size_t i = m / 8; i > 0; --i) {
107 uint64_t* X = operand + j1;
108 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
109
110 __m512i v_X;
111 __m512i v_Y;
112 LoadFwdInterleavedT1(X, &v_X, &v_Y);
113 __m512i v_W = _mm512_loadu_si512(v_W_pt++);
114 __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
115
116 FwdButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
117 v_twice_mod);
118 WriteFwdInterleavedT1(v_X, v_Y, v_X_pt);
119
120 j1 += 16;
121 }
122 }
123
124 template <int BitShift>
FwdT2(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)125 void FwdT2(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
126 uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
127 const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
128 const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
129
130 size_t j1 = 0;
131 // 4 | m guaranteed by n >= 16
132 HEXL_LOOP_UNROLL_4
133 for (size_t i = m / 4; i > 0; --i) {
134 uint64_t* X = operand + j1;
135 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
136
137 __m512i v_X;
138 __m512i v_Y;
139 LoadFwdInterleavedT2(X, &v_X, &v_Y);
140
141 __m512i v_W = _mm512_loadu_si512(v_W_pt++);
142 __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
143
144 HEXL_CHECK(ExtractValues(v_W)[0] == ExtractValues(v_W)[1],
145 "bad v_W " << ExtractValues(v_W));
146 HEXL_CHECK(ExtractValues(v_W_precon)[0] == ExtractValues(v_W_precon)[1],
147 "bad v_W_precon " << ExtractValues(v_W_precon));
148 FwdButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
149 v_twice_mod);
150
151 _mm512_storeu_si512(v_X_pt++, v_X);
152 _mm512_storeu_si512(v_X_pt, v_Y);
153
154 j1 += 16;
155 }
156 }
157
158 template <int BitShift>
FwdT4(uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t m,const uint64_t * W,const uint64_t * W_precon)159 void FwdT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod,
160 uint64_t m, const uint64_t* W, const uint64_t* W_precon) {
161 size_t j1 = 0;
162 const __m512i* v_W_pt = reinterpret_cast<const __m512i*>(W);
163 const __m512i* v_W_precon_pt = reinterpret_cast<const __m512i*>(W_precon);
164
165 // 2 | m guaranteed by n >= 16
166 HEXL_LOOP_UNROLL_4
167 for (size_t i = m / 2; i > 0; --i) {
168 uint64_t* X = operand + j1;
169 __m512i* v_X_pt = reinterpret_cast<__m512i*>(X);
170
171 __m512i v_X;
172 __m512i v_Y;
173 LoadFwdInterleavedT4(X, &v_X, &v_Y);
174
175 __m512i v_W = _mm512_loadu_si512(v_W_pt++);
176 __m512i v_W_precon = _mm512_loadu_si512(v_W_precon_pt++);
177 FwdButterfly<BitShift, false>(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus,
178 v_twice_mod);
179
180 _mm512_storeu_si512(v_X_pt++, v_X);
181 _mm512_storeu_si512(v_X_pt, v_Y);
182
183 j1 += 16;
184 }
185 }
186
187 // Out-of-place implementation
188 template <int BitShift, bool InputLessThanMod>
FwdT8(uint64_t * result,const uint64_t * operand,__m512i v_neg_modulus,__m512i v_twice_mod,uint64_t t,uint64_t m,const uint64_t * W,const uint64_t * W_precon)189 void FwdT8(uint64_t* result, const uint64_t* operand, __m512i v_neg_modulus,
190 __m512i v_twice_mod, uint64_t t, uint64_t m, const uint64_t* W,
191 const uint64_t* W_precon) {
192 size_t j1 = 0;
193
194 HEXL_LOOP_UNROLL_4
195 for (size_t i = 0; i < m; i++) {
196 // Referencing operand
197 const uint64_t* X_op = operand + j1;
198 const uint64_t* Y_op = X_op + t;
199
200 const __m512i* v_X_op_pt = reinterpret_cast<const __m512i*>(X_op);
201 const __m512i* v_Y_op_pt = reinterpret_cast<const __m512i*>(Y_op);
202
203 // Referencing result
204 uint64_t* X_r = result + j1;
205 uint64_t* Y_r = X_r + t;
206
207 __m512i* v_X_r_pt = reinterpret_cast<__m512i*>(X_r);
208 __m512i* v_Y_r_pt = reinterpret_cast<__m512i*>(Y_r);
209
210 // Weights and weights' preconditions
211 __m512i v_W = _mm512_set1_epi64(static_cast<int64_t>(*W++));
212 __m512i v_W_precon = _mm512_set1_epi64(static_cast<int64_t>(*W_precon++));
213
214 // assume 8 | t
215 for (size_t j = t / 8; j > 0; --j) {
216 __m512i v_X = _mm512_loadu_si512(v_X_op_pt);
217 __m512i v_Y = _mm512_loadu_si512(v_Y_op_pt);
218
219 FwdButterfly<BitShift, InputLessThanMod>(&v_X, &v_Y, v_W, v_W_precon,
220 v_neg_modulus, v_twice_mod);
221
222 _mm512_storeu_si512(v_X_r_pt++, v_X);
223 _mm512_storeu_si512(v_Y_r_pt++, v_Y);
224
225 // Increase operand pointers as well
226 v_X_op_pt++;
227 v_Y_op_pt++;
228 }
229 j1 += (t << 1);
230 }
231 }
232
233 template <int BitShift>
ForwardTransformToBitReverseAVX512(uint64_t * result,const uint64_t * operand,uint64_t n,uint64_t modulus,const uint64_t * root_of_unity_powers,const uint64_t * precon_root_of_unity_powers,uint64_t input_mod_factor,uint64_t output_mod_factor,uint64_t recursion_depth,uint64_t recursion_half)234 void ForwardTransformToBitReverseAVX512(
235 uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus,
236 const uint64_t* root_of_unity_powers,
237 const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor,
238 uint64_t output_mod_factor, uint64_t recursion_depth,
239 uint64_t recursion_half) {
240 HEXL_CHECK(NTT::CheckArguments(n, modulus), "");
241 HEXL_CHECK(modulus < NTT::s_max_fwd_modulus(BitShift),
242 "modulus " << modulus << " too large for BitShift " << BitShift
243 << " => maximum value "
244 << NTT::s_max_fwd_modulus(BitShift));
245 HEXL_CHECK_BOUNDS(precon_root_of_unity_powers, n, MaximumValue(BitShift),
246 "precon_root_of_unity_powers too large");
247 HEXL_CHECK_BOUNDS(operand, n, MaximumValue(BitShift), "operand too large");
248 // Skip input bound checking for recursive steps
249 HEXL_CHECK_BOUNDS(operand, (recursion_depth == 0) ? n : 0,
250 input_mod_factor * modulus,
251 "operand larger than input_mod_factor * modulus ("
252 << input_mod_factor << " * " << modulus << ")");
253 HEXL_CHECK(n >= 16,
254 "Don't support small transforms. Need n >= 16, got n = " << n);
255 HEXL_CHECK(
256 input_mod_factor == 1 || input_mod_factor == 2 || input_mod_factor == 4,
257 "input_mod_factor must be 1, 2, or 4; got " << input_mod_factor);
258 HEXL_CHECK(output_mod_factor == 1 || output_mod_factor == 4,
259 "output_mod_factor must be 1 or 4; got " << output_mod_factor);
260
261 uint64_t twice_mod = modulus << 1;
262
263 __m512i v_modulus = _mm512_set1_epi64(static_cast<int64_t>(modulus));
264 __m512i v_neg_modulus = _mm512_set1_epi64(-static_cast<int64_t>(modulus));
265 __m512i v_twice_mod = _mm512_set1_epi64(static_cast<int64_t>(twice_mod));
266
267 HEXL_VLOG(5, "root_of_unity_powers " << std::vector<uint64_t>(
268 root_of_unity_powers, root_of_unity_powers + n))
269 HEXL_VLOG(5,
270 "precon_root_of_unity_powers " << std::vector<uint64_t>(
271 precon_root_of_unity_powers, precon_root_of_unity_powers + n));
272 HEXL_VLOG(5, "operand " << std::vector<uint64_t>(operand, operand + n));
273
274 static const size_t base_ntt_size = 1024;
275
276 if (n <= base_ntt_size) { // Perform breadth-first NTT
277 size_t t = (n >> 1);
278 size_t m = 1;
279 size_t W_idx = (m << recursion_depth) + (recursion_half * m);
280
281 // Copy for out-of-place in case m is <= base_ntt_size from start
282 if (result != operand) {
283 std::memcpy(result, operand, n * sizeof(uint64_t));
284 }
285
286 // First iteration assumes input in [0,p)
287 if (m < (n >> 3)) {
288 const uint64_t* W = &root_of_unity_powers[W_idx];
289 const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];
290
291 if ((input_mod_factor <= 2) && (recursion_depth == 0)) {
292 FwdT8<BitShift, true>(result, result, v_neg_modulus, v_twice_mod, t, m,
293 W, W_precon);
294 } else {
295 FwdT8<BitShift, false>(result, result, v_neg_modulus, v_twice_mod, t, m,
296 W, W_precon);
297 }
298
299 t >>= 1;
300 m <<= 1;
301 W_idx <<= 1;
302 }
303 for (; m < (n >> 3); m <<= 1) {
304 const uint64_t* W = &root_of_unity_powers[W_idx];
305 const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];
306 FwdT8<BitShift, false>(result, result, v_neg_modulus, v_twice_mod, t, m,
307 W, W_precon);
308 t >>= 1;
309 W_idx <<= 1;
310 }
311
312 // Do T=4, T=2, T=1 separately
313 {
314 // Correction step needed due to extra copies of roots of unity in the
315 // AVX512 vectors loaded for FwdT2 and FwdT4
316 auto compute_new_W_idx = [&](size_t idx) {
317 // Originally, from root of unity vector index to loop:
318 // [0, N/8) => FwdT8
319 // [N/8, N/4) => FwdT4
320 // [N/4, N/2) => FwdT2
321 // [N/2, N) => FwdT1
322 // The new mapping from AVX512 root of unity vector index to loop:
323 // [0, N/8) => FwdT8
324 // [N/8, 5N/8) => FwdT4
325 // [5N/8, 9N/8) => FwdT2
326 // [9N/8, 13N/8) => FwdT1
327 size_t N = n << recursion_depth;
328
329 // FwdT8 range
330 if (idx <= N / 8) {
331 return idx;
332 }
333 // FwdT4 range
334 if (idx <= N / 4) {
335 return (idx - N / 8) * 4 + (N / 8);
336 }
337 // FwdT2 range
338 if (idx <= N / 2) {
339 return (idx - N / 4) * 2 + (5 * N / 8);
340 }
341 // FwdT1 range
342 return idx + (5 * N / 8);
343 };
344
345 size_t new_W_idx = compute_new_W_idx(W_idx);
346 const uint64_t* W = &root_of_unity_powers[new_W_idx];
347 const uint64_t* W_precon = &precon_root_of_unity_powers[new_W_idx];
348 FwdT4<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
349
350 m <<= 1;
351 W_idx <<= 1;
352 new_W_idx = compute_new_W_idx(W_idx);
353 W = &root_of_unity_powers[new_W_idx];
354 W_precon = &precon_root_of_unity_powers[new_W_idx];
355 FwdT2<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
356
357 m <<= 1;
358 W_idx <<= 1;
359 new_W_idx = compute_new_W_idx(W_idx);
360 W = &root_of_unity_powers[new_W_idx];
361 W_precon = &precon_root_of_unity_powers[new_W_idx];
362 FwdT1<BitShift>(result, v_neg_modulus, v_twice_mod, m, W, W_precon);
363 }
364
365 if (output_mod_factor == 1) {
366 // n power of two at least 8 => n divisible by 8
367 HEXL_CHECK(n % 8 == 0, "n " << n << " not a power of 2");
368 __m512i* v_X_pt = reinterpret_cast<__m512i*>(result);
369 for (size_t i = 0; i < n; i += 8) {
370 __m512i v_X = _mm512_loadu_si512(v_X_pt);
371
372 // Reduce from [0, 4q) to [0, q)
373 v_X = _mm512_hexl_small_mod_epu64(v_X, v_twice_mod);
374 v_X = _mm512_hexl_small_mod_epu64(v_X, v_modulus);
375
376 HEXL_CHECK_BOUNDS(ExtractValues(v_X).data(), 8, modulus,
377 "v_X exceeds bound " << modulus);
378
379 _mm512_storeu_si512(v_X_pt, v_X);
380
381 ++v_X_pt;
382 }
383 }
384 } else {
385 // Perform depth-first NTT via recursive call
386 size_t t = (n >> 1);
387 size_t W_idx = (1ULL << recursion_depth) + recursion_half;
388 const uint64_t* W = &root_of_unity_powers[W_idx];
389 const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];
390
391 FwdT8<BitShift, false>(result, operand, v_neg_modulus, v_twice_mod, t, 1, W,
392 W_precon);
393
394 ForwardTransformToBitReverseAVX512<BitShift>(
395 result, result, n / 2, modulus, root_of_unity_powers,
396 precon_root_of_unity_powers, input_mod_factor, output_mod_factor,
397 recursion_depth + 1, recursion_half * 2);
398
399 ForwardTransformToBitReverseAVX512<BitShift>(
400 &result[n / 2], &result[n / 2], n / 2, modulus, root_of_unity_powers,
401 precon_root_of_unity_powers, input_mod_factor, output_mod_factor,
402 recursion_depth + 1, recursion_half * 2 + 1);
403 }
404 }
405
406 #endif // HEXL_HAS_AVX512DQ
407
408 } // namespace hexl
409 } // namespace intel
410