1 /*
2  * Copyright (c) by CryptoLab inc.
3  * This program is licensed under a
4  * Creative Commons Attribution-NonCommercial 3.0 Unported License.
5  * You should have received a copy of the license along with this
6  * work.  If not, see <http://creativecommons.org/licenses/by-nc/3.0/>.
7  */
8 #ifndef HEAAN_RINGMULTIPLIER_H_
9 #define HEAAN_RINGMULTIPLIER_H_
10 
11 #include <cstdint>
12 #include <vector>
13 #include <NTL/ZZ.h>
14 #include "Params.h"
15 
16 using namespace std;
17 using namespace NTL;
18 
19 class RingMultiplier {
20 public:
21 
22 	uint64_t* pVec = new uint64_t[nprimes];
23 	uint64_t* prVec = new uint64_t[nprimes];
24 	uint64_t* pInvVec = new uint64_t[nprimes];
25 	uint64_t** scaledRootPows = new uint64_t*[nprimes];
26 	uint64_t** scaledRootInvPows = new uint64_t*[nprimes];
27 	uint64_t* scaledNInv = new uint64_t[nprimes];
28 	_ntl_general_rem_one_struct* red_ss_array[nprimes];
29 	mulmod_precon_t* coeffpinv_array[nprimes];
30 
31 	ZZ* pProd = new ZZ[nprimes];
32 	ZZ* pProdh = new ZZ[nprimes];
33 	ZZ** pHat = new ZZ*[nprimes];
34 	uint64_t** pHatInvModp = new uint64_t*[nprimes];
35 
36 	RingMultiplier();
37 
38 	bool primeTest(uint64_t p);
39 
40 	void NTT(uint64_t* a, long index);
41 	void INTT(uint64_t* a, long index);
42 
43 	void CRT(uint64_t* rx, ZZ* x, const long np);
44 
45 	void addNTTAndEqual(uint64_t* ra, uint64_t* rb, const long np);
46 
47 	void reconstruct(ZZ* x, uint64_t* rx, long np, const ZZ& QQ);
48 
49 	void mult(ZZ* x, ZZ* a, ZZ* b, long np, const ZZ& QQ);
50 
51 	void multNTT(ZZ* x, ZZ* a, uint64_t* rb, long np, const ZZ& QQ);
52 
53 	void multDNTT(ZZ* x, uint64_t* ra, uint64_t* rb, long np, const ZZ& QQ);
54 
55 	void multAndEqual(ZZ* a, ZZ* b, long np, const ZZ& QQ);
56 
57 	void multNTTAndEqual(ZZ* a, uint64_t* rb, long np, const ZZ& QQ);
58 
59 	void square(ZZ* x, ZZ* a, long np, const ZZ& QQ);
60 
61 	void squareNTT(ZZ* x, uint64_t* ra, long np, const ZZ& QQ);
62 
63 	void squareAndEqual(ZZ* a, long np, const ZZ& QQ);
64 
65 	void mulMod(uint64_t& r, uint64_t a, uint64_t b, uint64_t p);
66 
67 	void mulModBarrett(uint64_t& r, uint64_t a, uint64_t b, uint64_t p, uint64_t pr);
68 	void butt(uint64_t& a, uint64_t& b, uint64_t W, uint64_t p, uint64_t pInv);
69 	void ibutt(uint64_t& a, uint64_t& b, uint64_t W, uint64_t p, uint64_t pInv);
70 	void idivN(uint64_t& a, uint64_t NScale, uint64_t p, uint64_t pInv);
71 
72 	uint64_t invMod(uint64_t x, uint64_t p);
73 
74 	uint64_t powMod(uint64_t x, uint64_t y, uint64_t p);
75 
76 	uint64_t inv(uint64_t x);
77 
78 	uint64_t pow(uint64_t x, uint64_t y);
79 
80 	uint32_t bitReverse(uint32_t x);
81 
82 	void findPrimeFactors(vector<uint64_t> &s, uint64_t number);
83 
84 	uint64_t findPrimitiveRoot(uint64_t m);
85 
86 	uint64_t findMthRootOfUnity(uint64_t M, uint64_t p);
87 
88 };
89 
90 #endif /* RINGMULTIPLIER_H_ */
91