1 // @file transfrm.cpp This file contains the linear transform interface
2 // functionality.
3 // @author TPOC: contact@palisade-crypto.org
4 //
5 // @copyright Copyright (c) 2019, New Jersey Institute of Technology (NJIT)
6 // All rights reserved.
7 // Redistribution and use in source and binary forms, with or without
8 // modification, are permitted provided that the following conditions are met:
9 // 1. Redistributions of source code must retain the above copyright notice,
10 // this list of conditions and the following disclaimer.
11 // 2. Redistributions in binary form must reproduce the above copyright notice,
12 // this list of conditions and the following disclaimer in the documentation
13 // and/or other materials provided with the distribution. THIS SOFTWARE IS
14 // PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
15 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
16 // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
17 // EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
18 // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
21 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 
25 #include "math/transfrm.h"
26 #include "utils/defines.h"
27 
28 #ifdef WITH_INTEL_HEXL
29 #include "hexl/hexl.hpp"
30 #endif
31 
32 namespace lbcrypto {
33 
34 template <typename VecType>
35 std::map<typename VecType::Integer, VecType>
36     ChineseRemainderTransformFTT<VecType>::m_cycloOrderInverseTableByModulus;
37 
38 template <typename VecType>
39 std::map<typename VecType::Integer, NativeVector> ChineseRemainderTransformFTT<
40     VecType>::m_cycloOrderInversePreconTableByModulus;
41 
42 template <typename VecType>
43 std::map<typename VecType::Integer, VecType>
44     ChineseRemainderTransformFTT<VecType>::m_rootOfUnityReverseTableByModulus;
45 
46 template <typename VecType>
47 std::map<typename VecType::Integer, VecType> ChineseRemainderTransformFTT<
48     VecType>::m_rootOfUnityInverseReverseTableByModulus;
49 
50 template <typename VecType>
51 std::map<typename VecType::Integer, NativeVector> ChineseRemainderTransformFTT<
52     VecType>::m_rootOfUnityPreconReverseTableByModulus;
53 
54 template <typename VecType>
55 std::map<typename VecType::Integer, NativeVector> ChineseRemainderTransformFTT<
56     VecType>::m_rootOfUnityInversePreconReverseTableByModulus;
57 
58 #ifdef WITH_INTEL_HEXL
59 template <typename VecType>
60 // N, modulus
61 std::unordered_map<std::pair<uint64_t, uint64_t>, intel::hexl::NTT, HashPair>
62     ChineseRemainderTransformFTT<VecType>::m_IntelNtt;
63 template <typename VecType>
64 std::mutex ChineseRemainderTransformFTT<VecType>::m_mtxIntelNTT;
65 #endif
66 
67 template <typename VecType>
68 std::map<typename VecType::Integer, VecType>
69     ChineseRemainderTransformArb<VecType>::m_cyclotomicPolyMap;
70 
71 template <typename VecType>
72 std::map<typename VecType::Integer, VecType>
73     ChineseRemainderTransformArb<VecType>::m_cyclotomicPolyReverseNTTMap;
74 
75 template <typename VecType>
76 std::map<typename VecType::Integer, VecType>
77     ChineseRemainderTransformArb<VecType>::m_cyclotomicPolyNTTMap;
78 
79 template <typename VecType>
80 std::map<ModulusRoot<typename VecType::Integer>, VecType>
81     BluesteinFFT<VecType>::m_rootOfUnityTableByModulusRoot;
82 
83 template <typename VecType>
84 std::map<ModulusRoot<typename VecType::Integer>, VecType>
85     BluesteinFFT<VecType>::m_rootOfUnityInverseTableByModulusRoot;
86 
87 template <typename VecType>
88 std::map<ModulusRoot<typename VecType::Integer>, VecType>
89     BluesteinFFT<VecType>::m_powersTableByModulusRoot;
90 
91 template <typename VecType>
92 std::map<ModulusRootPair<typename VecType::Integer>, VecType>
93     BluesteinFFT<VecType>::m_RBTableByModulusRootPair;
94 
95 template <typename VecType>
96 std::map<typename VecType::Integer, ModulusRoot<typename VecType::Integer>>
97     BluesteinFFT<VecType>::m_defaultNTTModulusRoot;
98 
99 template <typename VecType>
100 std::map<typename VecType::Integer, VecType>
101     ChineseRemainderTransformArb<VecType>::m_rootOfUnityDivisionTableByModulus;
102 
103 template <typename VecType>
104 std::map<typename VecType::Integer, VecType> ChineseRemainderTransformArb<
105     VecType>::m_rootOfUnityDivisionInverseTableByModulus;
106 
107 template <typename VecType>
108 std::map<typename VecType::Integer, typename VecType::Integer>
109     ChineseRemainderTransformArb<VecType>::m_DivisionNTTModulus;
110 
111 template <typename VecType>
112 std::map<typename VecType::Integer, typename VecType::Integer>
113     ChineseRemainderTransformArb<VecType>::m_DivisionNTTRootOfUnity;
114 
115 template <typename VecType>
116 std::map<usint, usint> ChineseRemainderTransformArb<VecType>::m_nttDivisionDim;
117 
118 template <typename VecType>
119 void NumberTheoreticTransform<VecType>::ForwardTransformIterative(
120     const VecType &element, const VecType &rootOfUnityTable, VecType *result) {
121   usint n = element.GetLength();
122   if (result->GetLength() != n) {
123     PALISADE_THROW(
124         math_error,
125         "size of input element and size of output element not of same size");
126   }
127 
128   auto modulus = element.GetModulus();
129   IntType mu = modulus.ComputeMu();
130   result->SetModulus(modulus);
131 
132   usint msb = GetMSB64(n - 1);
133   for (size_t i = 0; i < n; i++) {
134     (*result)[i] = element[ReverseBits(i, msb)];
135   }
136 
137   IntType omega, omegaFactor, oddVal, evenVal;
138   usint logm, i, j, indexEven, indexOdd;
139 
140   usint logn = GetMSB64(n - 1);
141   for (logm = 1; logm <= logn; logm++) {
142     // calculate the i indexes into the root table one time per loop
143     vector<usint> indexes(1 << (logm - 1));
144     for (i = 0; i < (usint)(1 << (logm - 1)); i++) {
145       indexes[i] = (i << (logn - logm));
146     }
147 
148     for (j = 0; j < n; j = j + (1 << logm)) {
149       for (i = 0; i < (usint)(1 << (logm - 1)); i++) {
150         omega = rootOfUnityTable[indexes[i]];
151         indexEven = j + i;
152         indexOdd = indexEven + (1 << (logm - 1));
153         oddVal = (*result)[indexOdd];
154 
155         omegaFactor = omega.ModMul(oddVal, modulus, mu);
156         evenVal = (*result)[indexEven];
157         oddVal = evenVal;
158         oddVal += omegaFactor;
159         if (oddVal >= modulus) {
160           oddVal -= modulus;
161         }
162 
163         if (evenVal < omegaFactor) {
164           evenVal += modulus;
165         }
166         evenVal -= omegaFactor;
167 
168         (*result)[indexEven] = oddVal;
169         (*result)[indexOdd] = evenVal;
170       }
171     }
172   }
173   return;
174 }
175 
176 template <typename VecType>
177 void NumberTheoreticTransform<VecType>::InverseTransformIterative(
178     const VecType &element, const VecType &rootOfUnityInverseTable,
179     VecType *result) {
180   usint n = element.GetLength();
181 
182   IntType modulus = element.GetModulus();
183   IntType mu = modulus.ComputeMu();
184 
185   NumberTheoreticTransform<VecType>::ForwardTransformIterative(
186       element, rootOfUnityInverseTable, result);
187   IntType cycloOrderInv(IntType(n).ModInverse(modulus));
188   for (usint i = 0; i < n; i++) {
189     (*result)[i].ModMulEq(cycloOrderInv, modulus, mu);
190   }
191   return;
192 }
193 
194 template <typename VecType>
195 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace(
196     const VecType &rootOfUnityTable, VecType *element) {
197   usint n = element->GetLength();
198   IntType modulus = element->GetModulus();
199   IntType mu = modulus.ComputeMu();
200 
201   usint i, m, j1, j2, indexOmega, indexLo, indexHi;
202   IntType omega, omegaFactor, loVal, hiVal, zero(0);
203 
204   usint t = (n >> 1);
205   usint logt1 = GetMSB64(t);
206   for (m = 1; m < n; m <<= 1) {
207     for (i = 0; i < m; ++i) {
208       j1 = i << logt1;
209       j2 = j1 + t;
210       indexOmega = m + i;
211       omega = rootOfUnityTable[indexOmega];
212       for (indexLo = j1; indexLo < j2; ++indexLo) {
213         indexHi = indexLo + t;
214         loVal = (*element)[indexLo];
215         omegaFactor = (*element)[indexHi];
216         omegaFactor.ModMulFastEq(omega, modulus, mu);
217 
218         hiVal = loVal + omegaFactor;
219         if (hiVal >= modulus) {
220           hiVal -= modulus;
221         }
222 
223         if (loVal < omegaFactor) {
224           loVal += modulus;
225         }
226         loVal -= omegaFactor;
227 
228         (*element)[indexLo] = hiVal;
229         (*element)[indexHi] = loVal;
230       }
231     }
232     t >>= 1;
233     logt1--;
234   }
235   return;
236 }
237 
238 template <typename VecType>
239 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse(
240     const VecType &element, const VecType &rootOfUnityTable, VecType *result) {
241   usint n = element.GetLength();
242   if (result->GetLength() != n) {
243     PALISADE_THROW(
244         math_error,
245         "size of input element and size of output element not of same size");
246   }
247 
248   IntType modulus = element.GetModulus();
249   IntType mu = modulus.ComputeMu();
250   result->SetModulus(modulus);
251 
252   usint i, m, j1, j2, indexOmega, indexLo, indexHi;
253   IntType omega, omegaFactor, loVal, hiVal, zero(0);
254 
255   for (i = 0; i < n; ++i) {
256     (*result)[i] = element[i];
257   }
258 
259   usint t = (n >> 1);
260   usint logt1 = GetMSB64(t);
261   for (m = 1; m < n; m <<= 1) {
262     for (i = 0; i < m; ++i) {
263       j1 = i << logt1;
264       j2 = j1 + t;
265       indexOmega = m + i;
266       omega = rootOfUnityTable[indexOmega];
267       for (indexLo = j1; indexLo < j2; ++indexLo) {
268         indexHi = indexLo + t;
269         loVal = (*result)[indexLo];
270         omegaFactor = (*result)[indexHi];
271         if (omegaFactor != zero) {
272           omegaFactor.ModMulFastEq(omega, modulus, mu);
273 
274           hiVal = loVal + omegaFactor;
275           if (hiVal >= modulus) {
276             hiVal -= modulus;
277           }
278 
279           if (loVal < omegaFactor) {
280             loVal += modulus;
281           }
282           loVal -= omegaFactor;
283 
284           (*result)[indexLo] = hiVal;
285           (*result)[indexHi] = loVal;
286         } else {
287           (*result)[indexHi] = loVal;
288         }
289       }
290     }
291     t >>= 1;
292     logt1--;
293   }
294   return;
295 }
296 
297 template <typename VecType>
298 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace(
299     const VecType &rootOfUnityTable, const NativeVector &preconRootOfUnityTable,
300     VecType *element) {
301   usint n = element->GetLength();
302   IntType modulus = element->GetModulus();
303 
304   uint32_t indexOmega, indexHi;
305   NativeInteger preconOmega;
306   IntType omega, omegaFactor, loVal, hiVal, zero(0);
307 
308   usint t = (n >> 1);
309   usint logt1 = GetMSB64(t);
310   for (uint32_t m = 1; m < n; m <<= 1, t >>= 1, --logt1) {
311     uint32_t j1, j2;
312     for (uint32_t i = 0; i < m; ++i) {
313       j1 = i << logt1;
314       j2 = j1 + t;
315       indexOmega = m + i;
316       omega = rootOfUnityTable[indexOmega];
317       preconOmega = preconRootOfUnityTable[indexOmega];
318       for (uint32_t indexLo = j1; indexLo < j2; ++indexLo) {
319         indexHi = indexLo + t;
320         loVal = (*element)[indexLo];
321         omegaFactor = (*element)[indexHi];
322         omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
323 
324         hiVal = loVal + omegaFactor;
325         if (hiVal >= modulus) {
326           hiVal -= modulus;
327         }
328 
329         if (loVal < omegaFactor) {
330           loVal += modulus;
331         }
332         loVal -= omegaFactor;
333 
334         (*element)[indexLo] = hiVal;
335         (*element)[indexHi] = loVal;
336       }
337     }
338   }
339   return;
340 }
341 
342 template <typename VecType>
343 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse(
344     const VecType &element, const VecType &rootOfUnityTable,
345     const NativeVector &preconRootOfUnityTable, VecType *result) {
346   usint n = element.GetLength();
347 
348   if (result->GetLength() != n) {
349     PALISADE_THROW(
350         math_error,
351         "size of input element and size of output element not of same size");
352   }
353 
354   IntType modulus = element.GetModulus();
355 
356   result->SetModulus(modulus);
357 
358   for (uint32_t i = 0; i < n; ++i) {
359     (*result)[i] = element[i];
360   }
361 
362   uint32_t indexOmega, indexHi;
363   NativeInteger preconOmega;
364   IntType omega, omegaFactor, loVal, hiVal, zero(0);
365 
366   usint t = (n >> 1);
367   usint logt1 = GetMSB64(t);
368   for (uint32_t m = 1; m < n; m <<= 1, t >>= 1, --logt1) {
369     uint32_t j1, j2;
370     for (uint32_t i = 0; i < m; ++i) {
371       j1 = i << logt1;
372       j2 = j1 + t;
373       indexOmega = m + i;
374       omega = rootOfUnityTable[indexOmega];
375       preconOmega = preconRootOfUnityTable[indexOmega];
376       for (uint32_t indexLo = j1; indexLo < j2; ++indexLo) {
377         indexHi = indexLo + t;
378         loVal = (*result)[indexLo];
379         omegaFactor = (*result)[indexHi];
380         if (omegaFactor != zero) {
381           omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
382 
383           hiVal = loVal + omegaFactor;
384           if (hiVal >= modulus) {
385             hiVal -= modulus;
386           }
387 
388           if (loVal < omegaFactor) {
389             loVal += modulus;
390           }
391           loVal -= omegaFactor;
392 
393           (*result)[indexLo] = hiVal;
394           (*result)[indexHi] = loVal;
395         } else {
396           (*result)[indexHi] = loVal;
397         }
398       }
399     }
400   }
401   return;
402 }
403 
404 template <typename VecType>
405 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
406     const VecType &rootOfUnityInverseTable, const IntType &cycloOrderInv,
407     VecType *element) {
408   usint n = element->GetLength();
409   IntType modulus = element->GetModulus();
410   IntType mu = modulus.ComputeMu();
411 
412   IntType loVal, hiVal, omega, omegaFactor;
413   usint i, m, j1, j2, indexOmega, indexLo, indexHi;
414 
415   usint t = 1;
416   usint logt1 = 1;
417   for (m = (n >> 1); m >= 1; m >>= 1) {
418     for (i = 0; i < m; ++i) {
419       j1 = i << logt1;
420       j2 = j1 + t;
421       indexOmega = m + i;
422       omega = rootOfUnityInverseTable[indexOmega];
423 
424       for (indexLo = j1; indexLo < j2; ++indexLo) {
425         indexHi = indexLo + t;
426 
427         hiVal = (*element)[indexHi];
428         loVal = (*element)[indexLo];
429 
430         omegaFactor = loVal;
431         if (omegaFactor < hiVal) {
432           omegaFactor += modulus;
433         }
434 
435         omegaFactor -= hiVal;
436 
437         loVal += hiVal;
438         if (loVal >= modulus) {
439           loVal -= modulus;
440         }
441 
442         omegaFactor.ModMulFastEq(omega, modulus, mu);
443 
444         (*element)[indexLo] = loVal;
445         (*element)[indexHi] = omegaFactor;
446       }
447     }
448     t <<= 1;
449     logt1++;
450   }
451 
452   for (i = 0; i < n; i++) {
453     (*element)[i].ModMulFastEq(cycloOrderInv, modulus, mu);
454   }
455   return;
456 }
457 
458 template <typename VecType>
459 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverse(
460     const VecType &element, const VecType &rootOfUnityInverseTable,
461     const IntType &cycloOrderInv, VecType *result) {
462   usint n = element.GetLength();
463 
464   if (result->GetLength() != n) {
465     PALISADE_THROW(
466         math_error,
467         "size of input element and size of output element not of same size");
468   }
469 
470   result->SetModulus(element.GetModulus());
471 
472   for (usint i = 0; i < n; i++) {
473     (*result)[i] = element[i];
474   }
475   InverseTransformFromBitReverseInPlace(rootOfUnityInverseTable, cycloOrderInv,
476                                         result);
477 }
478 
479 template <typename VecType>
480 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
481     const VecType &rootOfUnityInverseTable,
482     const NativeVector &preconRootOfUnityInverseTable,
483     const IntType &cycloOrderInv, const NativeInteger &preconCycloOrderInv,
484     VecType *element) {
485   usint n = element->GetLength();
486 
487   IntType modulus = element->GetModulus();
488 
489   IntType loVal, hiVal, omega, omegaFactor;
490   NativeInteger preconOmega;
491   usint i, m, j1, j2, indexOmega, indexLo, indexHi;
492 
493   usint t = 1;
494   usint logt1 = 1;
495   for (m = (n >> 1); m >= 1; m >>= 1) {
496     for (i = 0; i < m; ++i) {
497       j1 = i << logt1;
498       j2 = j1 + t;
499       indexOmega = m + i;
500       omega = rootOfUnityInverseTable[indexOmega];
501       preconOmega = preconRootOfUnityInverseTable[indexOmega];
502 
503       for (indexLo = j1; indexLo < j2; ++indexLo) {
504         indexHi = indexLo + t;
505 
506         hiVal = (*element)[indexHi];
507         loVal = (*element)[indexLo];
508 
509         omegaFactor = loVal;
510         if (omegaFactor < hiVal) {
511           omegaFactor += modulus;
512         }
513 
514         omegaFactor -= hiVal;
515 
516         loVal += hiVal;
517         if (loVal >= modulus) {
518           loVal -= modulus;
519         }
520 
521         omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega);
522 
523         (*element)[indexLo] = loVal;
524         (*element)[indexHi] = omegaFactor;
525       }
526     }
527     t <<= 1;
528     logt1++;
529   }
530 
531   for (i = 0; i < n; i++) {
532     (*element)[i].ModMulFastConstEq(cycloOrderInv, modulus,
533                                     preconCycloOrderInv);
534   }
535 }
536 
537 template <typename VecType>
538 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverse(
539     const VecType &element, const VecType &rootOfUnityInverseTable,
540     const NativeVector &preconRootOfUnityInverseTable,
541     const IntType &cycloOrderInv, const NativeInteger &preconCycloOrderInv,
542     VecType *result) {
543   usint n = element.GetLength();
544   if (result->GetLength() != n) {
545     PALISADE_THROW(
546         math_error,
547         "size of input element and size of output element not of same size");
548   }
549 
550   result->SetModulus(element.GetModulus());
551 
552   for (usint i = 0; i < n; i++) {
553     (*result)[i] = element[i];
554   }
555   InverseTransformFromBitReverseInPlace(
556       rootOfUnityInverseTable, preconRootOfUnityInverseTable, cycloOrderInv,
557       preconCycloOrderInv, result);
558 
559   return;
560 }
561 
562 template <typename VecType>
563 void ChineseRemainderTransformFTT<VecType>::ForwardTransformToBitReverseInPlace(
564     const IntType &rootOfUnity, const usint CycloOrder, VecType *element) {
565   if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) {
566     return;
567   }
568 
569   if (!IsPowerOfTwo(CycloOrder)) {
570     PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two");
571   }
572 
573   usint CycloOrderHf = (CycloOrder >> 1);
574   if (element->GetLength() != CycloOrderHf) {
575     PALISADE_THROW(math_error,
576                    "element size must be equal to CyclotomicOrder / 2");
577   }
578 
579   IntType modulus = element->GetModulus();
580 
581   bool reCompute = false;
582   PALISADE_UNUSED(reCompute); // Used only when WITH_INTEL_HEXL=ON
583   auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus);
584   if (mapSearch == m_rootOfUnityReverseTableByModulus.end() ||
585       mapSearch->second.GetLength() != CycloOrderHf) {
586     PreCompute(rootOfUnity, CycloOrder, modulus);
587     reCompute = true;
588   }
589 
590   if (typeid(IntType) == typeid(NativeInteger)) {
591 #ifdef WITH_INTEL_HEXL
592     if (std::is_same<VecType, NativeVector64>::value) {
593       std::pair<uint64_t, uint64_t> key{element->GetLength(),
594                                         modulus.ConvertToInt()};
595       intel::hexl::NTT *p_ntt;
596       std::unique_lock<std::mutex> lock(m_mtxIntelNTT);
597       auto ntt_it = m_IntelNtt.find(key);
598       if (reCompute || ntt_it == m_IntelNtt.end()) {
599         intel::hexl::NTT ntt(element->GetLength(), modulus.ConvertToInt(),
600                              rootOfUnity.ConvertToInt());
601         m_IntelNtt[key] = std::move(ntt);
602         ntt_it = m_IntelNtt.find(key);
603       }
604       p_ntt = &ntt_it->second;
605       lock.unlock();
606 
607       auto *data = reinterpret_cast<uint64_t *>(&element->at(0));
608       p_ntt->ComputeForward(data, data, 1, 1);
609       element->SetModulus(modulus);
610     } else {
611       NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace(
612           m_rootOfUnityReverseTableByModulus[modulus],
613           m_rootOfUnityPreconReverseTableByModulus[modulus], element);
614     }
615 #else
616     NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace(
617         m_rootOfUnityReverseTableByModulus[modulus],
618         m_rootOfUnityPreconReverseTableByModulus[modulus], element);
619 #endif
620   } else {
621     NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace(
622         m_rootOfUnityReverseTableByModulus[modulus], element);
623   }
624 }
625 
626 template <typename VecType>
627 void ChineseRemainderTransformFTT<VecType>::ForwardTransformToBitReverse(
628     const VecType &element, const IntType &rootOfUnity, const usint CycloOrder,
629     VecType *result) {
630   if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) {
631     *result = element;
632     return;
633   }
634 
635   if (!IsPowerOfTwo(CycloOrder)) {
636     PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two");
637   }
638 
639   usint CycloOrderHf = (CycloOrder >> 1);
640   if (result->GetLength() != CycloOrderHf) {
641     PALISADE_THROW(math_error,
642                    "result size must be equal to CyclotomicOrder / 2");
643   }
644 
645   IntType modulus = element.GetModulus();
646 
647   bool reCompute = false;
648   PALISADE_UNUSED(reCompute); // Used only when WITH_INTEL_HEXL=ON
649   auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus);
650   if (mapSearch == m_rootOfUnityReverseTableByModulus.end() ||
651       mapSearch->second.GetLength() != CycloOrderHf) {
652     PreCompute(rootOfUnity, CycloOrder, modulus);
653     reCompute = true;
654   }
655 
656   if (typeid(IntType) == typeid(NativeInteger)) {
657 #ifdef WITH_INTEL_HEXL
658     if (std::is_same<VecType, NativeVector64>::value) {
659       std::pair<uint64_t, uint64_t> key{element.GetLength(),
660                                         modulus.ConvertToInt()};
661       intel::hexl::NTT *p_ntt;
662       std::unique_lock<std::mutex> lock(m_mtxIntelNTT);
663       auto ntt_it = m_IntelNtt.find(key);
664       if (reCompute || ntt_it == m_IntelNtt.end()) {
665         intel::hexl::NTT ntt(element.GetLength(), modulus.ConvertToInt(),
666                              rootOfUnity.ConvertToInt());
667         m_IntelNtt[key] = std::move(ntt);
668         ntt_it = m_IntelNtt.find(key);
669       }
670       p_ntt = &ntt_it->second;
671       lock.unlock();
672 
673       const uint64_t *input =
674           reinterpret_cast<const uint64_t *>(&element.at(0));
675       uint64_t *output = reinterpret_cast<uint64_t *>(&result->at(0));
676       p_ntt->ComputeForward(output, input, 1, 1);
677       result->SetModulus(modulus);
678 
679     } else {
680       NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse(
681           element, m_rootOfUnityReverseTableByModulus[modulus],
682           m_rootOfUnityPreconReverseTableByModulus[modulus], result);
683     }
684 #else
685     NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse(
686         element, m_rootOfUnityReverseTableByModulus[modulus],
687         m_rootOfUnityPreconReverseTableByModulus[modulus], result);
688 #endif
689   } else {
690     NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse(
691         element, m_rootOfUnityReverseTableByModulus[modulus], result);
692   }
693 
694   return;
695 }
696 
697 template <typename VecType>
698 void ChineseRemainderTransformFTT<
699     VecType>::InverseTransformFromBitReverseInPlace(const IntType &rootOfUnity,
700                                                     const usint CycloOrder,
701                                                     VecType *element) {
702   if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) {
703     return;
704   }
705 
706   if (!IsPowerOfTwo(CycloOrder)) {
707     PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two");
708   }
709 
710   usint CycloOrderHf = (CycloOrder >> 1);
711   if (element->GetLength() != CycloOrderHf) {
712     PALISADE_THROW(math_error,
713                    "element size must be equal to CyclotomicOrder / 2");
714   }
715 
716   IntType modulus = element->GetModulus();
717 
718   bool reCompute = false;
719   PALISADE_UNUSED(reCompute); // Used only when WITH_INTEL_HEXL=ON
720   auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus);
721   if (mapSearch == m_rootOfUnityReverseTableByModulus.end() ||
722       mapSearch->second.GetLength() != CycloOrderHf) {
723     PreCompute(rootOfUnity, CycloOrder, modulus);
724     reCompute = true;
725   }
726 
727   usint msb = GetMSB64(CycloOrderHf - 1);
728   if (typeid(IntType) == typeid(NativeInteger)) {
729 #ifdef WITH_INTEL_HEXL
730     if (std::is_same<VecType, NativeVector64>::value) {
731       std::pair<uint64_t, uint64_t> key{element->GetLength(),
732                                         modulus.ConvertToInt()};
733       intel::hexl::NTT *p_ntt;
734       std::unique_lock<std::mutex> lock(m_mtxIntelNTT);
735       auto ntt_it = m_IntelNtt.find(key);
736       if (reCompute || ntt_it == m_IntelNtt.end()) {
737         intel::hexl::NTT ntt(element->GetLength(), modulus.ConvertToInt(),
738                              rootOfUnity.ConvertToInt());
739         m_IntelNtt[key] = std::move(ntt);
740         ntt_it = m_IntelNtt.find(key);
741       }
742       p_ntt = &ntt_it->second;
743       lock.unlock();
744       auto *data = reinterpret_cast<uint64_t *>(&element->at(0));
745       p_ntt->ComputeInverse(data, data, 1, 1);
746       element->SetModulus(modulus);
747     } else {
748       NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
749           m_rootOfUnityInverseReverseTableByModulus[modulus],
750           m_rootOfUnityInversePreconReverseTableByModulus[modulus],
751           m_cycloOrderInverseTableByModulus[modulus][msb],
752           m_cycloOrderInversePreconTableByModulus[modulus][msb], element);
753     }
754 #else
755     NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
756         m_rootOfUnityInverseReverseTableByModulus[modulus],
757         m_rootOfUnityInversePreconReverseTableByModulus[modulus],
758         m_cycloOrderInverseTableByModulus[modulus][msb],
759         m_cycloOrderInversePreconTableByModulus[modulus][msb], element);
760 #endif
761   } else {
762     NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
763         m_rootOfUnityInverseReverseTableByModulus[modulus],
764         m_cycloOrderInverseTableByModulus[modulus][msb], element);
765   }
766 }
767 
768 template <typename VecType>
769 void ChineseRemainderTransformFTT<VecType>::InverseTransformFromBitReverse(
770     const VecType &element, const IntType &rootOfUnity, const usint CycloOrder,
771     VecType *result) {
772   if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) {
773     *result = element;
774     return;
775   }
776 
777   if (!IsPowerOfTwo(CycloOrder)) {
778     PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two");
779   }
780 
781   usint CycloOrderHf = (CycloOrder >> 1);
782   if (result->GetLength() != CycloOrderHf) {
783     PALISADE_THROW(math_error,
784                    "result size must be equal to CyclotomicOrder / 2");
785   }
786 
787   IntType modulus = element.GetModulus();
788 
789   bool reCompute = false;
790   (void)reCompute;  // Avoid unused variable
791   auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus);
792   if (mapSearch == m_rootOfUnityReverseTableByModulus.end() ||
793       mapSearch->second.GetLength() != CycloOrderHf) {
794     PreCompute(rootOfUnity, CycloOrder, modulus);
795     reCompute = true;
796   }
797 
798   usint n = element.GetLength();
799   result->SetModulus(element.GetModulus());
800   for (usint i = 0; i < n; i++) {
801     (*result)[i] = element[i];
802   }
803 
804   usint msb = GetMSB64(CycloOrderHf - 1);
805   if (typeid(IntType) == typeid(NativeInteger)) {
806 #ifdef WITH_INTEL_HEXL
807     if (std::is_same<VecType, NativeVector64>::value) {
808       std::pair<uint64_t, uint64_t> key{element.GetLength(),
809                                         modulus.ConvertToInt()};
810       intel::hexl::NTT *p_ntt;
811       std::unique_lock<std::mutex> lock(m_mtxIntelNTT);
812       auto ntt_it = m_IntelNtt.find(key);
813       if (reCompute || ntt_it == m_IntelNtt.end()) {
814         intel::hexl::NTT ntt(element.GetLength(), modulus.ConvertToInt(),
815                              rootOfUnity.ConvertToInt());
816         m_IntelNtt[key] = std::move(ntt);
817         ntt_it = m_IntelNtt.find(key);
818       }
819       p_ntt = &ntt_it->second;
820       lock.unlock();
821       auto *input = reinterpret_cast<const uint64_t *>(&result->at(0));
822       uint64_t *output = reinterpret_cast<uint64_t *>(&result->at(0));
823       p_ntt->ComputeInverse(output, input, 1, 1);
824       result->SetModulus(modulus);
825     } else {
826       NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
827           m_rootOfUnityInverseReverseTableByModulus[modulus],
828           m_rootOfUnityInversePreconReverseTableByModulus[modulus],
829           m_cycloOrderInverseTableByModulus[modulus][msb],
830           m_cycloOrderInversePreconTableByModulus[modulus][msb], result);
831     }
832 #else
833     NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
834         m_rootOfUnityInverseReverseTableByModulus[modulus],
835         m_rootOfUnityInversePreconReverseTableByModulus[modulus],
836         m_cycloOrderInverseTableByModulus[modulus][msb],
837         m_cycloOrderInversePreconTableByModulus[modulus][msb], result);
838 #endif
839   } else {
840     NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace(
841         m_rootOfUnityInverseReverseTableByModulus[modulus],
842         m_cycloOrderInverseTableByModulus[modulus][msb], result);
843   }
844 
845   return;
846 }
847 
848 template <typename VecType>
849 void ChineseRemainderTransformFTT<VecType>::PreCompute(
850     const IntType &rootOfUnity, const usint CycloOrder,
851     const IntType &modulus) {
852   // Half of cyclo order
853   usint CycloOrderHf = (CycloOrder >> 1);
854 
855   auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus);
856   if (mapSearch == m_rootOfUnityReverseTableByModulus.end() ||
857       mapSearch->second.GetLength() != CycloOrderHf) {
858 #pragma omp critical
859     {
860       IntType x(1), xinv(1);
861       usint msb = GetMSB64(CycloOrderHf - 1);
862       IntType mu = modulus.ComputeMu();
863       VecType Table(CycloOrderHf, modulus);
864       VecType TableI(CycloOrderHf, modulus);
865       IntType rootOfUnityInverse = rootOfUnity.ModInverse(modulus);
866       usint iinv;
867       for (usint i = 0; i < CycloOrderHf; i++) {
868         iinv = ReverseBits(i, msb);
869         Table[iinv] = x;
870         TableI[iinv] = xinv;
871         x.ModMulEq(rootOfUnity, modulus, mu);
872         xinv.ModMulEq(rootOfUnityInverse, modulus, mu);
873       }
874       m_rootOfUnityReverseTableByModulus[modulus] = Table;
875       m_rootOfUnityInverseReverseTableByModulus[modulus] = TableI;
876 
877       VecType TableCOI(msb + 1, modulus);
878       for (usint i = 0; i < msb + 1; i++) {
879         IntType coInv(IntType(1 << i).ModInverse(modulus));
880         TableCOI[i] = coInv;
881       }
882       m_cycloOrderInverseTableByModulus[modulus] = TableCOI;
883 
884       if (typeid(IntType) == typeid(NativeInteger)) {
885         NativeInteger nativeModulus = modulus.ConvertToInt();
886         NativeVector preconTable(CycloOrderHf, nativeModulus);
887         NativeVector preconTableI(CycloOrderHf, nativeModulus);
888 
889         for (usint i = 0; i < CycloOrderHf; i++) {
890           preconTable[i] =
891               NativeInteger(
892                   m_rootOfUnityReverseTableByModulus[modulus][i].ConvertToInt())
893                   .PrepModMulConst(nativeModulus);
894           preconTableI[i] =
895               NativeInteger(
896                   m_rootOfUnityInverseReverseTableByModulus[modulus][i]
897                       .ConvertToInt())
898                   .PrepModMulConst(nativeModulus);
899         }
900 
901         NativeVector preconTableCOI(msb + 1, nativeModulus);
902         for (usint i = 0; i < msb + 1; i++) {
903           preconTableCOI[i] =
904               NativeInteger(
905                   m_cycloOrderInverseTableByModulus[modulus][i].ConvertToInt())
906                   .PrepModMulConst(nativeModulus);
907         }
908 
909         m_rootOfUnityPreconReverseTableByModulus[modulus] = preconTable;
910         m_rootOfUnityInversePreconReverseTableByModulus[modulus] = preconTableI;
911         m_cycloOrderInversePreconTableByModulus[modulus] = preconTableCOI;
912       }
913     }
914   }
915 }
916 
917 template <typename VecType>
918 void ChineseRemainderTransformFTT<VecType>::PreCompute(
919     std::vector<IntType> &rootOfUnity, const usint CycloOrder,
920     std::vector<IntType> &moduliiChain) {
921   usint numOfRootU = rootOfUnity.size();
922   usint numModulii = moduliiChain.size();
923 
924   if (numOfRootU != numModulii) {
925     PALISADE_THROW(
926         math_error,
927         "size of root of unity and size of moduli chain not of same size");
928   }
929 
930   for (usint i = 0; i < numOfRootU; ++i) {
931     IntType currentRoot(rootOfUnity[i]);
932     IntType currentMod(moduliiChain[i]);
933     PreCompute(currentRoot, CycloOrder, currentMod);
934   }
935 }
936 
937 template <typename VecType>
938 void ChineseRemainderTransformFTT<VecType>::Reset() {
939   m_cycloOrderInverseTableByModulus.clear();
940   m_cycloOrderInversePreconTableByModulus.clear();
941   m_rootOfUnityReverseTableByModulus.clear();
942   m_rootOfUnityInverseReverseTableByModulus.clear();
943   m_rootOfUnityPreconReverseTableByModulus.clear();
944   m_rootOfUnityInversePreconReverseTableByModulus.clear();
945 }
946 
947 template <typename VecType>
948 void BluesteinFFT<VecType>::PreComputeDefaultNTTModulusRoot(
949     usint cycloOrder, const IntType &modulus) {
950   usint nttDim = pow(2, ceil(log2(2 * cycloOrder - 1)));
951   const auto nttModulus =
952       FirstPrime<IntType>(log2(nttDim) + 2 * modulus.GetMSB(), nttDim);
953   const auto nttRoot = RootOfUnity(nttDim, nttModulus);
954   const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot};
955   m_defaultNTTModulusRoot[modulus] = nttModulusRoot;
956 
957   PreComputeRootTableForNTT(cycloOrder, nttModulusRoot);
958 }
959 
960 template <typename VecType>
961 void BluesteinFFT<VecType>::PreComputeRootTableForNTT(
962     usint cyclotoOrder, const ModulusRoot<IntType> &nttModulusRoot) {
963   usint nttDim = pow(2, ceil(log2(2 * cyclotoOrder - 1)));
964   const auto &nttModulus = nttModulusRoot.first;
965   const auto &nttRoot = nttModulusRoot.second;
966 
967   IntType root(nttRoot);
968 
969   auto rootInv = root.ModInverse(nttModulus);
970 
971   usint nttDimHf = (nttDim >> 1);
972   VecType rootTable(nttDimHf, nttModulus);
973   VecType rootTableInverse(nttDimHf, nttModulus);
974 
975   IntType x(1);
976   for (usint i = 0; i < nttDimHf; i++) {
977     rootTable[i] = x;
978     x = x.ModMul(root, nttModulus);
979   }
980 
981   x = 1;
982   for (usint i = 0; i < nttDimHf; i++) {
983     rootTableInverse[i] = x;
984     x = x.ModMul(rootInv, nttModulus);
985   }
986 
987   m_rootOfUnityTableByModulusRoot[nttModulusRoot] = rootTable;
988   m_rootOfUnityInverseTableByModulusRoot[nttModulusRoot] = rootTableInverse;
989 }
990 
991 template <typename VecType>
992 void BluesteinFFT<VecType>::PreComputePowers(
993     usint cycloOrder, const ModulusRoot<IntType> &modulusRoot) {
994   const auto &modulus = modulusRoot.first;
995   const auto &root = modulusRoot.second;
996 
997   VecType powers(cycloOrder, modulus);
998   powers[0] = 1;
999   for (usint i = 1; i < cycloOrder; i++) {
1000     auto iSqr = (i * i) % (2 * cycloOrder);
1001     auto val = root.ModExp(IntType(iSqr), modulus);
1002     powers[i] = val;
1003   }
1004   m_powersTableByModulusRoot[modulusRoot] = powers;
1005 }
1006 
1007 template <typename VecType>
1008 void BluesteinFFT<VecType>::PreComputeRBTable(
1009     usint cycloOrder, const ModulusRootPair<IntType> &modulusRootPair) {
1010   const auto &modulusRoot = modulusRootPair.first;
1011   const auto &modulus = modulusRoot.first;
1012   const auto &root = modulusRoot.second;
1013   const auto rootInv = root.ModInverse(modulus);
1014 
1015   const auto &nttModulusRoot = modulusRootPair.second;
1016   const auto &nttModulus = nttModulusRoot.first;
1017   // const auto &nttRoot = nttModulusRoot.second;
1018   // assumes rootTable is precomputed
1019   const auto &rootTable = m_rootOfUnityTableByModulusRoot[nttModulusRoot];
1020   usint nttDim = pow(2, ceil(log2(2 * cycloOrder - 1)));
1021 
1022   VecType b(2 * cycloOrder - 1, modulus);
1023   b[cycloOrder - 1] = 1;
1024   for (usint i = 1; i < cycloOrder; i++) {
1025     auto iSqr = (i * i) % (2 * cycloOrder);
1026     auto val = rootInv.ModExp(IntType(iSqr), modulus);
1027     b[cycloOrder - 1 + i] = val;
1028     b[cycloOrder - 1 - i] = val;
1029   }
1030 
1031   auto Rb = PadZeros(b, nttDim);
1032   Rb.SetModulus(nttModulus);
1033 
1034   VecType RB(nttDim);
1035   NumberTheoreticTransform<VecType>::ForwardTransformIterative(Rb, rootTable,
1036                                                                &RB);
1037   m_RBTableByModulusRootPair[modulusRootPair] = RB;
1038 }
1039 
1040 template <typename VecType>
1041 VecType BluesteinFFT<VecType>::ForwardTransform(const VecType &element,
1042                                                 const IntType &root,
1043                                                 const usint cycloOrder) {
1044   const auto &modulus = element.GetModulus();
1045   const auto &nttModulusRoot = m_defaultNTTModulusRoot[modulus];
1046 
1047   return ForwardTransform(element, root, cycloOrder, nttModulusRoot);
1048 }
1049 
1050 template <typename VecType>
1051 VecType BluesteinFFT<VecType>::ForwardTransform(
1052     const VecType &element, const IntType &root, const usint cycloOrder,
1053     const ModulusRoot<IntType> &nttModulusRoot) {
1054   if (element.GetLength() != cycloOrder) {
1055     PALISADE_THROW(
1056         math_error,
1057         "expected size of element vector should be equal to cyclotomic order");
1058   }
1059 
1060   const auto &modulus = element.GetModulus();
1061   const ModulusRoot<IntType> modulusRoot = {modulus, root};
1062   const VecType &powers = m_powersTableByModulusRoot[modulusRoot];
1063 
1064   const auto &nttModulus = nttModulusRoot.first;
1065   // assumes rootTable is precomputed
1066   const auto &rootTable = m_rootOfUnityTableByModulusRoot[nttModulusRoot];
1067   const auto &rootTableInverse = m_rootOfUnityInverseTableByModulusRoot
1068       [nttModulusRoot];  // assumes rootTableInverse is precomputed
1069   VecType x = element.ModMul(powers);
1070 
1071   usint nttDim = pow(2, ceil(log2(2 * cycloOrder - 1)));
1072   auto Ra = PadZeros(x, nttDim);
1073   Ra.SetModulus(nttModulus);
1074   VecType RA(nttDim);
1075   NumberTheoreticTransform<VecType>::ForwardTransformIterative(Ra, rootTable,
1076                                                                &RA);
1077 
1078   const ModulusRootPair<IntType> modulusRootPair = {modulusRoot,
1079                                                     nttModulusRoot};
1080   const auto &RB = m_RBTableByModulusRootPair[modulusRootPair];
1081 
1082   auto RC = RA.ModMul(RB);
1083   VecType Rc(nttDim);
1084   NumberTheoreticTransform<VecType>::InverseTransformIterative(
1085       RC, rootTableInverse, &Rc);
1086   auto resizeRc = Resize(Rc, cycloOrder - 1, 2 * (cycloOrder - 1));
1087   resizeRc.SetModulus(modulus);
1088   resizeRc.ModEq(modulus);
1089   auto result = resizeRc.ModMul(powers);
1090 
1091   return result;
1092 }
1093 
1094 template <typename VecType>
1095 VecType BluesteinFFT<VecType>::PadZeros(const VecType &a,
1096                                         const usint finalSize) {
1097   usint s = a.GetLength();
1098   VecType result(finalSize, a.GetModulus());
1099 
1100   for (usint i = 0; i < s; i++) {
1101     result[i] = a[i];
1102   }
1103 
1104   for (usint i = a.GetLength(); i < finalSize; i++) {
1105     result[i] = IntType(0);
1106   }
1107 
1108   return result;
1109 }
1110 
1111 template <typename VecType>
1112 VecType BluesteinFFT<VecType>::Resize(const VecType &a, usint lo, usint hi) {
1113   VecType result(hi - lo + 1, a.GetModulus());
1114 
1115   for (usint i = lo, j = 0; i <= hi; i++, j++) {
1116     result[j] = a[i];
1117   }
1118 
1119   return result;
1120 }
1121 
1122 template <typename VecType>
1123 void BluesteinFFT<VecType>::Reset() {
1124   m_rootOfUnityTableByModulusRoot.clear();
1125   m_rootOfUnityInverseTableByModulusRoot.clear();
1126   m_powersTableByModulusRoot.clear();
1127   m_RBTableByModulusRootPair.clear();
1128   m_defaultNTTModulusRoot.clear();
1129 }
1130 
1131 template <typename VecType>
1132 void ChineseRemainderTransformArb<VecType>::SetCylotomicPolynomial(
1133     const VecType &poly, const IntType &mod) {
1134   m_cyclotomicPolyMap[mod] = poly;
1135 }
1136 
1137 template <typename VecType>
1138 void ChineseRemainderTransformArb<VecType>::PreCompute(const usint cyclotoOrder,
1139                                                        const IntType &modulus) {
1140   BluesteinFFT<VecType>::PreComputeDefaultNTTModulusRoot(cyclotoOrder, modulus);
1141 }
1142 
1143 template <typename VecType>
1144 void ChineseRemainderTransformArb<VecType>::SetPreComputedNTTModulus(
1145     usint cyclotoOrder, const IntType &modulus, const IntType &nttModulus,
1146     const IntType &nttRoot) {
1147   const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot};
1148   BluesteinFFT<VecType>::PreComputeRootTableForNTT(cyclotoOrder,
1149                                                    nttModulusRoot);
1150 }
1151 
1152 template <typename VecType>
1153 void ChineseRemainderTransformArb<VecType>::SetPreComputedNTTDivisionModulus(
1154     usint cyclotoOrder, const IntType &modulus, const IntType &nttMod,
1155     const IntType &nttRootBig) {
1156   DEBUG_FLAG(false);
1157 
1158   usint n = GetTotient(cyclotoOrder);
1159   DEBUG("GetTotient(" << cyclotoOrder << ")= " << n);
1160 
1161   usint power = cyclotoOrder - n;
1162   m_nttDivisionDim[cyclotoOrder] = 2 * std::pow(2, ceil(log2(power)));
1163 
1164   usint nttDimBig = std::pow(2, ceil(log2(2 * cyclotoOrder - 1)));
1165 
1166   // Computes the root of unity for the division NTT based on the root of unity
1167   // for regular NTT
1168   IntType nttRoot = nttRootBig.ModExp(
1169       IntType(nttDimBig / m_nttDivisionDim[cyclotoOrder]), nttMod);
1170 
1171   m_DivisionNTTModulus[modulus] = nttMod;
1172   m_DivisionNTTRootOfUnity[modulus] = nttRoot;
1173   // part0 setting of rootTable and inverse rootTable
1174   usint nttDim = m_nttDivisionDim[cyclotoOrder];
1175   IntType root(nttRoot);
1176   auto rootInv = root.ModInverse(nttMod);
1177 
1178   usint nttDimHf = (nttDim >> 1);
1179   VecType rootTable(nttDimHf, nttMod);
1180   VecType rootTableInverse(nttDimHf, nttMod);
1181 
1182   IntType x(1);
1183   for (usint i = 0; i < nttDimHf; i++) {
1184     rootTable[i] = x;
1185     x = x.ModMul(root, nttMod);
1186   }
1187 
1188   x = 1;
1189   for (usint i = 0; i < nttDimHf; i++) {
1190     rootTableInverse[i] = x;
1191     x = x.ModMul(rootInv, nttMod);
1192   }
1193 
1194   m_rootOfUnityDivisionTableByModulus[nttMod] = rootTable;
1195   m_rootOfUnityDivisionInverseTableByModulus[nttMod] = rootTableInverse;
1196 
1197   // end of part0
1198   // part1
1199   const auto &RevCPM =
1200       InversePolyMod(m_cyclotomicPolyMap[modulus], modulus, power);
1201   auto RevCPMPadded = BluesteinFFT<VecType>::PadZeros(RevCPM, nttDim);
1202   RevCPMPadded.SetModulus(nttMod);
1203   // end of part1
1204 
1205   VecType RA(nttDim);
1206   NumberTheoreticTransform<VecType>::ForwardTransformIterative(RevCPMPadded,
1207                                                                rootTable, &RA);
1208   m_cyclotomicPolyReverseNTTMap[modulus] = RA;
1209 
1210   const auto &cycloPoly = m_cyclotomicPolyMap[modulus];
1211 
1212   VecType QForwardTransform(nttDim, nttMod);
1213   for (usint i = 0; i < cycloPoly.GetLength(); i++) {
1214     QForwardTransform[i] = cycloPoly[i];
1215   }
1216 
1217   VecType QFwdResult(nttDim);
1218   NumberTheoreticTransform<VecType>::ForwardTransformIterative(
1219       QForwardTransform, rootTable, &QFwdResult);
1220 
1221   m_cyclotomicPolyNTTMap[modulus] = QFwdResult;
1222 }
1223 
1224 template <typename VecType>
1225 VecType ChineseRemainderTransformArb<VecType>::InversePolyMod(
1226     const VecType &cycloPoly, const IntType &modulus, usint power) {
1227   VecType result(power, modulus);
1228   usint r = ceil(log2(power));
1229   VecType h(1, modulus);  // h is a unit polynomial
1230   h[0] = 1;
1231 
1232   // Precompute the Barrett mu parameter
1233   IntType mu = modulus.ComputeMu();
1234 
1235   for (usint i = 0; i < r; i++) {
1236     usint qDegree = std::pow(2, i + 1);
1237     VecType q(qDegree + 1, modulus);  // q = x^(2^i+1)
1238     q[qDegree] = 1;
1239     auto hSquare = PolynomialMultiplication(h, h);
1240 
1241     auto a = h * IntType(2);
1242     auto b = PolynomialMultiplication(hSquare, cycloPoly);
1243     // b = 2h - gh^2
1244     for (usint j = 0; j < b.GetLength(); j++) {
1245       if (j < a.GetLength()) {
1246         b[j] = a[j].ModSub(b[j], modulus, mu);
1247       } else {
1248         b[j] = modulus.ModSub(b[j], modulus, mu);
1249       }
1250     }
1251     h = PolyMod(b, q, modulus);
1252   }
1253   // take modulo x^power
1254   for (usint i = 0; i < power; i++) {
1255     result[i] = h[i];
1256   }
1257 
1258   return result;
1259 }
1260 
1261 template <typename VecType>
1262 VecType ChineseRemainderTransformArb<VecType>::ForwardTransform(
1263     const VecType &element, const IntType &root, const IntType &nttModulus,
1264     const IntType &nttRoot, const usint cycloOrder) {
1265   usint phim = GetTotient(cycloOrder);
1266   if (element.GetLength() != phim) {
1267     PALISADE_THROW(math_error, "element size should be equal to phim");
1268   }
1269 
1270   const auto &modulus = element.GetModulus();
1271   const ModulusRoot<IntType> modulusRoot = {modulus, root};
1272 
1273   const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot};
1274   const ModulusRootPair<IntType> modulusRootPair = {modulusRoot,
1275                                                     nttModulusRoot};
1276 
1277 #pragma omp critical
1278   {
1279     if (BluesteinFFT<VecType>::m_rootOfUnityTableByModulusRoot[nttModulusRoot]
1280             .GetLength() == 0) {
1281       BluesteinFFT<VecType>::PreComputeRootTableForNTT(cycloOrder,
1282                                                        nttModulusRoot);
1283     }
1284 
1285     if (BluesteinFFT<VecType>::m_powersTableByModulusRoot[modulusRoot]
1286             .GetLength() == 0) {
1287       BluesteinFFT<VecType>::PreComputePowers(cycloOrder, modulusRoot);
1288     }
1289 
1290     if (BluesteinFFT<VecType>::m_RBTableByModulusRootPair[modulusRootPair]
1291             .GetLength() == 0) {
1292       BluesteinFFT<VecType>::PreComputeRBTable(cycloOrder, modulusRootPair);
1293     }
1294   }
1295 
1296   VecType inputToBluestein = Pad(element, cycloOrder, true);
1297   auto outputBluestein = BluesteinFFT<VecType>::ForwardTransform(
1298       inputToBluestein, root, cycloOrder, nttModulusRoot);
1299   VecType output = Drop(outputBluestein, cycloOrder, true, nttModulus, nttRoot);
1300 
1301   return output;
1302 }
1303 
1304 template <typename VecType>
1305 VecType ChineseRemainderTransformArb<VecType>::InverseTransform(
1306     const VecType &element, const IntType &root, const IntType &nttModulus,
1307     const IntType &nttRoot, const usint cycloOrder) {
1308   usint phim = GetTotient(cycloOrder);
1309   if (element.GetLength() != phim) {
1310     PALISADE_THROW(math_error, "element size should be equal to phim");
1311   }
1312 
1313   const auto &modulus = element.GetModulus();
1314   auto rootInverse(root.ModInverse(modulus));
1315   const ModulusRoot<IntType> modulusRootInverse = {modulus, rootInverse};
1316 
1317   const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot};
1318   const ModulusRootPair<IntType> modulusRootPair = {modulusRootInverse,
1319                                                     nttModulusRoot};
1320 
1321 #pragma omp critical
1322   {
1323     if (BluesteinFFT<VecType>::m_rootOfUnityTableByModulusRoot[nttModulusRoot]
1324             .GetLength() == 0) {
1325       BluesteinFFT<VecType>::PreComputeRootTableForNTT(cycloOrder,
1326                                                        nttModulusRoot);
1327     }
1328 
1329     if (BluesteinFFT<VecType>::m_powersTableByModulusRoot[modulusRootInverse]
1330             .GetLength() == 0) {
1331       BluesteinFFT<VecType>::PreComputePowers(cycloOrder, modulusRootInverse);
1332     }
1333 
1334     if (BluesteinFFT<VecType>::m_RBTableByModulusRootPair[modulusRootPair]
1335             .GetLength() == 0) {
1336       BluesteinFFT<VecType>::PreComputeRBTable(cycloOrder, modulusRootPair);
1337     }
1338   }
1339   VecType inputToBluestein = Pad(element, cycloOrder, false);
1340   auto outputBluestein = BluesteinFFT<VecType>::ForwardTransform(
1341       inputToBluestein, rootInverse, cycloOrder, nttModulusRoot);
1342   auto cyclotomicInverse((IntType(cycloOrder)).ModInverse(modulus));
1343   outputBluestein = outputBluestein * cyclotomicInverse;
1344   VecType output =
1345       Drop(outputBluestein, cycloOrder, false, nttModulus, nttRoot);
1346   return output;
1347 }
1348 
1349 template <typename VecType>
1350 VecType ChineseRemainderTransformArb<VecType>::Pad(const VecType &element,
1351                                                    const usint cycloOrder,
1352                                                    bool forward) {
1353   usint n = GetTotient(cycloOrder);
1354 
1355   const auto &modulus = element.GetModulus();
1356   VecType inputToBluestein(cycloOrder, modulus);
1357 
1358   if (forward) {  // Forward transform padding
1359     for (usint i = 0; i < n; i++) {
1360       inputToBluestein[i] = element[i];
1361     }
1362   } else {  // Inverse transform padding
1363     auto tList = GetTotientList(cycloOrder);
1364     usint i = 0;
1365     for (auto &coprime : tList) {
1366       inputToBluestein[coprime] = element[i++];
1367     }
1368   }
1369 
1370   return inputToBluestein;
1371 }
1372 
1373 template <typename VecType>
1374 VecType ChineseRemainderTransformArb<VecType>::Drop(const VecType &element,
1375                                                     const usint cycloOrder,
1376                                                     bool forward,
1377                                                     const IntType &bigMod,
1378                                                     const IntType &bigRoot) {
1379   usint n = GetTotient(cycloOrder);
1380 
1381   const auto &modulus = element.GetModulus();
1382   VecType output(n, modulus);
1383 
1384   if (forward) {  // Forward transform drop
1385     auto tList = GetTotientList(cycloOrder);
1386     for (usint i = 0; i < n; i++) {
1387       output[i] = element[tList[i]];
1388     }
1389   } else {  // Inverse transform drop
1390     if ((n + 1) == cycloOrder) {
1391       IntType mu = modulus.ComputeMu();  // Precompute the Barrett mu parameter
1392       // cycloOrder is prime: Reduce mod Phi_{n+1}(x)
1393       // Reduction involves subtracting the coeff of x^n from all terms
1394       auto coeff_n = element[n];
1395       for (usint i = 0; i < n; i++) {
1396         output[i] = element[i].ModSub(coeff_n, modulus, mu);
1397       }
1398     } else if ((n + 1) * 2 == cycloOrder) {
1399       IntType mu = modulus.ComputeMu();  // Precompute the Barrett mu parameter
1400       // cycloOrder is 2*prime: 2 Step reduction
1401       // First reduce mod x^(n+1)+1 (=(x+1)*Phi_{2*(n+1)}(x))
1402       // Subtract co-efficient of x^(i+n+1) from x^(i)
1403       for (usint i = 0; i < n; i++) {
1404         auto coeff_i = element[i];
1405         auto coeff_ip = element[i + n + 1];
1406         output[i] = coeff_i.ModSub(coeff_ip, modulus, mu);
1407       }
1408       auto coeff_n = element[n].ModSub(element[2 * n + 1], modulus, mu);
1409       // Now reduce mod Phi_{2*(n+1)}(x)
1410       // Similar to the prime case but with alternating signs
1411       for (usint i = 0; i < n; i++) {
1412         if (i % 2 == 0) {
1413           output[i].ModSubEq(coeff_n, modulus, mu);
1414         } else {
1415           output[i].ModAddEq(coeff_n, modulus, mu);
1416         }
1417       }
1418     } else {
1419       // precompute root of unity tables for division NTT
1420       if ((m_rootOfUnityDivisionTableByModulus[bigMod].GetLength() == 0) ||
1421           (m_DivisionNTTModulus[modulus] != bigMod)) {
1422         SetPreComputedNTTDivisionModulus(cycloOrder, modulus, bigMod, bigRoot);
1423       }
1424 
1425       // cycloOrder is arbitrary
1426       // auto output = PolyMod(element, this->m_cyclotomicPolyMap[modulus],
1427       // modulus);
1428 
1429       const auto &nttMod = m_DivisionNTTModulus[modulus];
1430       const auto &rootTable = m_rootOfUnityDivisionTableByModulus[nttMod];
1431       VecType aPadded2(m_nttDivisionDim[cycloOrder], nttMod);
1432       // perform mod operation
1433       usint power = cycloOrder - n;
1434       for (usint i = n; i < element.GetLength(); i++) {
1435         aPadded2[power - (i - n) - 1] = element[i];
1436       }
1437       VecType A(m_nttDivisionDim[cycloOrder]);
1438       NumberTheoreticTransform<VecType>::ForwardTransformIterative(
1439           aPadded2, rootTable, &A);
1440       auto AB = A * m_cyclotomicPolyReverseNTTMap[modulus];
1441       const auto &rootTableInverse =
1442           m_rootOfUnityDivisionInverseTableByModulus[nttMod];
1443       VecType a(m_nttDivisionDim[cycloOrder]);
1444       NumberTheoreticTransform<VecType>::InverseTransformIterative(
1445           AB, rootTableInverse, &a);
1446 
1447       VecType quotient(m_nttDivisionDim[cycloOrder], modulus);
1448       for (usint i = 0; i < power; i++) {
1449         quotient[i] = a[i];
1450       }
1451       quotient.ModEq(modulus);
1452       quotient.SetModulus(nttMod);
1453 
1454       VecType newQuotient(m_nttDivisionDim[cycloOrder]);
1455       NumberTheoreticTransform<VecType>::ForwardTransformIterative(
1456           quotient, rootTable, &newQuotient);
1457       newQuotient *= m_cyclotomicPolyNTTMap[modulus];
1458 
1459       VecType newQuotient2(m_nttDivisionDim[cycloOrder]);
1460       NumberTheoreticTransform<VecType>::InverseTransformIterative(
1461           newQuotient, rootTableInverse, &newQuotient2);
1462       newQuotient2.SetModulus(modulus);
1463       newQuotient2.ModEq(modulus);
1464 
1465       IntType mu = modulus.ComputeMu();  // Precompute the Barrett mu parameter
1466 
1467       for (usint i = 0; i < n; i++) {
1468         output[i] =
1469             element[i].ModSub(newQuotient2[cycloOrder - 1 - i], modulus, mu);
1470       }
1471     }
1472   }
1473   return output;
1474 }
1475 
1476 template <typename VecType>
1477 void ChineseRemainderTransformArb<VecType>::Reset() {
1478   m_cyclotomicPolyMap.clear();
1479   m_cyclotomicPolyReverseNTTMap.clear();
1480   m_cyclotomicPolyNTTMap.clear();
1481   m_rootOfUnityDivisionTableByModulus.clear();
1482   m_rootOfUnityDivisionInverseTableByModulus.clear();
1483   m_DivisionNTTModulus.clear();
1484   m_DivisionNTTRootOfUnity.clear();
1485   m_nttDivisionDim.clear();
1486   BluesteinFFT<VecType>::Reset();
1487 }
1488 
1489 }  // namespace lbcrypto
1490