1 /* 2 * Copyright (c) by CryptoLab inc. 3 * This program is licensed under a 4 * Creative Commons Attribution-NonCommercial 3.0 Unported License. 5 * You should have received a copy of the license along with this 6 * work. If not, see <http://creativecommons.org/licenses/by-nc/3.0/>. 7 */ 8 #ifndef HEAAN_RINGMULTIPLIER_H_ 9 #define HEAAN_RINGMULTIPLIER_H_ 10 11 #include <cstdint> 12 #include <vector> 13 #include <NTL/ZZ.h> 14 #include "Params.h" 15 16 using namespace std; 17 using namespace NTL; 18 19 class RingMultiplier { 20 public: 21 22 uint64_t* pVec = new uint64_t[nprimes]; 23 uint64_t* prVec = new uint64_t[nprimes]; 24 uint64_t* pInvVec = new uint64_t[nprimes]; 25 uint64_t** scaledRootPows = new uint64_t*[nprimes]; 26 uint64_t** scaledRootInvPows = new uint64_t*[nprimes]; 27 uint64_t* scaledNInv = new uint64_t[nprimes]; 28 _ntl_general_rem_one_struct* red_ss_array[nprimes]; 29 mulmod_precon_t* coeffpinv_array[nprimes]; 30 31 ZZ* pProd = new ZZ[nprimes]; 32 ZZ* pProdh = new ZZ[nprimes]; 33 ZZ** pHat = new ZZ*[nprimes]; 34 uint64_t** pHatInvModp = new uint64_t*[nprimes]; 35 36 RingMultiplier(); 37 38 bool primeTest(uint64_t p); 39 40 void NTT(uint64_t* a, long index); 41 void INTT(uint64_t* a, long index); 42 43 void CRT(uint64_t* rx, ZZ* x, const long np); 44 45 void addNTTAndEqual(uint64_t* ra, uint64_t* rb, const long np); 46 47 void reconstruct(ZZ* x, uint64_t* rx, long np, const ZZ& QQ); 48 49 void mult(ZZ* x, ZZ* a, ZZ* b, long np, const ZZ& QQ); 50 51 void multNTT(ZZ* x, ZZ* a, uint64_t* rb, long np, const ZZ& QQ); 52 53 void multDNTT(ZZ* x, uint64_t* ra, uint64_t* rb, long np, const ZZ& QQ); 54 55 void multAndEqual(ZZ* a, ZZ* b, long np, const ZZ& QQ); 56 57 void multNTTAndEqual(ZZ* a, uint64_t* rb, long np, const ZZ& QQ); 58 59 void square(ZZ* x, ZZ* a, long np, const ZZ& QQ); 60 61 void squareNTT(ZZ* x, uint64_t* ra, long np, const ZZ& QQ); 62 63 void squareAndEqual(ZZ* a, long np, const ZZ& QQ); 64 65 void mulMod(uint64_t& r, uint64_t a, uint64_t b, uint64_t p); 66 67 void mulModBarrett(uint64_t& r, uint64_t a, uint64_t b, uint64_t p, uint64_t pr); 68 void butt(uint64_t& a, uint64_t& b, uint64_t W, uint64_t p, uint64_t pInv); 69 void ibutt(uint64_t& a, uint64_t& b, uint64_t W, uint64_t p, uint64_t pInv); 70 void idivN(uint64_t& a, uint64_t NScale, uint64_t p, uint64_t pInv); 71 72 uint64_t invMod(uint64_t x, uint64_t p); 73 74 uint64_t powMod(uint64_t x, uint64_t y, uint64_t p); 75 76 uint64_t inv(uint64_t x); 77 78 uint64_t pow(uint64_t x, uint64_t y); 79 80 uint32_t bitReverse(uint32_t x); 81 82 void findPrimeFactors(vector<uint64_t> &s, uint64_t number); 83 84 uint64_t findPrimitiveRoot(uint64_t m); 85 86 uint64_t findMthRootOfUnity(uint64_t M, uint64_t p); 87 88 }; 89 90 #endif /* RINGMULTIPLIER_H_ */ 91