1 // @file transfrm.cpp This file contains the linear transform interface 2 // functionality. 3 // @author TPOC: contact@palisade-crypto.org 4 // 5 // @copyright Copyright (c) 2019, New Jersey Institute of Technology (NJIT) 6 // All rights reserved. 7 // Redistribution and use in source and binary forms, with or without 8 // modification, are permitted provided that the following conditions are met: 9 // 1. Redistributions of source code must retain the above copyright notice, 10 // this list of conditions and the following disclaimer. 11 // 2. Redistributions in binary form must reproduce the above copyright notice, 12 // this list of conditions and the following disclaimer in the documentation 13 // and/or other materials provided with the distribution. THIS SOFTWARE IS 14 // PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR 15 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 16 // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO 17 // EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 18 // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 20 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 21 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 23 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 25 #include "math/transfrm.h" 26 #include "utils/defines.h" 27 28 #ifdef WITH_INTEL_HEXL 29 #include "hexl/hexl.hpp" 30 #endif 31 32 namespace lbcrypto { 33 34 template <typename VecType> 35 std::map<typename VecType::Integer, VecType> 36 ChineseRemainderTransformFTT<VecType>::m_cycloOrderInverseTableByModulus; 37 38 template <typename VecType> 39 std::map<typename VecType::Integer, NativeVector> ChineseRemainderTransformFTT< 40 VecType>::m_cycloOrderInversePreconTableByModulus; 41 42 template <typename VecType> 43 std::map<typename VecType::Integer, VecType> 44 ChineseRemainderTransformFTT<VecType>::m_rootOfUnityReverseTableByModulus; 45 46 template <typename VecType> 47 std::map<typename VecType::Integer, VecType> ChineseRemainderTransformFTT< 48 VecType>::m_rootOfUnityInverseReverseTableByModulus; 49 50 template <typename VecType> 51 std::map<typename VecType::Integer, NativeVector> ChineseRemainderTransformFTT< 52 VecType>::m_rootOfUnityPreconReverseTableByModulus; 53 54 template <typename VecType> 55 std::map<typename VecType::Integer, NativeVector> ChineseRemainderTransformFTT< 56 VecType>::m_rootOfUnityInversePreconReverseTableByModulus; 57 58 #ifdef WITH_INTEL_HEXL 59 template <typename VecType> 60 // N, modulus 61 std::unordered_map<std::pair<uint64_t, uint64_t>, intel::hexl::NTT, HashPair> 62 ChineseRemainderTransformFTT<VecType>::m_IntelNtt; 63 template <typename VecType> 64 std::mutex ChineseRemainderTransformFTT<VecType>::m_mtxIntelNTT; 65 #endif 66 67 template <typename VecType> 68 std::map<typename VecType::Integer, VecType> 69 ChineseRemainderTransformArb<VecType>::m_cyclotomicPolyMap; 70 71 template <typename VecType> 72 std::map<typename VecType::Integer, VecType> 73 ChineseRemainderTransformArb<VecType>::m_cyclotomicPolyReverseNTTMap; 74 75 template <typename VecType> 76 std::map<typename VecType::Integer, VecType> 77 ChineseRemainderTransformArb<VecType>::m_cyclotomicPolyNTTMap; 78 79 template <typename VecType> 80 std::map<ModulusRoot<typename VecType::Integer>, VecType> 81 BluesteinFFT<VecType>::m_rootOfUnityTableByModulusRoot; 82 83 template <typename VecType> 84 std::map<ModulusRoot<typename VecType::Integer>, VecType> 85 BluesteinFFT<VecType>::m_rootOfUnityInverseTableByModulusRoot; 86 87 template <typename VecType> 88 std::map<ModulusRoot<typename VecType::Integer>, VecType> 89 BluesteinFFT<VecType>::m_powersTableByModulusRoot; 90 91 template <typename VecType> 92 std::map<ModulusRootPair<typename VecType::Integer>, VecType> 93 BluesteinFFT<VecType>::m_RBTableByModulusRootPair; 94 95 template <typename VecType> 96 std::map<typename VecType::Integer, ModulusRoot<typename VecType::Integer>> 97 BluesteinFFT<VecType>::m_defaultNTTModulusRoot; 98 99 template <typename VecType> 100 std::map<typename VecType::Integer, VecType> 101 ChineseRemainderTransformArb<VecType>::m_rootOfUnityDivisionTableByModulus; 102 103 template <typename VecType> 104 std::map<typename VecType::Integer, VecType> ChineseRemainderTransformArb< 105 VecType>::m_rootOfUnityDivisionInverseTableByModulus; 106 107 template <typename VecType> 108 std::map<typename VecType::Integer, typename VecType::Integer> 109 ChineseRemainderTransformArb<VecType>::m_DivisionNTTModulus; 110 111 template <typename VecType> 112 std::map<typename VecType::Integer, typename VecType::Integer> 113 ChineseRemainderTransformArb<VecType>::m_DivisionNTTRootOfUnity; 114 115 template <typename VecType> 116 std::map<usint, usint> ChineseRemainderTransformArb<VecType>::m_nttDivisionDim; 117 118 template <typename VecType> 119 void NumberTheoreticTransform<VecType>::ForwardTransformIterative( 120 const VecType &element, const VecType &rootOfUnityTable, VecType *result) { 121 usint n = element.GetLength(); 122 if (result->GetLength() != n) { 123 PALISADE_THROW( 124 math_error, 125 "size of input element and size of output element not of same size"); 126 } 127 128 auto modulus = element.GetModulus(); 129 IntType mu = modulus.ComputeMu(); 130 result->SetModulus(modulus); 131 132 usint msb = GetMSB64(n - 1); 133 for (size_t i = 0; i < n; i++) { 134 (*result)[i] = element[ReverseBits(i, msb)]; 135 } 136 137 IntType omega, omegaFactor, oddVal, evenVal; 138 usint logm, i, j, indexEven, indexOdd; 139 140 usint logn = GetMSB64(n - 1); 141 for (logm = 1; logm <= logn; logm++) { 142 // calculate the i indexes into the root table one time per loop 143 vector<usint> indexes(1 << (logm - 1)); 144 for (i = 0; i < (usint)(1 << (logm - 1)); i++) { 145 indexes[i] = (i << (logn - logm)); 146 } 147 148 for (j = 0; j < n; j = j + (1 << logm)) { 149 for (i = 0; i < (usint)(1 << (logm - 1)); i++) { 150 omega = rootOfUnityTable[indexes[i]]; 151 indexEven = j + i; 152 indexOdd = indexEven + (1 << (logm - 1)); 153 oddVal = (*result)[indexOdd]; 154 155 omegaFactor = omega.ModMul(oddVal, modulus, mu); 156 evenVal = (*result)[indexEven]; 157 oddVal = evenVal; 158 oddVal += omegaFactor; 159 if (oddVal >= modulus) { 160 oddVal -= modulus; 161 } 162 163 if (evenVal < omegaFactor) { 164 evenVal += modulus; 165 } 166 evenVal -= omegaFactor; 167 168 (*result)[indexEven] = oddVal; 169 (*result)[indexOdd] = evenVal; 170 } 171 } 172 } 173 return; 174 } 175 176 template <typename VecType> 177 void NumberTheoreticTransform<VecType>::InverseTransformIterative( 178 const VecType &element, const VecType &rootOfUnityInverseTable, 179 VecType *result) { 180 usint n = element.GetLength(); 181 182 IntType modulus = element.GetModulus(); 183 IntType mu = modulus.ComputeMu(); 184 185 NumberTheoreticTransform<VecType>::ForwardTransformIterative( 186 element, rootOfUnityInverseTable, result); 187 IntType cycloOrderInv(IntType(n).ModInverse(modulus)); 188 for (usint i = 0; i < n; i++) { 189 (*result)[i].ModMulEq(cycloOrderInv, modulus, mu); 190 } 191 return; 192 } 193 194 template <typename VecType> 195 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace( 196 const VecType &rootOfUnityTable, VecType *element) { 197 usint n = element->GetLength(); 198 IntType modulus = element->GetModulus(); 199 IntType mu = modulus.ComputeMu(); 200 201 usint i, m, j1, j2, indexOmega, indexLo, indexHi; 202 IntType omega, omegaFactor, loVal, hiVal, zero(0); 203 204 usint t = (n >> 1); 205 usint logt1 = GetMSB64(t); 206 for (m = 1; m < n; m <<= 1) { 207 for (i = 0; i < m; ++i) { 208 j1 = i << logt1; 209 j2 = j1 + t; 210 indexOmega = m + i; 211 omega = rootOfUnityTable[indexOmega]; 212 for (indexLo = j1; indexLo < j2; ++indexLo) { 213 indexHi = indexLo + t; 214 loVal = (*element)[indexLo]; 215 omegaFactor = (*element)[indexHi]; 216 omegaFactor.ModMulFastEq(omega, modulus, mu); 217 218 hiVal = loVal + omegaFactor; 219 if (hiVal >= modulus) { 220 hiVal -= modulus; 221 } 222 223 if (loVal < omegaFactor) { 224 loVal += modulus; 225 } 226 loVal -= omegaFactor; 227 228 (*element)[indexLo] = hiVal; 229 (*element)[indexHi] = loVal; 230 } 231 } 232 t >>= 1; 233 logt1--; 234 } 235 return; 236 } 237 238 template <typename VecType> 239 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse( 240 const VecType &element, const VecType &rootOfUnityTable, VecType *result) { 241 usint n = element.GetLength(); 242 if (result->GetLength() != n) { 243 PALISADE_THROW( 244 math_error, 245 "size of input element and size of output element not of same size"); 246 } 247 248 IntType modulus = element.GetModulus(); 249 IntType mu = modulus.ComputeMu(); 250 result->SetModulus(modulus); 251 252 usint i, m, j1, j2, indexOmega, indexLo, indexHi; 253 IntType omega, omegaFactor, loVal, hiVal, zero(0); 254 255 for (i = 0; i < n; ++i) { 256 (*result)[i] = element[i]; 257 } 258 259 usint t = (n >> 1); 260 usint logt1 = GetMSB64(t); 261 for (m = 1; m < n; m <<= 1) { 262 for (i = 0; i < m; ++i) { 263 j1 = i << logt1; 264 j2 = j1 + t; 265 indexOmega = m + i; 266 omega = rootOfUnityTable[indexOmega]; 267 for (indexLo = j1; indexLo < j2; ++indexLo) { 268 indexHi = indexLo + t; 269 loVal = (*result)[indexLo]; 270 omegaFactor = (*result)[indexHi]; 271 if (omegaFactor != zero) { 272 omegaFactor.ModMulFastEq(omega, modulus, mu); 273 274 hiVal = loVal + omegaFactor; 275 if (hiVal >= modulus) { 276 hiVal -= modulus; 277 } 278 279 if (loVal < omegaFactor) { 280 loVal += modulus; 281 } 282 loVal -= omegaFactor; 283 284 (*result)[indexLo] = hiVal; 285 (*result)[indexHi] = loVal; 286 } else { 287 (*result)[indexHi] = loVal; 288 } 289 } 290 } 291 t >>= 1; 292 logt1--; 293 } 294 return; 295 } 296 297 template <typename VecType> 298 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace( 299 const VecType &rootOfUnityTable, const NativeVector &preconRootOfUnityTable, 300 VecType *element) { 301 usint n = element->GetLength(); 302 IntType modulus = element->GetModulus(); 303 304 uint32_t indexOmega, indexHi; 305 NativeInteger preconOmega; 306 IntType omega, omegaFactor, loVal, hiVal, zero(0); 307 308 usint t = (n >> 1); 309 usint logt1 = GetMSB64(t); 310 for (uint32_t m = 1; m < n; m <<= 1, t >>= 1, --logt1) { 311 uint32_t j1, j2; 312 for (uint32_t i = 0; i < m; ++i) { 313 j1 = i << logt1; 314 j2 = j1 + t; 315 indexOmega = m + i; 316 omega = rootOfUnityTable[indexOmega]; 317 preconOmega = preconRootOfUnityTable[indexOmega]; 318 for (uint32_t indexLo = j1; indexLo < j2; ++indexLo) { 319 indexHi = indexLo + t; 320 loVal = (*element)[indexLo]; 321 omegaFactor = (*element)[indexHi]; 322 omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); 323 324 hiVal = loVal + omegaFactor; 325 if (hiVal >= modulus) { 326 hiVal -= modulus; 327 } 328 329 if (loVal < omegaFactor) { 330 loVal += modulus; 331 } 332 loVal -= omegaFactor; 333 334 (*element)[indexLo] = hiVal; 335 (*element)[indexHi] = loVal; 336 } 337 } 338 } 339 return; 340 } 341 342 template <typename VecType> 343 void NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse( 344 const VecType &element, const VecType &rootOfUnityTable, 345 const NativeVector &preconRootOfUnityTable, VecType *result) { 346 usint n = element.GetLength(); 347 348 if (result->GetLength() != n) { 349 PALISADE_THROW( 350 math_error, 351 "size of input element and size of output element not of same size"); 352 } 353 354 IntType modulus = element.GetModulus(); 355 356 result->SetModulus(modulus); 357 358 for (uint32_t i = 0; i < n; ++i) { 359 (*result)[i] = element[i]; 360 } 361 362 uint32_t indexOmega, indexHi; 363 NativeInteger preconOmega; 364 IntType omega, omegaFactor, loVal, hiVal, zero(0); 365 366 usint t = (n >> 1); 367 usint logt1 = GetMSB64(t); 368 for (uint32_t m = 1; m < n; m <<= 1, t >>= 1, --logt1) { 369 uint32_t j1, j2; 370 for (uint32_t i = 0; i < m; ++i) { 371 j1 = i << logt1; 372 j2 = j1 + t; 373 indexOmega = m + i; 374 omega = rootOfUnityTable[indexOmega]; 375 preconOmega = preconRootOfUnityTable[indexOmega]; 376 for (uint32_t indexLo = j1; indexLo < j2; ++indexLo) { 377 indexHi = indexLo + t; 378 loVal = (*result)[indexLo]; 379 omegaFactor = (*result)[indexHi]; 380 if (omegaFactor != zero) { 381 omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); 382 383 hiVal = loVal + omegaFactor; 384 if (hiVal >= modulus) { 385 hiVal -= modulus; 386 } 387 388 if (loVal < omegaFactor) { 389 loVal += modulus; 390 } 391 loVal -= omegaFactor; 392 393 (*result)[indexLo] = hiVal; 394 (*result)[indexHi] = loVal; 395 } else { 396 (*result)[indexHi] = loVal; 397 } 398 } 399 } 400 } 401 return; 402 } 403 404 template <typename VecType> 405 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 406 const VecType &rootOfUnityInverseTable, const IntType &cycloOrderInv, 407 VecType *element) { 408 usint n = element->GetLength(); 409 IntType modulus = element->GetModulus(); 410 IntType mu = modulus.ComputeMu(); 411 412 IntType loVal, hiVal, omega, omegaFactor; 413 usint i, m, j1, j2, indexOmega, indexLo, indexHi; 414 415 usint t = 1; 416 usint logt1 = 1; 417 for (m = (n >> 1); m >= 1; m >>= 1) { 418 for (i = 0; i < m; ++i) { 419 j1 = i << logt1; 420 j2 = j1 + t; 421 indexOmega = m + i; 422 omega = rootOfUnityInverseTable[indexOmega]; 423 424 for (indexLo = j1; indexLo < j2; ++indexLo) { 425 indexHi = indexLo + t; 426 427 hiVal = (*element)[indexHi]; 428 loVal = (*element)[indexLo]; 429 430 omegaFactor = loVal; 431 if (omegaFactor < hiVal) { 432 omegaFactor += modulus; 433 } 434 435 omegaFactor -= hiVal; 436 437 loVal += hiVal; 438 if (loVal >= modulus) { 439 loVal -= modulus; 440 } 441 442 omegaFactor.ModMulFastEq(omega, modulus, mu); 443 444 (*element)[indexLo] = loVal; 445 (*element)[indexHi] = omegaFactor; 446 } 447 } 448 t <<= 1; 449 logt1++; 450 } 451 452 for (i = 0; i < n; i++) { 453 (*element)[i].ModMulFastEq(cycloOrderInv, modulus, mu); 454 } 455 return; 456 } 457 458 template <typename VecType> 459 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverse( 460 const VecType &element, const VecType &rootOfUnityInverseTable, 461 const IntType &cycloOrderInv, VecType *result) { 462 usint n = element.GetLength(); 463 464 if (result->GetLength() != n) { 465 PALISADE_THROW( 466 math_error, 467 "size of input element and size of output element not of same size"); 468 } 469 470 result->SetModulus(element.GetModulus()); 471 472 for (usint i = 0; i < n; i++) { 473 (*result)[i] = element[i]; 474 } 475 InverseTransformFromBitReverseInPlace(rootOfUnityInverseTable, cycloOrderInv, 476 result); 477 } 478 479 template <typename VecType> 480 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 481 const VecType &rootOfUnityInverseTable, 482 const NativeVector &preconRootOfUnityInverseTable, 483 const IntType &cycloOrderInv, const NativeInteger &preconCycloOrderInv, 484 VecType *element) { 485 usint n = element->GetLength(); 486 487 IntType modulus = element->GetModulus(); 488 489 IntType loVal, hiVal, omega, omegaFactor; 490 NativeInteger preconOmega; 491 usint i, m, j1, j2, indexOmega, indexLo, indexHi; 492 493 usint t = 1; 494 usint logt1 = 1; 495 for (m = (n >> 1); m >= 1; m >>= 1) { 496 for (i = 0; i < m; ++i) { 497 j1 = i << logt1; 498 j2 = j1 + t; 499 indexOmega = m + i; 500 omega = rootOfUnityInverseTable[indexOmega]; 501 preconOmega = preconRootOfUnityInverseTable[indexOmega]; 502 503 for (indexLo = j1; indexLo < j2; ++indexLo) { 504 indexHi = indexLo + t; 505 506 hiVal = (*element)[indexHi]; 507 loVal = (*element)[indexLo]; 508 509 omegaFactor = loVal; 510 if (omegaFactor < hiVal) { 511 omegaFactor += modulus; 512 } 513 514 omegaFactor -= hiVal; 515 516 loVal += hiVal; 517 if (loVal >= modulus) { 518 loVal -= modulus; 519 } 520 521 omegaFactor.ModMulFastConstEq(omega, modulus, preconOmega); 522 523 (*element)[indexLo] = loVal; 524 (*element)[indexHi] = omegaFactor; 525 } 526 } 527 t <<= 1; 528 logt1++; 529 } 530 531 for (i = 0; i < n; i++) { 532 (*element)[i].ModMulFastConstEq(cycloOrderInv, modulus, 533 preconCycloOrderInv); 534 } 535 } 536 537 template <typename VecType> 538 void NumberTheoreticTransform<VecType>::InverseTransformFromBitReverse( 539 const VecType &element, const VecType &rootOfUnityInverseTable, 540 const NativeVector &preconRootOfUnityInverseTable, 541 const IntType &cycloOrderInv, const NativeInteger &preconCycloOrderInv, 542 VecType *result) { 543 usint n = element.GetLength(); 544 if (result->GetLength() != n) { 545 PALISADE_THROW( 546 math_error, 547 "size of input element and size of output element not of same size"); 548 } 549 550 result->SetModulus(element.GetModulus()); 551 552 for (usint i = 0; i < n; i++) { 553 (*result)[i] = element[i]; 554 } 555 InverseTransformFromBitReverseInPlace( 556 rootOfUnityInverseTable, preconRootOfUnityInverseTable, cycloOrderInv, 557 preconCycloOrderInv, result); 558 559 return; 560 } 561 562 template <typename VecType> 563 void ChineseRemainderTransformFTT<VecType>::ForwardTransformToBitReverseInPlace( 564 const IntType &rootOfUnity, const usint CycloOrder, VecType *element) { 565 if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) { 566 return; 567 } 568 569 if (!IsPowerOfTwo(CycloOrder)) { 570 PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two"); 571 } 572 573 usint CycloOrderHf = (CycloOrder >> 1); 574 if (element->GetLength() != CycloOrderHf) { 575 PALISADE_THROW(math_error, 576 "element size must be equal to CyclotomicOrder / 2"); 577 } 578 579 IntType modulus = element->GetModulus(); 580 581 bool reCompute = false; 582 PALISADE_UNUSED(reCompute); // Used only when WITH_INTEL_HEXL=ON 583 auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus); 584 if (mapSearch == m_rootOfUnityReverseTableByModulus.end() || 585 mapSearch->second.GetLength() != CycloOrderHf) { 586 PreCompute(rootOfUnity, CycloOrder, modulus); 587 reCompute = true; 588 } 589 590 if (typeid(IntType) == typeid(NativeInteger)) { 591 #ifdef WITH_INTEL_HEXL 592 if (std::is_same<VecType, NativeVector64>::value) { 593 std::pair<uint64_t, uint64_t> key{element->GetLength(), 594 modulus.ConvertToInt()}; 595 intel::hexl::NTT *p_ntt; 596 std::unique_lock<std::mutex> lock(m_mtxIntelNTT); 597 auto ntt_it = m_IntelNtt.find(key); 598 if (reCompute || ntt_it == m_IntelNtt.end()) { 599 intel::hexl::NTT ntt(element->GetLength(), modulus.ConvertToInt(), 600 rootOfUnity.ConvertToInt()); 601 m_IntelNtt[key] = std::move(ntt); 602 ntt_it = m_IntelNtt.find(key); 603 } 604 p_ntt = &ntt_it->second; 605 lock.unlock(); 606 607 auto *data = reinterpret_cast<uint64_t *>(&element->at(0)); 608 p_ntt->ComputeForward(data, data, 1, 1); 609 element->SetModulus(modulus); 610 } else { 611 NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace( 612 m_rootOfUnityReverseTableByModulus[modulus], 613 m_rootOfUnityPreconReverseTableByModulus[modulus], element); 614 } 615 #else 616 NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace( 617 m_rootOfUnityReverseTableByModulus[modulus], 618 m_rootOfUnityPreconReverseTableByModulus[modulus], element); 619 #endif 620 } else { 621 NumberTheoreticTransform<VecType>::ForwardTransformToBitReverseInPlace( 622 m_rootOfUnityReverseTableByModulus[modulus], element); 623 } 624 } 625 626 template <typename VecType> 627 void ChineseRemainderTransformFTT<VecType>::ForwardTransformToBitReverse( 628 const VecType &element, const IntType &rootOfUnity, const usint CycloOrder, 629 VecType *result) { 630 if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) { 631 *result = element; 632 return; 633 } 634 635 if (!IsPowerOfTwo(CycloOrder)) { 636 PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two"); 637 } 638 639 usint CycloOrderHf = (CycloOrder >> 1); 640 if (result->GetLength() != CycloOrderHf) { 641 PALISADE_THROW(math_error, 642 "result size must be equal to CyclotomicOrder / 2"); 643 } 644 645 IntType modulus = element.GetModulus(); 646 647 bool reCompute = false; 648 PALISADE_UNUSED(reCompute); // Used only when WITH_INTEL_HEXL=ON 649 auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus); 650 if (mapSearch == m_rootOfUnityReverseTableByModulus.end() || 651 mapSearch->second.GetLength() != CycloOrderHf) { 652 PreCompute(rootOfUnity, CycloOrder, modulus); 653 reCompute = true; 654 } 655 656 if (typeid(IntType) == typeid(NativeInteger)) { 657 #ifdef WITH_INTEL_HEXL 658 if (std::is_same<VecType, NativeVector64>::value) { 659 std::pair<uint64_t, uint64_t> key{element.GetLength(), 660 modulus.ConvertToInt()}; 661 intel::hexl::NTT *p_ntt; 662 std::unique_lock<std::mutex> lock(m_mtxIntelNTT); 663 auto ntt_it = m_IntelNtt.find(key); 664 if (reCompute || ntt_it == m_IntelNtt.end()) { 665 intel::hexl::NTT ntt(element.GetLength(), modulus.ConvertToInt(), 666 rootOfUnity.ConvertToInt()); 667 m_IntelNtt[key] = std::move(ntt); 668 ntt_it = m_IntelNtt.find(key); 669 } 670 p_ntt = &ntt_it->second; 671 lock.unlock(); 672 673 const uint64_t *input = 674 reinterpret_cast<const uint64_t *>(&element.at(0)); 675 uint64_t *output = reinterpret_cast<uint64_t *>(&result->at(0)); 676 p_ntt->ComputeForward(output, input, 1, 1); 677 result->SetModulus(modulus); 678 679 } else { 680 NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse( 681 element, m_rootOfUnityReverseTableByModulus[modulus], 682 m_rootOfUnityPreconReverseTableByModulus[modulus], result); 683 } 684 #else 685 NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse( 686 element, m_rootOfUnityReverseTableByModulus[modulus], 687 m_rootOfUnityPreconReverseTableByModulus[modulus], result); 688 #endif 689 } else { 690 NumberTheoreticTransform<VecType>::ForwardTransformToBitReverse( 691 element, m_rootOfUnityReverseTableByModulus[modulus], result); 692 } 693 694 return; 695 } 696 697 template <typename VecType> 698 void ChineseRemainderTransformFTT< 699 VecType>::InverseTransformFromBitReverseInPlace(const IntType &rootOfUnity, 700 const usint CycloOrder, 701 VecType *element) { 702 if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) { 703 return; 704 } 705 706 if (!IsPowerOfTwo(CycloOrder)) { 707 PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two"); 708 } 709 710 usint CycloOrderHf = (CycloOrder >> 1); 711 if (element->GetLength() != CycloOrderHf) { 712 PALISADE_THROW(math_error, 713 "element size must be equal to CyclotomicOrder / 2"); 714 } 715 716 IntType modulus = element->GetModulus(); 717 718 bool reCompute = false; 719 PALISADE_UNUSED(reCompute); // Used only when WITH_INTEL_HEXL=ON 720 auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus); 721 if (mapSearch == m_rootOfUnityReverseTableByModulus.end() || 722 mapSearch->second.GetLength() != CycloOrderHf) { 723 PreCompute(rootOfUnity, CycloOrder, modulus); 724 reCompute = true; 725 } 726 727 usint msb = GetMSB64(CycloOrderHf - 1); 728 if (typeid(IntType) == typeid(NativeInteger)) { 729 #ifdef WITH_INTEL_HEXL 730 if (std::is_same<VecType, NativeVector64>::value) { 731 std::pair<uint64_t, uint64_t> key{element->GetLength(), 732 modulus.ConvertToInt()}; 733 intel::hexl::NTT *p_ntt; 734 std::unique_lock<std::mutex> lock(m_mtxIntelNTT); 735 auto ntt_it = m_IntelNtt.find(key); 736 if (reCompute || ntt_it == m_IntelNtt.end()) { 737 intel::hexl::NTT ntt(element->GetLength(), modulus.ConvertToInt(), 738 rootOfUnity.ConvertToInt()); 739 m_IntelNtt[key] = std::move(ntt); 740 ntt_it = m_IntelNtt.find(key); 741 } 742 p_ntt = &ntt_it->second; 743 lock.unlock(); 744 auto *data = reinterpret_cast<uint64_t *>(&element->at(0)); 745 p_ntt->ComputeInverse(data, data, 1, 1); 746 element->SetModulus(modulus); 747 } else { 748 NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 749 m_rootOfUnityInverseReverseTableByModulus[modulus], 750 m_rootOfUnityInversePreconReverseTableByModulus[modulus], 751 m_cycloOrderInverseTableByModulus[modulus][msb], 752 m_cycloOrderInversePreconTableByModulus[modulus][msb], element); 753 } 754 #else 755 NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 756 m_rootOfUnityInverseReverseTableByModulus[modulus], 757 m_rootOfUnityInversePreconReverseTableByModulus[modulus], 758 m_cycloOrderInverseTableByModulus[modulus][msb], 759 m_cycloOrderInversePreconTableByModulus[modulus][msb], element); 760 #endif 761 } else { 762 NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 763 m_rootOfUnityInverseReverseTableByModulus[modulus], 764 m_cycloOrderInverseTableByModulus[modulus][msb], element); 765 } 766 } 767 768 template <typename VecType> 769 void ChineseRemainderTransformFTT<VecType>::InverseTransformFromBitReverse( 770 const VecType &element, const IntType &rootOfUnity, const usint CycloOrder, 771 VecType *result) { 772 if (rootOfUnity == IntType(1) || rootOfUnity == IntType(0)) { 773 *result = element; 774 return; 775 } 776 777 if (!IsPowerOfTwo(CycloOrder)) { 778 PALISADE_THROW(math_error, "CyclotomicOrder is not a power of two"); 779 } 780 781 usint CycloOrderHf = (CycloOrder >> 1); 782 if (result->GetLength() != CycloOrderHf) { 783 PALISADE_THROW(math_error, 784 "result size must be equal to CyclotomicOrder / 2"); 785 } 786 787 IntType modulus = element.GetModulus(); 788 789 bool reCompute = false; 790 (void)reCompute; // Avoid unused variable 791 auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus); 792 if (mapSearch == m_rootOfUnityReverseTableByModulus.end() || 793 mapSearch->second.GetLength() != CycloOrderHf) { 794 PreCompute(rootOfUnity, CycloOrder, modulus); 795 reCompute = true; 796 } 797 798 usint n = element.GetLength(); 799 result->SetModulus(element.GetModulus()); 800 for (usint i = 0; i < n; i++) { 801 (*result)[i] = element[i]; 802 } 803 804 usint msb = GetMSB64(CycloOrderHf - 1); 805 if (typeid(IntType) == typeid(NativeInteger)) { 806 #ifdef WITH_INTEL_HEXL 807 if (std::is_same<VecType, NativeVector64>::value) { 808 std::pair<uint64_t, uint64_t> key{element.GetLength(), 809 modulus.ConvertToInt()}; 810 intel::hexl::NTT *p_ntt; 811 std::unique_lock<std::mutex> lock(m_mtxIntelNTT); 812 auto ntt_it = m_IntelNtt.find(key); 813 if (reCompute || ntt_it == m_IntelNtt.end()) { 814 intel::hexl::NTT ntt(element.GetLength(), modulus.ConvertToInt(), 815 rootOfUnity.ConvertToInt()); 816 m_IntelNtt[key] = std::move(ntt); 817 ntt_it = m_IntelNtt.find(key); 818 } 819 p_ntt = &ntt_it->second; 820 lock.unlock(); 821 auto *input = reinterpret_cast<const uint64_t *>(&result->at(0)); 822 uint64_t *output = reinterpret_cast<uint64_t *>(&result->at(0)); 823 p_ntt->ComputeInverse(output, input, 1, 1); 824 result->SetModulus(modulus); 825 } else { 826 NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 827 m_rootOfUnityInverseReverseTableByModulus[modulus], 828 m_rootOfUnityInversePreconReverseTableByModulus[modulus], 829 m_cycloOrderInverseTableByModulus[modulus][msb], 830 m_cycloOrderInversePreconTableByModulus[modulus][msb], result); 831 } 832 #else 833 NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 834 m_rootOfUnityInverseReverseTableByModulus[modulus], 835 m_rootOfUnityInversePreconReverseTableByModulus[modulus], 836 m_cycloOrderInverseTableByModulus[modulus][msb], 837 m_cycloOrderInversePreconTableByModulus[modulus][msb], result); 838 #endif 839 } else { 840 NumberTheoreticTransform<VecType>::InverseTransformFromBitReverseInPlace( 841 m_rootOfUnityInverseReverseTableByModulus[modulus], 842 m_cycloOrderInverseTableByModulus[modulus][msb], result); 843 } 844 845 return; 846 } 847 848 template <typename VecType> 849 void ChineseRemainderTransformFTT<VecType>::PreCompute( 850 const IntType &rootOfUnity, const usint CycloOrder, 851 const IntType &modulus) { 852 // Half of cyclo order 853 usint CycloOrderHf = (CycloOrder >> 1); 854 855 auto mapSearch = m_rootOfUnityReverseTableByModulus.find(modulus); 856 if (mapSearch == m_rootOfUnityReverseTableByModulus.end() || 857 mapSearch->second.GetLength() != CycloOrderHf) { 858 #pragma omp critical 859 { 860 IntType x(1), xinv(1); 861 usint msb = GetMSB64(CycloOrderHf - 1); 862 IntType mu = modulus.ComputeMu(); 863 VecType Table(CycloOrderHf, modulus); 864 VecType TableI(CycloOrderHf, modulus); 865 IntType rootOfUnityInverse = rootOfUnity.ModInverse(modulus); 866 usint iinv; 867 for (usint i = 0; i < CycloOrderHf; i++) { 868 iinv = ReverseBits(i, msb); 869 Table[iinv] = x; 870 TableI[iinv] = xinv; 871 x.ModMulEq(rootOfUnity, modulus, mu); 872 xinv.ModMulEq(rootOfUnityInverse, modulus, mu); 873 } 874 m_rootOfUnityReverseTableByModulus[modulus] = Table; 875 m_rootOfUnityInverseReverseTableByModulus[modulus] = TableI; 876 877 VecType TableCOI(msb + 1, modulus); 878 for (usint i = 0; i < msb + 1; i++) { 879 IntType coInv(IntType(1 << i).ModInverse(modulus)); 880 TableCOI[i] = coInv; 881 } 882 m_cycloOrderInverseTableByModulus[modulus] = TableCOI; 883 884 if (typeid(IntType) == typeid(NativeInteger)) { 885 NativeInteger nativeModulus = modulus.ConvertToInt(); 886 NativeVector preconTable(CycloOrderHf, nativeModulus); 887 NativeVector preconTableI(CycloOrderHf, nativeModulus); 888 889 for (usint i = 0; i < CycloOrderHf; i++) { 890 preconTable[i] = 891 NativeInteger( 892 m_rootOfUnityReverseTableByModulus[modulus][i].ConvertToInt()) 893 .PrepModMulConst(nativeModulus); 894 preconTableI[i] = 895 NativeInteger( 896 m_rootOfUnityInverseReverseTableByModulus[modulus][i] 897 .ConvertToInt()) 898 .PrepModMulConst(nativeModulus); 899 } 900 901 NativeVector preconTableCOI(msb + 1, nativeModulus); 902 for (usint i = 0; i < msb + 1; i++) { 903 preconTableCOI[i] = 904 NativeInteger( 905 m_cycloOrderInverseTableByModulus[modulus][i].ConvertToInt()) 906 .PrepModMulConst(nativeModulus); 907 } 908 909 m_rootOfUnityPreconReverseTableByModulus[modulus] = preconTable; 910 m_rootOfUnityInversePreconReverseTableByModulus[modulus] = preconTableI; 911 m_cycloOrderInversePreconTableByModulus[modulus] = preconTableCOI; 912 } 913 } 914 } 915 } 916 917 template <typename VecType> 918 void ChineseRemainderTransformFTT<VecType>::PreCompute( 919 std::vector<IntType> &rootOfUnity, const usint CycloOrder, 920 std::vector<IntType> &moduliiChain) { 921 usint numOfRootU = rootOfUnity.size(); 922 usint numModulii = moduliiChain.size(); 923 924 if (numOfRootU != numModulii) { 925 PALISADE_THROW( 926 math_error, 927 "size of root of unity and size of moduli chain not of same size"); 928 } 929 930 for (usint i = 0; i < numOfRootU; ++i) { 931 IntType currentRoot(rootOfUnity[i]); 932 IntType currentMod(moduliiChain[i]); 933 PreCompute(currentRoot, CycloOrder, currentMod); 934 } 935 } 936 937 template <typename VecType> 938 void ChineseRemainderTransformFTT<VecType>::Reset() { 939 m_cycloOrderInverseTableByModulus.clear(); 940 m_cycloOrderInversePreconTableByModulus.clear(); 941 m_rootOfUnityReverseTableByModulus.clear(); 942 m_rootOfUnityInverseReverseTableByModulus.clear(); 943 m_rootOfUnityPreconReverseTableByModulus.clear(); 944 m_rootOfUnityInversePreconReverseTableByModulus.clear(); 945 } 946 947 template <typename VecType> 948 void BluesteinFFT<VecType>::PreComputeDefaultNTTModulusRoot( 949 usint cycloOrder, const IntType &modulus) { 950 usint nttDim = pow(2, ceil(log2(2 * cycloOrder - 1))); 951 const auto nttModulus = 952 FirstPrime<IntType>(log2(nttDim) + 2 * modulus.GetMSB(), nttDim); 953 const auto nttRoot = RootOfUnity(nttDim, nttModulus); 954 const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot}; 955 m_defaultNTTModulusRoot[modulus] = nttModulusRoot; 956 957 PreComputeRootTableForNTT(cycloOrder, nttModulusRoot); 958 } 959 960 template <typename VecType> 961 void BluesteinFFT<VecType>::PreComputeRootTableForNTT( 962 usint cyclotoOrder, const ModulusRoot<IntType> &nttModulusRoot) { 963 usint nttDim = pow(2, ceil(log2(2 * cyclotoOrder - 1))); 964 const auto &nttModulus = nttModulusRoot.first; 965 const auto &nttRoot = nttModulusRoot.second; 966 967 IntType root(nttRoot); 968 969 auto rootInv = root.ModInverse(nttModulus); 970 971 usint nttDimHf = (nttDim >> 1); 972 VecType rootTable(nttDimHf, nttModulus); 973 VecType rootTableInverse(nttDimHf, nttModulus); 974 975 IntType x(1); 976 for (usint i = 0; i < nttDimHf; i++) { 977 rootTable[i] = x; 978 x = x.ModMul(root, nttModulus); 979 } 980 981 x = 1; 982 for (usint i = 0; i < nttDimHf; i++) { 983 rootTableInverse[i] = x; 984 x = x.ModMul(rootInv, nttModulus); 985 } 986 987 m_rootOfUnityTableByModulusRoot[nttModulusRoot] = rootTable; 988 m_rootOfUnityInverseTableByModulusRoot[nttModulusRoot] = rootTableInverse; 989 } 990 991 template <typename VecType> 992 void BluesteinFFT<VecType>::PreComputePowers( 993 usint cycloOrder, const ModulusRoot<IntType> &modulusRoot) { 994 const auto &modulus = modulusRoot.first; 995 const auto &root = modulusRoot.second; 996 997 VecType powers(cycloOrder, modulus); 998 powers[0] = 1; 999 for (usint i = 1; i < cycloOrder; i++) { 1000 auto iSqr = (i * i) % (2 * cycloOrder); 1001 auto val = root.ModExp(IntType(iSqr), modulus); 1002 powers[i] = val; 1003 } 1004 m_powersTableByModulusRoot[modulusRoot] = powers; 1005 } 1006 1007 template <typename VecType> 1008 void BluesteinFFT<VecType>::PreComputeRBTable( 1009 usint cycloOrder, const ModulusRootPair<IntType> &modulusRootPair) { 1010 const auto &modulusRoot = modulusRootPair.first; 1011 const auto &modulus = modulusRoot.first; 1012 const auto &root = modulusRoot.second; 1013 const auto rootInv = root.ModInverse(modulus); 1014 1015 const auto &nttModulusRoot = modulusRootPair.second; 1016 const auto &nttModulus = nttModulusRoot.first; 1017 // const auto &nttRoot = nttModulusRoot.second; 1018 // assumes rootTable is precomputed 1019 const auto &rootTable = m_rootOfUnityTableByModulusRoot[nttModulusRoot]; 1020 usint nttDim = pow(2, ceil(log2(2 * cycloOrder - 1))); 1021 1022 VecType b(2 * cycloOrder - 1, modulus); 1023 b[cycloOrder - 1] = 1; 1024 for (usint i = 1; i < cycloOrder; i++) { 1025 auto iSqr = (i * i) % (2 * cycloOrder); 1026 auto val = rootInv.ModExp(IntType(iSqr), modulus); 1027 b[cycloOrder - 1 + i] = val; 1028 b[cycloOrder - 1 - i] = val; 1029 } 1030 1031 auto Rb = PadZeros(b, nttDim); 1032 Rb.SetModulus(nttModulus); 1033 1034 VecType RB(nttDim); 1035 NumberTheoreticTransform<VecType>::ForwardTransformIterative(Rb, rootTable, 1036 &RB); 1037 m_RBTableByModulusRootPair[modulusRootPair] = RB; 1038 } 1039 1040 template <typename VecType> 1041 VecType BluesteinFFT<VecType>::ForwardTransform(const VecType &element, 1042 const IntType &root, 1043 const usint cycloOrder) { 1044 const auto &modulus = element.GetModulus(); 1045 const auto &nttModulusRoot = m_defaultNTTModulusRoot[modulus]; 1046 1047 return ForwardTransform(element, root, cycloOrder, nttModulusRoot); 1048 } 1049 1050 template <typename VecType> 1051 VecType BluesteinFFT<VecType>::ForwardTransform( 1052 const VecType &element, const IntType &root, const usint cycloOrder, 1053 const ModulusRoot<IntType> &nttModulusRoot) { 1054 if (element.GetLength() != cycloOrder) { 1055 PALISADE_THROW( 1056 math_error, 1057 "expected size of element vector should be equal to cyclotomic order"); 1058 } 1059 1060 const auto &modulus = element.GetModulus(); 1061 const ModulusRoot<IntType> modulusRoot = {modulus, root}; 1062 const VecType &powers = m_powersTableByModulusRoot[modulusRoot]; 1063 1064 const auto &nttModulus = nttModulusRoot.first; 1065 // assumes rootTable is precomputed 1066 const auto &rootTable = m_rootOfUnityTableByModulusRoot[nttModulusRoot]; 1067 const auto &rootTableInverse = m_rootOfUnityInverseTableByModulusRoot 1068 [nttModulusRoot]; // assumes rootTableInverse is precomputed 1069 VecType x = element.ModMul(powers); 1070 1071 usint nttDim = pow(2, ceil(log2(2 * cycloOrder - 1))); 1072 auto Ra = PadZeros(x, nttDim); 1073 Ra.SetModulus(nttModulus); 1074 VecType RA(nttDim); 1075 NumberTheoreticTransform<VecType>::ForwardTransformIterative(Ra, rootTable, 1076 &RA); 1077 1078 const ModulusRootPair<IntType> modulusRootPair = {modulusRoot, 1079 nttModulusRoot}; 1080 const auto &RB = m_RBTableByModulusRootPair[modulusRootPair]; 1081 1082 auto RC = RA.ModMul(RB); 1083 VecType Rc(nttDim); 1084 NumberTheoreticTransform<VecType>::InverseTransformIterative( 1085 RC, rootTableInverse, &Rc); 1086 auto resizeRc = Resize(Rc, cycloOrder - 1, 2 * (cycloOrder - 1)); 1087 resizeRc.SetModulus(modulus); 1088 resizeRc.ModEq(modulus); 1089 auto result = resizeRc.ModMul(powers); 1090 1091 return result; 1092 } 1093 1094 template <typename VecType> 1095 VecType BluesteinFFT<VecType>::PadZeros(const VecType &a, 1096 const usint finalSize) { 1097 usint s = a.GetLength(); 1098 VecType result(finalSize, a.GetModulus()); 1099 1100 for (usint i = 0; i < s; i++) { 1101 result[i] = a[i]; 1102 } 1103 1104 for (usint i = a.GetLength(); i < finalSize; i++) { 1105 result[i] = IntType(0); 1106 } 1107 1108 return result; 1109 } 1110 1111 template <typename VecType> 1112 VecType BluesteinFFT<VecType>::Resize(const VecType &a, usint lo, usint hi) { 1113 VecType result(hi - lo + 1, a.GetModulus()); 1114 1115 for (usint i = lo, j = 0; i <= hi; i++, j++) { 1116 result[j] = a[i]; 1117 } 1118 1119 return result; 1120 } 1121 1122 template <typename VecType> 1123 void BluesteinFFT<VecType>::Reset() { 1124 m_rootOfUnityTableByModulusRoot.clear(); 1125 m_rootOfUnityInverseTableByModulusRoot.clear(); 1126 m_powersTableByModulusRoot.clear(); 1127 m_RBTableByModulusRootPair.clear(); 1128 m_defaultNTTModulusRoot.clear(); 1129 } 1130 1131 template <typename VecType> 1132 void ChineseRemainderTransformArb<VecType>::SetCylotomicPolynomial( 1133 const VecType &poly, const IntType &mod) { 1134 m_cyclotomicPolyMap[mod] = poly; 1135 } 1136 1137 template <typename VecType> 1138 void ChineseRemainderTransformArb<VecType>::PreCompute(const usint cyclotoOrder, 1139 const IntType &modulus) { 1140 BluesteinFFT<VecType>::PreComputeDefaultNTTModulusRoot(cyclotoOrder, modulus); 1141 } 1142 1143 template <typename VecType> 1144 void ChineseRemainderTransformArb<VecType>::SetPreComputedNTTModulus( 1145 usint cyclotoOrder, const IntType &modulus, const IntType &nttModulus, 1146 const IntType &nttRoot) { 1147 const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot}; 1148 BluesteinFFT<VecType>::PreComputeRootTableForNTT(cyclotoOrder, 1149 nttModulusRoot); 1150 } 1151 1152 template <typename VecType> 1153 void ChineseRemainderTransformArb<VecType>::SetPreComputedNTTDivisionModulus( 1154 usint cyclotoOrder, const IntType &modulus, const IntType &nttMod, 1155 const IntType &nttRootBig) { 1156 DEBUG_FLAG(false); 1157 1158 usint n = GetTotient(cyclotoOrder); 1159 DEBUG("GetTotient(" << cyclotoOrder << ")= " << n); 1160 1161 usint power = cyclotoOrder - n; 1162 m_nttDivisionDim[cyclotoOrder] = 2 * std::pow(2, ceil(log2(power))); 1163 1164 usint nttDimBig = std::pow(2, ceil(log2(2 * cyclotoOrder - 1))); 1165 1166 // Computes the root of unity for the division NTT based on the root of unity 1167 // for regular NTT 1168 IntType nttRoot = nttRootBig.ModExp( 1169 IntType(nttDimBig / m_nttDivisionDim[cyclotoOrder]), nttMod); 1170 1171 m_DivisionNTTModulus[modulus] = nttMod; 1172 m_DivisionNTTRootOfUnity[modulus] = nttRoot; 1173 // part0 setting of rootTable and inverse rootTable 1174 usint nttDim = m_nttDivisionDim[cyclotoOrder]; 1175 IntType root(nttRoot); 1176 auto rootInv = root.ModInverse(nttMod); 1177 1178 usint nttDimHf = (nttDim >> 1); 1179 VecType rootTable(nttDimHf, nttMod); 1180 VecType rootTableInverse(nttDimHf, nttMod); 1181 1182 IntType x(1); 1183 for (usint i = 0; i < nttDimHf; i++) { 1184 rootTable[i] = x; 1185 x = x.ModMul(root, nttMod); 1186 } 1187 1188 x = 1; 1189 for (usint i = 0; i < nttDimHf; i++) { 1190 rootTableInverse[i] = x; 1191 x = x.ModMul(rootInv, nttMod); 1192 } 1193 1194 m_rootOfUnityDivisionTableByModulus[nttMod] = rootTable; 1195 m_rootOfUnityDivisionInverseTableByModulus[nttMod] = rootTableInverse; 1196 1197 // end of part0 1198 // part1 1199 const auto &RevCPM = 1200 InversePolyMod(m_cyclotomicPolyMap[modulus], modulus, power); 1201 auto RevCPMPadded = BluesteinFFT<VecType>::PadZeros(RevCPM, nttDim); 1202 RevCPMPadded.SetModulus(nttMod); 1203 // end of part1 1204 1205 VecType RA(nttDim); 1206 NumberTheoreticTransform<VecType>::ForwardTransformIterative(RevCPMPadded, 1207 rootTable, &RA); 1208 m_cyclotomicPolyReverseNTTMap[modulus] = RA; 1209 1210 const auto &cycloPoly = m_cyclotomicPolyMap[modulus]; 1211 1212 VecType QForwardTransform(nttDim, nttMod); 1213 for (usint i = 0; i < cycloPoly.GetLength(); i++) { 1214 QForwardTransform[i] = cycloPoly[i]; 1215 } 1216 1217 VecType QFwdResult(nttDim); 1218 NumberTheoreticTransform<VecType>::ForwardTransformIterative( 1219 QForwardTransform, rootTable, &QFwdResult); 1220 1221 m_cyclotomicPolyNTTMap[modulus] = QFwdResult; 1222 } 1223 1224 template <typename VecType> 1225 VecType ChineseRemainderTransformArb<VecType>::InversePolyMod( 1226 const VecType &cycloPoly, const IntType &modulus, usint power) { 1227 VecType result(power, modulus); 1228 usint r = ceil(log2(power)); 1229 VecType h(1, modulus); // h is a unit polynomial 1230 h[0] = 1; 1231 1232 // Precompute the Barrett mu parameter 1233 IntType mu = modulus.ComputeMu(); 1234 1235 for (usint i = 0; i < r; i++) { 1236 usint qDegree = std::pow(2, i + 1); 1237 VecType q(qDegree + 1, modulus); // q = x^(2^i+1) 1238 q[qDegree] = 1; 1239 auto hSquare = PolynomialMultiplication(h, h); 1240 1241 auto a = h * IntType(2); 1242 auto b = PolynomialMultiplication(hSquare, cycloPoly); 1243 // b = 2h - gh^2 1244 for (usint j = 0; j < b.GetLength(); j++) { 1245 if (j < a.GetLength()) { 1246 b[j] = a[j].ModSub(b[j], modulus, mu); 1247 } else { 1248 b[j] = modulus.ModSub(b[j], modulus, mu); 1249 } 1250 } 1251 h = PolyMod(b, q, modulus); 1252 } 1253 // take modulo x^power 1254 for (usint i = 0; i < power; i++) { 1255 result[i] = h[i]; 1256 } 1257 1258 return result; 1259 } 1260 1261 template <typename VecType> 1262 VecType ChineseRemainderTransformArb<VecType>::ForwardTransform( 1263 const VecType &element, const IntType &root, const IntType &nttModulus, 1264 const IntType &nttRoot, const usint cycloOrder) { 1265 usint phim = GetTotient(cycloOrder); 1266 if (element.GetLength() != phim) { 1267 PALISADE_THROW(math_error, "element size should be equal to phim"); 1268 } 1269 1270 const auto &modulus = element.GetModulus(); 1271 const ModulusRoot<IntType> modulusRoot = {modulus, root}; 1272 1273 const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot}; 1274 const ModulusRootPair<IntType> modulusRootPair = {modulusRoot, 1275 nttModulusRoot}; 1276 1277 #pragma omp critical 1278 { 1279 if (BluesteinFFT<VecType>::m_rootOfUnityTableByModulusRoot[nttModulusRoot] 1280 .GetLength() == 0) { 1281 BluesteinFFT<VecType>::PreComputeRootTableForNTT(cycloOrder, 1282 nttModulusRoot); 1283 } 1284 1285 if (BluesteinFFT<VecType>::m_powersTableByModulusRoot[modulusRoot] 1286 .GetLength() == 0) { 1287 BluesteinFFT<VecType>::PreComputePowers(cycloOrder, modulusRoot); 1288 } 1289 1290 if (BluesteinFFT<VecType>::m_RBTableByModulusRootPair[modulusRootPair] 1291 .GetLength() == 0) { 1292 BluesteinFFT<VecType>::PreComputeRBTable(cycloOrder, modulusRootPair); 1293 } 1294 } 1295 1296 VecType inputToBluestein = Pad(element, cycloOrder, true); 1297 auto outputBluestein = BluesteinFFT<VecType>::ForwardTransform( 1298 inputToBluestein, root, cycloOrder, nttModulusRoot); 1299 VecType output = Drop(outputBluestein, cycloOrder, true, nttModulus, nttRoot); 1300 1301 return output; 1302 } 1303 1304 template <typename VecType> 1305 VecType ChineseRemainderTransformArb<VecType>::InverseTransform( 1306 const VecType &element, const IntType &root, const IntType &nttModulus, 1307 const IntType &nttRoot, const usint cycloOrder) { 1308 usint phim = GetTotient(cycloOrder); 1309 if (element.GetLength() != phim) { 1310 PALISADE_THROW(math_error, "element size should be equal to phim"); 1311 } 1312 1313 const auto &modulus = element.GetModulus(); 1314 auto rootInverse(root.ModInverse(modulus)); 1315 const ModulusRoot<IntType> modulusRootInverse = {modulus, rootInverse}; 1316 1317 const ModulusRoot<IntType> nttModulusRoot = {nttModulus, nttRoot}; 1318 const ModulusRootPair<IntType> modulusRootPair = {modulusRootInverse, 1319 nttModulusRoot}; 1320 1321 #pragma omp critical 1322 { 1323 if (BluesteinFFT<VecType>::m_rootOfUnityTableByModulusRoot[nttModulusRoot] 1324 .GetLength() == 0) { 1325 BluesteinFFT<VecType>::PreComputeRootTableForNTT(cycloOrder, 1326 nttModulusRoot); 1327 } 1328 1329 if (BluesteinFFT<VecType>::m_powersTableByModulusRoot[modulusRootInverse] 1330 .GetLength() == 0) { 1331 BluesteinFFT<VecType>::PreComputePowers(cycloOrder, modulusRootInverse); 1332 } 1333 1334 if (BluesteinFFT<VecType>::m_RBTableByModulusRootPair[modulusRootPair] 1335 .GetLength() == 0) { 1336 BluesteinFFT<VecType>::PreComputeRBTable(cycloOrder, modulusRootPair); 1337 } 1338 } 1339 VecType inputToBluestein = Pad(element, cycloOrder, false); 1340 auto outputBluestein = BluesteinFFT<VecType>::ForwardTransform( 1341 inputToBluestein, rootInverse, cycloOrder, nttModulusRoot); 1342 auto cyclotomicInverse((IntType(cycloOrder)).ModInverse(modulus)); 1343 outputBluestein = outputBluestein * cyclotomicInverse; 1344 VecType output = 1345 Drop(outputBluestein, cycloOrder, false, nttModulus, nttRoot); 1346 return output; 1347 } 1348 1349 template <typename VecType> 1350 VecType ChineseRemainderTransformArb<VecType>::Pad(const VecType &element, 1351 const usint cycloOrder, 1352 bool forward) { 1353 usint n = GetTotient(cycloOrder); 1354 1355 const auto &modulus = element.GetModulus(); 1356 VecType inputToBluestein(cycloOrder, modulus); 1357 1358 if (forward) { // Forward transform padding 1359 for (usint i = 0; i < n; i++) { 1360 inputToBluestein[i] = element[i]; 1361 } 1362 } else { // Inverse transform padding 1363 auto tList = GetTotientList(cycloOrder); 1364 usint i = 0; 1365 for (auto &coprime : tList) { 1366 inputToBluestein[coprime] = element[i++]; 1367 } 1368 } 1369 1370 return inputToBluestein; 1371 } 1372 1373 template <typename VecType> 1374 VecType ChineseRemainderTransformArb<VecType>::Drop(const VecType &element, 1375 const usint cycloOrder, 1376 bool forward, 1377 const IntType &bigMod, 1378 const IntType &bigRoot) { 1379 usint n = GetTotient(cycloOrder); 1380 1381 const auto &modulus = element.GetModulus(); 1382 VecType output(n, modulus); 1383 1384 if (forward) { // Forward transform drop 1385 auto tList = GetTotientList(cycloOrder); 1386 for (usint i = 0; i < n; i++) { 1387 output[i] = element[tList[i]]; 1388 } 1389 } else { // Inverse transform drop 1390 if ((n + 1) == cycloOrder) { 1391 IntType mu = modulus.ComputeMu(); // Precompute the Barrett mu parameter 1392 // cycloOrder is prime: Reduce mod Phi_{n+1}(x) 1393 // Reduction involves subtracting the coeff of x^n from all terms 1394 auto coeff_n = element[n]; 1395 for (usint i = 0; i < n; i++) { 1396 output[i] = element[i].ModSub(coeff_n, modulus, mu); 1397 } 1398 } else if ((n + 1) * 2 == cycloOrder) { 1399 IntType mu = modulus.ComputeMu(); // Precompute the Barrett mu parameter 1400 // cycloOrder is 2*prime: 2 Step reduction 1401 // First reduce mod x^(n+1)+1 (=(x+1)*Phi_{2*(n+1)}(x)) 1402 // Subtract co-efficient of x^(i+n+1) from x^(i) 1403 for (usint i = 0; i < n; i++) { 1404 auto coeff_i = element[i]; 1405 auto coeff_ip = element[i + n + 1]; 1406 output[i] = coeff_i.ModSub(coeff_ip, modulus, mu); 1407 } 1408 auto coeff_n = element[n].ModSub(element[2 * n + 1], modulus, mu); 1409 // Now reduce mod Phi_{2*(n+1)}(x) 1410 // Similar to the prime case but with alternating signs 1411 for (usint i = 0; i < n; i++) { 1412 if (i % 2 == 0) { 1413 output[i].ModSubEq(coeff_n, modulus, mu); 1414 } else { 1415 output[i].ModAddEq(coeff_n, modulus, mu); 1416 } 1417 } 1418 } else { 1419 // precompute root of unity tables for division NTT 1420 if ((m_rootOfUnityDivisionTableByModulus[bigMod].GetLength() == 0) || 1421 (m_DivisionNTTModulus[modulus] != bigMod)) { 1422 SetPreComputedNTTDivisionModulus(cycloOrder, modulus, bigMod, bigRoot); 1423 } 1424 1425 // cycloOrder is arbitrary 1426 // auto output = PolyMod(element, this->m_cyclotomicPolyMap[modulus], 1427 // modulus); 1428 1429 const auto &nttMod = m_DivisionNTTModulus[modulus]; 1430 const auto &rootTable = m_rootOfUnityDivisionTableByModulus[nttMod]; 1431 VecType aPadded2(m_nttDivisionDim[cycloOrder], nttMod); 1432 // perform mod operation 1433 usint power = cycloOrder - n; 1434 for (usint i = n; i < element.GetLength(); i++) { 1435 aPadded2[power - (i - n) - 1] = element[i]; 1436 } 1437 VecType A(m_nttDivisionDim[cycloOrder]); 1438 NumberTheoreticTransform<VecType>::ForwardTransformIterative( 1439 aPadded2, rootTable, &A); 1440 auto AB = A * m_cyclotomicPolyReverseNTTMap[modulus]; 1441 const auto &rootTableInverse = 1442 m_rootOfUnityDivisionInverseTableByModulus[nttMod]; 1443 VecType a(m_nttDivisionDim[cycloOrder]); 1444 NumberTheoreticTransform<VecType>::InverseTransformIterative( 1445 AB, rootTableInverse, &a); 1446 1447 VecType quotient(m_nttDivisionDim[cycloOrder], modulus); 1448 for (usint i = 0; i < power; i++) { 1449 quotient[i] = a[i]; 1450 } 1451 quotient.ModEq(modulus); 1452 quotient.SetModulus(nttMod); 1453 1454 VecType newQuotient(m_nttDivisionDim[cycloOrder]); 1455 NumberTheoreticTransform<VecType>::ForwardTransformIterative( 1456 quotient, rootTable, &newQuotient); 1457 newQuotient *= m_cyclotomicPolyNTTMap[modulus]; 1458 1459 VecType newQuotient2(m_nttDivisionDim[cycloOrder]); 1460 NumberTheoreticTransform<VecType>::InverseTransformIterative( 1461 newQuotient, rootTableInverse, &newQuotient2); 1462 newQuotient2.SetModulus(modulus); 1463 newQuotient2.ModEq(modulus); 1464 1465 IntType mu = modulus.ComputeMu(); // Precompute the Barrett mu parameter 1466 1467 for (usint i = 0; i < n; i++) { 1468 output[i] = 1469 element[i].ModSub(newQuotient2[cycloOrder - 1 - i], modulus, mu); 1470 } 1471 } 1472 } 1473 return output; 1474 } 1475 1476 template <typename VecType> 1477 void ChineseRemainderTransformArb<VecType>::Reset() { 1478 m_cyclotomicPolyMap.clear(); 1479 m_cyclotomicPolyReverseNTTMap.clear(); 1480 m_cyclotomicPolyNTTMap.clear(); 1481 m_rootOfUnityDivisionTableByModulus.clear(); 1482 m_rootOfUnityDivisionInverseTableByModulus.clear(); 1483 m_DivisionNTTModulus.clear(); 1484 m_DivisionNTTRootOfUnity.clear(); 1485 m_nttDivisionDim.clear(); 1486 BluesteinFFT<VecType>::Reset(); 1487 } 1488 1489 } // namespace lbcrypto 1490