1 /*
2 * Copyright (c) by CryptoLab inc.
3 * This program is licensed under a
4 * Creative Commons Attribution-NonCommercial 3.0 Unported License.
5 * You should have received a copy of the license along with this
6 * work.  If not, see <http://creativecommons.org/licenses/by-nc/3.0/>.
7 */
8 
9 #include "RingMultiplier.h"
10 
11 #include <NTL/BasicThreadPool.h>
12 #include <NTL/tools.h>
13 #include <cmath>
14 #include <cstdlib>
15 #include <iterator>
16 
RingMultiplier()17 RingMultiplier::RingMultiplier() {
18 
19 	uint64_t primetest = (1ULL << pbnd) + 1;
20 	for (long i = 0; i < nprimes; ++i) {
21 		while(true) {
22 			primetest += M;
23 			if(primeTest(primetest)) {
24 				pVec[i] = primetest;
25 				break;
26 			}
27 		}
28 	}
29 
30 	for (long i = 0; i < nprimes; ++i) {
31 		red_ss_array[i] = _ntl_general_rem_one_struct_build(pVec[i]);
32 		pInvVec[i] = inv(pVec[i]);
33 		prVec[i] = (static_cast<unsigned __int128>(1) << kbar2) / pVec[i];
34 		uint64_t root = findMthRootOfUnity(M, pVec[i]);
35 		uint64_t rootinv = invMod(root, pVec[i]);
36 		uint64_t NInv = invMod(N, pVec[i]);
37 		mulMod(scaledNInv[i], NInv, (1ULL << 32), pVec[i]);
38 		mulMod(scaledNInv[i], scaledNInv[i], (1ULL << 32), pVec[i]);
39 		scaledRootPows[i] = new uint64_t[N]();
40 		scaledRootInvPows[i] = new uint64_t[N]();
41 		uint64_t power = 1;
42 		uint64_t powerInv = 1;
43 		for (long j = 0; j < N; ++j) {
44 			uint32_t jprime = bitReverse(static_cast<uint32_t>(j)) >> (32 - logN);
45 			uint64_t rootpow = power;
46 			mulMod(scaledRootPows[i][jprime], rootpow,(1ULL << 32), pVec[i]);
47 			mulMod(scaledRootPows[i][jprime], scaledRootPows[i][jprime], (1ULL << 32), pVec[i]);
48 			uint64_t rootpowInv = powerInv;
49 			mulMod(scaledRootInvPows[i][jprime], rootpowInv, (1ULL << 32), pVec[i]);
50 			mulMod(scaledRootInvPows[i][jprime], scaledRootInvPows[i][jprime], (1ULL << 32), pVec[i]);
51 			mulMod(power, power, root, pVec[i]);
52 			mulMod(powerInv, powerInv, rootinv, pVec[i]);
53 		}
54 	}
55 
56 	for (long i = 0; i < nprimes; ++i) {
57 		coeffpinv_array[i] = new mulmod_precon_t[i + 1];
58 		pProd[i] = (i == 0) ? to_ZZ((long) pVec[i]) : pProd[i - 1] * (long) pVec[i];
59 		pProdh[i] = pProd[i] / 2;
60 		pHat[i] = new ZZ[i + 1];
61 		pHatInvModp[i] = new uint64_t[i + 1];
62 		for (long j = 0; j < i + 1; ++j) {
63 			pHat[i][j] = ZZ(1);
64 			for (long k = 0; k < j; ++k) {
65 				pHat[i][j] *= (long) pVec[k];
66 			}
67 			for (long k = j + 1; k < i + 1; ++k) {
68 				pHat[i][j] *= (long) pVec[k];
69 			}
70 			pHatInvModp[i][j] = to_long(pHat[i][j] % (long) pVec[j]);
71 			pHatInvModp[i][j] = invMod(pHatInvModp[i][j], pVec[j]);
72 			coeffpinv_array[i][j] = PrepMulModPrecon(pHatInvModp[i][j], pVec[j]);
73 		}
74 	}
75 }
76 
primeTest(uint64_t p)77 bool RingMultiplier::primeTest(uint64_t p) {
78 	if(p < 2) return false;
79 	if(p != 2 && p % 2 == 0) return false;
80 	uint64_t s = p - 1;
81 	while(s % 2 == 0) {
82 		s /= 2;
83 	}
84 	for(long i = 0; i < 200; i++) {
85 		uint64_t temp1 = rand();
86 		temp1  = (temp1 << 32) | rand();
87 		temp1 = temp1 % (p - 1) + 1;
88 		uint64_t temp2 = s;
89 		uint64_t mod = powMod(temp1,temp2,p);
90 		while (temp2 != p - 1 && mod != 1 && mod != p - 1) {
91 			mulMod(mod, mod, mod, p);
92 		    temp2 *= 2;
93 		}
94 		if (mod != p - 1 && temp2 % 2 == 0) return false;
95 	}
96 	return true;
97 }
98 
NTT(uint64_t * a,long index)99 void RingMultiplier::NTT(uint64_t* a, long index) {
100 	long t = N;
101 	long logt1 = logN + 1;
102 	uint64_t p = pVec[index];
103 	uint64_t pInv = pInvVec[index];
104 	for (long m = 1; m < N; m <<= 1) {
105 		t >>= 1;
106 		logt1 -= 1;
107 		for (long i = 0; i < m; i++) {
108 			long j1 = i << logt1;
109 			long j2 = j1 + t - 1;
110 			uint64_t W = scaledRootPows[index][m + i];
111 			for (long j = j1; j <= j2; j++) {
112 				butt(a[j], a[j+t], W, p, pInv);
113 			}
114 		}
115 	}
116 }
117 
INTT(uint64_t * a,long index)118 void RingMultiplier::INTT(uint64_t* a, long index) {
119 	uint64_t p = pVec[index];
120 	uint64_t pInv = pInvVec[index];
121 	long t = 1;
122 	for (long m = N; m > 1; m >>= 1) {
123 		long j1 = 0;
124 		long h = m >> 1;
125 		for (long i = 0; i < h; i++) {
126 			long j2 = j1 + t - 1;
127 			uint64_t W = scaledRootInvPows[index][h + i];
128 			for (long j = j1; j <= j2; j++) {
129 				ibutt(a[j], a[j+t], W, p, pInv);
130 			}
131 			j1 += (t << 1);
132 		}
133 		t <<= 1;
134 	}
135 
136 	uint64_t NScale = scaledNInv[index];
137 	for (long i = 0; i < N; i++) {
138 		idivN(a[i], NScale, p, pInv);
139 	}
140 }
141 
142 //----------------------------------------------------------------------------------
143 //   FFT
144 //----------------------------------------------------------------------------------
145 
CRT(uint64_t * rx,ZZ * x,const long np)146 void RingMultiplier::CRT(uint64_t* rx, ZZ* x, const long np) {
147 	NTL_EXEC_RANGE(np, first, last);
148 	for (long i = first; i < last; ++i) {
149 		uint64_t* rxi = rx + (i << logN);
150 		uint64_t pi = pVec[i];
151 		_ntl_general_rem_one_struct* red_ss = red_ss_array[i];
152 		for (long n = 0; n < N; ++n) {
153 			rxi[n] = _ntl_general_rem_one_struct_apply(x[n].rep, pi, red_ss);
154 		}
155 		NTT(rxi, i);
156 	}
157 	NTL_EXEC_RANGE_END;
158 }
159 
addNTTAndEqual(uint64_t * ra,uint64_t * rb,const long np)160 void RingMultiplier::addNTTAndEqual(uint64_t* ra, uint64_t* rb, const long np) {
161 	for (long i = 0; i < np; ++i) {
162 		uint64_t* rai = ra + (i << logN);
163 		uint64_t* rbi = rb + (i << logN);
164 		uint64_t pi = pVec[i];
165 		for (long n = 0; n < N; ++n) {
166 			rai[n] += rbi[n];
167 			if(rai[n] > pi) rai[n] -= pi;
168 		}
169 	}
170 }
171 
reconstruct(ZZ * x,uint64_t * rx,long np,const ZZ & q)172 void RingMultiplier::reconstruct(ZZ* x, uint64_t* rx, long np, const ZZ& q) {
173 	ZZ* pHatnp = pHat[np - 1];
174 	uint64_t* pHatInvModpnp = pHatInvModp[np - 1];
175 	mulmod_precon_t* coeffpinv_arraynp = coeffpinv_array[np - 1];
176 	ZZ& pProdnp = pProd[np - 1];
177 	ZZ& pProdhnp = pProdh[np - 1];
178 	NTL_EXEC_RANGE(N, first, last);
179 	for (long n = first; n < last; ++n) {
180 		ZZ& acc = x[n];
181 		QuickAccumBegin(acc, pProdnp.size());
182 		for (long i = 0; i < np; i++) {
183 			long p = pVec[i];
184 			long tt = pHatInvModpnp[i];
185 			mulmod_precon_t ttpinv = coeffpinv_arraynp[i];
186 			long s = MulModPrecon(rx[n + (i << logN)], tt, p, ttpinv);
187 			QuickAccumMulAdd(acc, pHatnp[i], s);
188 		}
189 		QuickAccumEnd(acc);
190 		rem(x[n], x[n], pProdnp);
191 		if (x[n] > pProdhnp) x[n] -= pProdnp;
192 		x[n] %= q;
193 	}
194 	NTL_EXEC_RANGE_END;
195 }
196 
mult(ZZ * x,ZZ * a,ZZ * b,long np,const ZZ & mod)197 void RingMultiplier::mult(ZZ* x, ZZ* a, ZZ* b, long np, const ZZ& mod) {
198 	uint64_t* ra = new uint64_t[np << logN]();
199 	uint64_t* rb = new uint64_t[np << logN]();
200 	uint64_t* rx = new uint64_t[np << logN]();
201 
202 	NTL_EXEC_RANGE(np, first, last);
203 	for (long i = first; i < last; ++i) {
204 		uint64_t* rai = ra + (i << logN);
205 		uint64_t* rbi = rb + (i << logN);
206 		uint64_t* rxi = rx + (i << logN);
207 		uint64_t pi = pVec[i];
208 		uint64_t pri = prVec[i];
209 		_ntl_general_rem_one_struct* red_ss = red_ss_array[i];
210 		for (long n = 0; n < N; ++n) {
211 			rai[n] = _ntl_general_rem_one_struct_apply(a[n].rep, pi, red_ss);
212 			rbi[n] = _ntl_general_rem_one_struct_apply(b[n].rep, pi, red_ss);
213 		}
214 		NTT(rai, i);
215 		NTT(rbi, i);
216 		for (long n = 0; n < N; ++n) {
217 			mulModBarrett(rxi[n], rai[n], rbi[n], pi, pri);
218 		}
219 		INTT(rxi, i);
220 	}
221 	NTL_EXEC_RANGE_END;
222 
223 	reconstruct(x, rx, np, mod);
224 
225 	delete[] ra;
226 	delete[] rb;
227 	delete[] rx;
228 }
229 
multNTT(ZZ * x,ZZ * a,uint64_t * rb,long np,const ZZ & mod)230 void RingMultiplier::multNTT(ZZ* x, ZZ* a, uint64_t* rb, long np, const ZZ& mod) {
231 	uint64_t* ra = new uint64_t[np << logN]();
232 	uint64_t* rx = new uint64_t[np << logN]();
233 	NTL_EXEC_RANGE(np, first, last);
234 	for (long i = first; i < last; ++i) {
235 		uint64_t* rai = ra + (i << logN);
236 		uint64_t* rbi = rb + (i << logN);
237 		uint64_t* rxi = rx + (i << logN);
238 		uint64_t pi = pVec[i];
239 		uint64_t pri = prVec[i];
240 		_ntl_general_rem_one_struct* red_ss = red_ss_array[i];
241 		for (long n = 0; n < N; ++n) {
242 			rai[n] = _ntl_general_rem_one_struct_apply(a[n].rep, pi, red_ss);
243 		}
244 		NTT(rai, i);
245 		for (long n = 0; n < N; ++n) {
246 			mulModBarrett(rxi[n], rai[n], rbi[n], pi, pri);
247 		}
248 		INTT(rxi, i);
249 	}
250 	NTL_EXEC_RANGE_END;
251 
252 	reconstruct(x, rx, np, mod);
253 
254 	delete[] ra;
255 	delete[] rx;
256 }
257 
multDNTT(ZZ * x,uint64_t * ra,uint64_t * rb,long np,const ZZ & mod)258 void RingMultiplier::multDNTT(ZZ* x, uint64_t* ra, uint64_t* rb, long np, const ZZ& mod) {
259 	uint64_t* rx = new uint64_t[np << logN]();
260 
261 	NTL_EXEC_RANGE(np, first, last);
262 	for (long i = first; i < last; ++i) {
263 		uint64_t* rai = ra + (i << logN);
264 		uint64_t* rbi = rb + (i << logN);
265 		uint64_t* rxi = rx + (i << logN);
266 		uint64_t pi = pVec[i];
267 		uint64_t pri = prVec[i];
268 		for (long n = 0; n < N; ++n) {
269 			mulModBarrett(rxi[n], rai[n], rbi[n], pi, pri);
270 		}
271 		INTT(rxi, i);
272 	}
273 	NTL_EXEC_RANGE_END;
274 
275 	reconstruct(x, rx, np, mod);
276 
277 	delete[] rx;
278 }
279 
multAndEqual(ZZ * a,ZZ * b,long np,const ZZ & mod)280 void RingMultiplier::multAndEqual(ZZ* a, ZZ* b, long np, const ZZ& mod) {
281 	uint64_t* ra = new uint64_t[np << logN]();
282 	uint64_t* rb = new uint64_t[np << logN]();
283 
284 	NTL_EXEC_RANGE(np, first, last);
285 	for (long i = first; i < last; ++i) {
286 		uint64_t* rai = ra + (i << logN);
287 		uint64_t* rbi = rb + (i << logN);
288 		uint64_t pi = pVec[i];
289 		uint64_t pri = prVec[i];
290 		_ntl_general_rem_one_struct* red_ss = red_ss_array[i];
291 		for (long n = 0; n < N; ++n) {
292 			rai[n] = _ntl_general_rem_one_struct_apply(a[n].rep, pi, red_ss);
293 			rbi[n] = _ntl_general_rem_one_struct_apply(b[n].rep, pi, red_ss);
294 		}
295 		NTT(rai, i);
296 		NTT(rbi, i);
297 		for (long n = 0; n < N; ++n) {
298 			mulModBarrett(rai[n], rai[n], rbi[n], pi, pri);
299 		}
300 		INTT(rai, i);
301 	}
302 	NTL_EXEC_RANGE_END;
303 
304 	ZZ* pHatnp = pHat[np - 1];
305 	uint64_t* pHatInvModpnp = pHatInvModp[np - 1];
306 
307 	reconstruct(a, ra, np, mod);
308 
309 	delete[] ra;
310 	delete[] rb;
311 }
312 
multNTTAndEqual(ZZ * a,uint64_t * rb,long np,const ZZ & mod)313 void RingMultiplier::multNTTAndEqual(ZZ* a, uint64_t* rb, long np, const ZZ& mod) {
314 	uint64_t* ra = new uint64_t[np << logN]();
315 
316 	NTL_EXEC_RANGE(np, first, last);
317 	for (long i = first; i < last; ++i) {
318 		uint64_t* rai = ra + (i << logN);
319 		uint64_t* rbi = rb + (i << logN);
320 		uint64_t pi = pVec[i];
321 		uint64_t pri = prVec[i];
322 		_ntl_general_rem_one_struct* red_ss = red_ss_array[i];
323 		for (long n = 0; n < N; ++n) {
324 			rai[n] = _ntl_general_rem_one_struct_apply(a[n].rep, pi, red_ss);
325 		}
326 		NTT(rai, i);
327 		for (long n = 0; n < N; ++n) {
328 			mulModBarrett(rai[n], rai[n], rbi[n], pi, pri);
329 		}
330 		INTT(rai, i);
331 	}
332 	NTL_EXEC_RANGE_END;
333 
334 	ZZ* pHatnp = pHat[np - 1];
335 	uint64_t* pHatInvModpnp = pHatInvModp[np - 1];
336 
337 	reconstruct(a, ra, np, mod);
338 
339 	delete[] ra;
340 }
341 
342 
square(ZZ * x,ZZ * a,long np,const ZZ & mod)343 void RingMultiplier::square(ZZ* x, ZZ* a, long np, const ZZ& mod) {
344 	uint64_t* ra = new uint64_t[np << logN]();
345 	uint64_t* rx = new uint64_t[np << logN]();
346 
347 	NTL_EXEC_RANGE(np, first, last);
348 	for (long i = first; i < last; ++i) {
349 		uint64_t* rai = ra + (i << logN);
350 		uint64_t* rxi = rx + (i << logN);
351 		uint64_t pi = pVec[i];
352 		uint64_t pri = prVec[i];
353 		_ntl_general_rem_one_struct* red_ss = red_ss_array[i];
354 		for (long n = 0; n < N; ++n) {
355 			rai[n] = _ntl_general_rem_one_struct_apply(a[n].rep, pi, red_ss);
356 		}
357 		NTT(rai, i);
358 		for (long n = 0; n < N; ++n) {
359 			mulModBarrett(rxi[n], rai[n], rai[n], pi, pri);
360 		}
361 		INTT(rxi, i);
362 	}
363 	NTL_EXEC_RANGE_END;
364 
365 	ZZ* pHatnp = pHat[np - 1];
366 	uint64_t* pHatInvModpnp = pHatInvModp[np - 1];
367 
368 	reconstruct(x, rx, np, mod);
369 
370 	delete[] ra;
371 	delete[] rx;
372 }
373 
squareNTT(ZZ * x,uint64_t * ra,long np,const ZZ & mod)374 void RingMultiplier::squareNTT(ZZ* x, uint64_t* ra, long np, const ZZ& mod) {
375 	uint64_t* rx = new uint64_t[np << logN]();
376 
377 	NTL_EXEC_RANGE(np, first, last);
378 	for (long i = first; i < last; ++i) {
379 		uint64_t* rai = ra + (i << logN);
380 		uint64_t* rxi = rx + (i << logN);
381 		uint64_t pi = pVec[i];
382 		uint64_t pri = prVec[i];
383 		for (long n = 0; n < N; ++n) {
384 			mulModBarrett(rxi[n], rai[n], rai[n], pi, pri);
385 		}
386 		INTT(rxi, i);
387 	}
388 	NTL_EXEC_RANGE_END;
389 
390 	reconstruct(x, rx, np, mod);
391 
392 	delete[] rx;
393 }
394 
squareAndEqual(ZZ * a,long np,const ZZ & mod)395 void RingMultiplier::squareAndEqual(ZZ* a, long np, const ZZ& mod) {
396 	uint64_t* ra = new uint64_t[np << logN]();
397 
398 	NTL_EXEC_RANGE(np, first, last);
399 	for (long i = first; i < last; ++i) {
400 		uint64_t* rai = ra + (i << logN);
401 		uint64_t pi = pVec[i];
402 		uint64_t pri = prVec[i];
403 		_ntl_general_rem_one_struct* red_ss = red_ss_array[i];
404 		for (long n = 0; n < N; ++n) {
405 			rai[n] = _ntl_general_rem_one_struct_apply(a[n].rep, pi, red_ss);
406 		}
407 		NTT(rai, i);
408 		for (long n = 0; n < N; ++n) {
409 			mulModBarrett(rai[n], rai[n], rai[n], pi, pri);
410 		}
411 		INTT(rai, i);
412 	}
413 	NTL_EXEC_RANGE_END;
414 
415 	reconstruct(a, ra, np, mod);
416 
417 	delete[] ra;
418 }
419 
mulMod(uint64_t & r,uint64_t a,uint64_t b,uint64_t m)420 void RingMultiplier::mulMod(uint64_t &r, uint64_t a, uint64_t b, uint64_t m) {
421 	unsigned __int128 mul = static_cast<unsigned __int128>(a) * b;
422 	mul %= static_cast<unsigned __int128>(m);
423 	r = static_cast<uint64_t>(mul);
424 }
425 
mulModBarrett(uint64_t & r,uint64_t a,uint64_t b,uint64_t p,uint64_t pr)426 void RingMultiplier::mulModBarrett(uint64_t& r, uint64_t a, uint64_t b, uint64_t p, uint64_t pr) {
427 	unsigned __int128 mul = static_cast<unsigned __int128>(a) * b;
428 	uint64_t abot = static_cast<uint64_t>(mul);
429 	uint64_t atop = static_cast<uint64_t>(mul >> 64);
430 	unsigned __int128 tmp = static_cast<unsigned __int128>(abot) * pr;
431 	tmp >>= 64;
432 	tmp += static_cast<unsigned __int128>(atop) * pr;
433 	tmp >>= kbar2 - 64;
434 	tmp *= p;
435 	tmp = mul - tmp;
436 	r = static_cast<uint64_t>(tmp);
437 	if(r >= p) r -= p;
438 }
439 
butt(uint64_t & a,uint64_t & b,uint64_t W,uint64_t p,uint64_t pInv)440 void RingMultiplier::butt(uint64_t& a, uint64_t& b, uint64_t W, uint64_t p, uint64_t pInv) {
441 	unsigned __int128 U = static_cast<unsigned __int128>(b) * W;
442 	uint64_t U0 = static_cast<uint64_t>(U);
443 	uint64_t U1 = U >> 64;
444 	uint64_t Q = U0 * pInv;
445 	unsigned __int128 Hx = static_cast<unsigned __int128>(Q) * p;
446 	uint64_t H = Hx >> 64;
447 	uint64_t V = U1 < H ? U1 + p - H : U1 - H;
448 	b = a < V ? a + p - V : a - V;
449 	a += V;
450 	if (a > p) a -= p;
451 }
452 
ibutt(uint64_t & a,uint64_t & b,uint64_t W,uint64_t p,uint64_t pInv)453 void RingMultiplier::ibutt(uint64_t& a, uint64_t& b, uint64_t W, uint64_t p, uint64_t pInv) {
454 	uint64_t T = a < b ? a + p - b : a - b;
455 	a += b;
456 	if (a > p) a -= p;
457 	unsigned __int128 UU = static_cast<unsigned __int128>(T) * W;
458 	uint64_t U0 = static_cast<uint64_t>(UU);
459 	uint64_t U1 = UU >> 64;
460 	uint64_t Q = U0 * pInv;
461 	unsigned __int128 Hx = static_cast<unsigned __int128>(Q) * p;
462 	uint64_t H = Hx >> 64;
463 	b = (U1 < H) ? U1 + p - H : U1 - H;
464 }
465 
idivN(uint64_t & a,uint64_t NScale,uint64_t p,uint64_t pInv)466 void RingMultiplier::idivN(uint64_t& a, uint64_t NScale, uint64_t p, uint64_t pInv) {
467 	unsigned __int128 U = static_cast<unsigned __int128>(a) * NScale;
468 	uint64_t U0 = static_cast<uint64_t>(U);
469 	uint64_t U1 = U >> 64;
470 	uint64_t Q = U0 * pInv;
471 	unsigned __int128 Hx = static_cast<unsigned __int128>(Q) * p;
472 	uint64_t H = Hx >> 64;
473 	a = (U1 < H) ? U1 + p - H : U1 - H;
474 }
475 
invMod(uint64_t x,uint64_t m)476 uint64_t RingMultiplier::invMod(uint64_t x, uint64_t m) {
477 	return powMod(x, m - 2, m);
478 }
479 
powMod(uint64_t x,uint64_t y,uint64_t modulus)480 uint64_t RingMultiplier::powMod(uint64_t x, uint64_t y, uint64_t modulus) {
481 	uint64_t res = 1;
482 	while (y > 0) {
483 		if (y & 1) {
484 			mulMod(res, res, x, modulus);
485 		}
486 		y = y >> 1;
487 		mulMod(x, x, x, modulus);
488 	}
489 	return res;
490 }
491 
inv(uint64_t x)492 uint64_t RingMultiplier::inv(uint64_t x) {
493 	return pow(x, static_cast<uint64_t>(-1));
494 }
495 
pow(uint64_t x,uint64_t y)496 uint64_t RingMultiplier::pow(uint64_t x, uint64_t y) {
497 	uint64_t res = 1;
498 	while (y > 0) {
499 		if (y & 1) {
500 			res *= x;
501 		}
502 		y = y >> 1;
503 		x *= x;
504 	}
505 	return res;
506 }
507 
bitReverse(uint32_t x)508 uint32_t RingMultiplier::bitReverse(uint32_t x) {
509 	x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1));
510 	x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2));
511 	x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4));
512 	x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8));
513 	return ((x >> 16) | (x << 16));
514 }
515 
findPrimeFactors(vector<uint64_t> & s,uint64_t number)516 void RingMultiplier::findPrimeFactors(vector<uint64_t> &s, uint64_t number) {
517 	while (number % 2 == 0) {
518 		s.push_back(2);
519 		number /= 2;
520 	}
521 	for (uint64_t i = 3; i < sqrt(number); i++) {
522 		while (number % i == 0) {
523 			s.push_back(i);
524 			number /= i;
525 		}
526 	}
527 	if (number > 2) {
528 		s.push_back(number);
529 	}
530 }
531 
findPrimitiveRoot(uint64_t modulus)532 uint64_t RingMultiplier::findPrimitiveRoot(uint64_t modulus) {
533 	vector<uint64_t> s;
534 	uint64_t phi = modulus - 1;
535 	findPrimeFactors(s, phi);
536 	for (uint64_t r = 2; r <= phi; r++) {
537 		bool flag = false;
538 		for (auto it = s.begin(); it != s.end(); it++) {
539 			if (powMod(r, phi / (*it), modulus) == 1) {
540 				flag = true;
541 				break;
542 			}
543 		}
544 		if (flag == false) {
545 			return r;
546 		}
547 	}
548 	return -1;
549 }
550 
findMthRootOfUnity(uint64_t M,uint64_t mod)551 uint64_t RingMultiplier::findMthRootOfUnity(uint64_t M, uint64_t mod) {
552     uint64_t res;
553     res = findPrimitiveRoot(mod);
554     if((mod - 1) % M == 0) {
555         uint64_t factor = (mod - 1) / M;
556         res = powMod(res, factor, mod);
557         return res;
558     }
559     else {
560         return -1;
561     }
562 }
563 
564 
565