1 /* Copyright (C) 2012-2020 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12
13 #include <helib/PAlgebra.h>
14 #include <helib/hypercube.h>
15 #include <helib/timing.h>
16 #include <helib/range.h>
17
18 #include <NTL/ZZXFactoring.h>
19 #include <NTL/GF2EXFactoring.h>
20 #include <NTL/lzz_pEXFactoring.h>
21 #include <NTL/BasicThreadPool.h>
22
23 #include <algorithm> // defines count(...), min(...)
24 #include <cmath>
25 #include <mutex> // std::mutex, std::unique_lock
26
27 namespace helib {
28
29 // polynomials are sorted lexicographically, with the
30 // constant term being the "most significant"
31
32 template <typename RX>
33 bool poly_comp(const RX& a, const RX& b);
34
less_than(NTL::GF2 a,NTL::GF2 b)35 bool less_than(NTL::GF2 a, NTL::GF2 b) { return rep(a) < rep(b); }
less_than(NTL::zz_p a,NTL::zz_p b)36 bool less_than(NTL::zz_p a, NTL::zz_p b) { return rep(a) < rep(b); }
37
less_than(const NTL::GF2X & a,const NTL::GF2X & b)38 bool less_than(const NTL::GF2X& a, const NTL::GF2X& b)
39 {
40 return poly_comp(a, b);
41 }
less_than(const NTL::zz_pX & a,const NTL::zz_pX & b)42 bool less_than(const NTL::zz_pX& a, const NTL::zz_pX& b)
43 {
44 return poly_comp(a, b);
45 }
46
less_than(const NTL::GF2E & a,const NTL::GF2E & b)47 bool less_than(const NTL::GF2E& a, const NTL::GF2E& b)
48 {
49 return less_than(rep(a), rep(b));
50 }
less_than(const NTL::zz_pE & a,const NTL::zz_pE & b)51 bool less_than(const NTL::zz_pE& a, const NTL::zz_pE& b)
52 {
53 return less_than(rep(a), rep(b));
54 }
55
less_than(const NTL::GF2EX & a,const NTL::GF2EX & b)56 bool less_than(const NTL::GF2EX& a, const NTL::GF2EX& b)
57 {
58 return poly_comp(a, b);
59 }
less_than(const NTL::zz_pEX & a,const NTL::zz_pEX & b)60 bool less_than(const NTL::zz_pEX& a, const NTL::zz_pEX& b)
61 {
62 return poly_comp(a, b);
63 }
64
65 template <typename RX>
poly_comp(const RX & a,const RX & b)66 bool poly_comp(const RX& a, const RX& b)
67 {
68 long na = deg(a) + 1;
69 long nb = deg(b) + 1;
70
71 long i = 0;
72 while (i < na && i < nb && coeff(a, i) == coeff(b, i))
73 i++;
74
75 if (i < na && i < nb)
76 return less_than(coeff(a, i), coeff(b, i));
77 else
78 return na < nb;
79 }
80
operator ==(const PAlgebra & other) const81 bool PAlgebra::operator==(const PAlgebra& other) const
82 {
83 if (m != other.m)
84 return false;
85 if (p != other.p)
86 return false;
87
88 return true;
89 }
90
exponentiate(const std::vector<long> & exps,bool onlySameOrd) const91 long PAlgebra::exponentiate(const std::vector<long>& exps,
92 bool onlySameOrd) const
93 {
94 if (isDryRun())
95 return 1;
96 long t = 1;
97 long n = std::min(exps.size(), gens.size());
98 for (long i = 0; i < n; i++) {
99 if (onlySameOrd && !SameOrd(i))
100 continue;
101 long g = NTL::PowerMod(gens[i], exps[i], m);
102 t = NTL::MulMod(t, g, m);
103 }
104 return t;
105 }
106
printout(std::ostream & out) const107 void PAlgebra::printout(std::ostream& out) const
108 {
109 out << "m = " << m << ", p = " << p;
110 if (isDryRun()) {
111 out << " (dry run)" << std::endl;
112 return;
113 }
114
115 out << ", phi(m) = " << phiM << std::endl;
116 out << " ord(p) = " << ordP << std::endl;
117 out << " normBnd = " << normBnd << std::endl;
118 out << " polyNormBnd = " << polyNormBnd << std::endl;
119
120 std::vector<long> facs;
121 factorize(facs, m);
122 out << " factors = " << facs << std::endl;
123
124 for (std::size_t i = 0; i < gens.size(); i++)
125 if (gens[i]) {
126 // FIXME: is it really possible that gens[i] can be 0?
127 // There is very likely some code here and there that
128 // would break if that happens.
129
130 out << " generator " << gens[i] << " has order (";
131 if (FrobPerturb(i) == 0)
132 out << "=";
133 else if (FrobPerturb(i) > 0)
134 out << "!";
135 else
136 out << "!!";
137 out << "= Z_m^*) of ";
138 out << OrderOf(i) << std::endl;
139 }
140
141 if (cube.getSize() < 40) {
142 out << " T = [ ";
143 for (auto const& t : T)
144 out << t << " ";
145 out << "]" << std::endl;
146 }
147 }
148
printAll(std::ostream & out) const149 void PAlgebra::printAll(std::ostream& out) const
150 {
151 printout(out);
152 if (cube.getSize() < 40) {
153 out << " Tidx = [ ";
154 for (const auto& x : Tidx)
155 out << x << " ";
156 out << "]\n";
157 out << " zmsIdx = [ ";
158 for (const auto& x : zmsIdx)
159 out << x << " ";
160 out << "]\n";
161 out << " zmsRep = [ ";
162 for (const auto& x : zmsRep)
163 out << x << " ";
164 out << "]\n";
165 }
166 }
167
cotan(double x)168 static double cotan(double x) { return 1 / tan(x); }
169
half_FFT(long m)170 half_FFT::half_FFT(long m) : fft(m / 2)
171 {
172 typedef std::complex<double> cmplx_t;
173 typedef long double ldbl;
174
175 pow.resize(m / 2);
176 for (long i : range(m / 2)) {
177 // pow[i] = 2^{2*pi*I*(i/m)}
178 ldbl angle = -((2.0L * PI) * (ldbl(i) / ldbl(m)));
179 pow[i] = cmplx_t(std::cos(angle), std::sin(angle));
180 }
181 }
182
quarter_FFT(long m)183 quarter_FFT::quarter_FFT(long m) : fft(m / 4)
184 {
185 typedef std::complex<double> cmplx_t;
186 typedef long double ldbl;
187
188 pow1.resize(m / 4);
189 pow2.resize(m / 4);
190 for (long i : range(m / 2)) {
191 // pow[i] = 2^{2*pi*I*(i/m)}
192 ldbl angle = -((2.0L * PI) * (ldbl(i) / ldbl(m)));
193 if (i % 2)
194 pow1[i >> 1] = cmplx_t(std::cos(angle), std::sin(angle));
195 else
196 pow2[i >> 1] = cmplx_t(std::cos(angle), std::sin(angle));
197 }
198 }
199
MUL(std::complex<double> a,std::complex<double> b)200 static inline std::complex<double> MUL(std::complex<double> a,
201 std::complex<double> b)
202 {
203 double x = a.real(), y = a.imag(), u = b.real(), v = b.imag();
204 return std::complex<double>(x * u - y * v, x * v + y * u);
205 }
206
ABS(std::complex<double> a)207 static inline double ABS(std::complex<double> a)
208 {
209 double x = a.real(), y = a.imag();
210 return std::sqrt(x * x + y * y);
211 }
212
calcPolyNormBnd(long m)213 double calcPolyNormBnd(long m)
214 {
215 assertTrue(m >= 1, "m >= 1");
216
217 typedef std::complex<double> cmplx_t;
218 typedef long double ldbl;
219
220 // first, remove 2's
221 while (m % 2 == 0)
222 m /= 2;
223
224 if (m == 1) {
225 return 1;
226 }
227
228 std::vector<long> fac;
229 factorize(fac, m);
230
231 long radm = 1;
232 for (long p : fac)
233 radm *= p;
234
235 if (fac.size() == 1) {
236 long u = fac[0];
237 return 2.0L * cotan(PI / (2.0L * u)) / u;
238 }
239
240 m = radm;
241
242 long n = phi_N(m);
243
244 NTL::ZZX PhiPoly = Cyclotomic(m);
245
246 std::vector<double> a(n);
247 for (long i : range(n))
248 conv(a[i], PhiPoly[i]);
249 // a does not include the leading coefficient 1
250 // NOTE: according to the Arnold and Monogan paper
251 // (Table 6) the least m such that the coefficients of Phi_m
252 // do not fit in 53-bits is m=43,730,115.
253
254 std::vector<cmplx_t> roots(m);
255 std::vector<cmplx_t> x(n);
256
257 for (long i : range(m)) {
258 ldbl re = std::cos(2.0L * PI * (ldbl(i) / ldbl(m)));
259 ldbl im = std::sin(2.0L * PI * (ldbl(i) / ldbl(m)));
260 roots[i] = cmplx_t(re, im);
261 }
262
263 std::vector<long> res_tab(n);
264
265 long row_num = 0;
266 for (long i : range(1, m)) {
267 if (NTL::GCD(i, m) != 1)
268 continue;
269 x[row_num] = roots[i];
270 res_tab[row_num] = i;
271 row_num++;
272 }
273
274 std::vector<double> dist_tab_vec(2 * m - 1);
275 std::vector<int> dist_exp_tab_vec(2 * m);
276
277 double* dist_tab = &dist_tab_vec[m - 1];
278 int* dist_exp_tab = &dist_exp_tab_vec[m - 1];
279
280 const double sqrt2_inv = 1.0 / std::sqrt(ldbl(2));
281 constexpr long FREXP_ITER = 1600;
282
283 for (long i : range(1, m)) {
284 dist_tab[i] = std::frexp(double(2.0L * std::sin(PI * (ldbl(i) / ldbl(m)))),
285 &dist_exp_tab[i]);
286
287 if (dist_tab[i] < sqrt2_inv) {
288 dist_tab[i] *= 2.0;
289 dist_exp_tab[i]--;
290 }
291
292 dist_tab[-i] = dist_tab[i];
293 dist_exp_tab[-i] = dist_exp_tab[i];
294 }
295
296 dist_tab[0] = 1;
297 dist_exp_tab[0] = 0;
298
299 std::vector<double> global_norm_col(n);
300 for (long i : range(n))
301 global_norm_col[i] = 0;
302 std::mutex global_norm_col_mutex;
303
304 NTL_EXEC_RANGE(n, first, last)
305
306 std::vector<double> norm_col(n);
307 for (long i : range(n))
308 norm_col[i] = 0;
309
310 long j = first;
311
312 for (; j <= last - 2; j += 2) {
313 // process columns j and j+1 of inverse matrix
314 // NOTE: processing columns two at a time gives an almost 2x speedup
315
316 long res_j = res_tab[j];
317 long res_j_1 = res_tab[j + 1];
318
319 double prod = 1;
320 double prod_1 = 1;
321 long e_total = 0;
322 long e_total_1 = 0;
323 int e;
324 int e_1;
325
326 {
327 long i = 0;
328 while (i <= n - FREXP_ITER) {
329 for (long k = 0; k < FREXP_ITER; k++) {
330 long res_i = res_tab[i + k];
331 prod *= dist_tab[res_i - res_j];
332 prod_1 *= dist_tab[res_i - res_j_1];
333 e_total += dist_exp_tab[res_i - res_j];
334 e_total_1 += dist_exp_tab[res_i - res_j_1];
335 }
336 prod = std::frexp(prod, &e);
337 prod_1 = std::frexp(prod_1, &e_1);
338 e_total += e;
339 e_total_1 += e_1;
340
341 i += FREXP_ITER;
342 }
343 while (i < n) {
344 long res_i = res_tab[i];
345 prod *= dist_tab[res_i - res_j];
346 prod_1 *= dist_tab[res_i - res_j_1];
347 e_total += dist_exp_tab[res_i - res_j];
348 e_total_1 += dist_exp_tab[res_i - res_j_1];
349 i++;
350 }
351 }
352
353 prod = std::ldexp(prod, e_total);
354 prod_1 = std::ldexp(prod_1, e_total_1);
355
356 double inv_prod = 1.0 / prod;
357 double inv_prod_1 = 1.0 / prod_1;
358
359 cmplx_t xj = x[j];
360 cmplx_t xj_1 = x[j + 1];
361 cmplx_t q = 1;
362 cmplx_t q_1 = 1;
363
364 norm_col[0] += (inv_prod + inv_prod_1);
365
366 for (long i : range(1, n)) {
367 q = MUL(q, xj) + a[n - i];
368 q_1 = MUL(q_1, xj_1) + a[n - i];
369 norm_col[i] += (ABS(q) * inv_prod + ABS(q_1) * inv_prod_1);
370 }
371 }
372
373 if (j == last - 1) {
374 // process column j of inverse matrix
375
376 long res_j = res_tab[j];
377
378 double prod = 1;
379 long e_total = 0;
380 int e;
381
382 {
383 long i = 0;
384 while (i <= n - FREXP_ITER) {
385 for (long k = 0; k < FREXP_ITER; k++) {
386 long res_i = res_tab[i + k];
387 prod *= dist_tab[res_i - res_j];
388 e_total += dist_exp_tab[res_i - res_j];
389 }
390 prod = std::frexp(prod, &e);
391 e_total += e;
392 i += FREXP_ITER;
393 }
394 while (i < n) {
395 long res_i = res_tab[i];
396 prod *= dist_tab[res_i - res_j];
397 e_total += dist_exp_tab[res_i - res_j];
398 i++;
399 }
400 }
401
402 prod = std::ldexp(prod, e_total);
403
404 double inv_prod = 1.0 / prod;
405
406 cmplx_t xj = x[j];
407 cmplx_t q = 1;
408 norm_col[0] += inv_prod;
409 for (long i : range(1, n)) {
410 q = MUL(q, xj) + a[n - i];
411 norm_col[i] += ABS(q) * inv_prod;
412 }
413 }
414
415 std::lock_guard<std::mutex> guard(global_norm_col_mutex);
416
417 for (long i : range(n))
418 global_norm_col[i] += norm_col[i];
419
420 NTL_EXEC_INDEX_END
421
422 double max_norm = 0;
423 for (long i : range(n)) {
424 if (max_norm < global_norm_col[i])
425 max_norm = global_norm_col[i];
426 }
427
428 return max_norm;
429 }
430
PAlgebra(long mm,long pp,const std::vector<long> & _gens,const std::vector<long> & _ords)431 PAlgebra::PAlgebra(long mm,
432 long pp,
433 const std::vector<long>& _gens,
434 const std::vector<long>& _ords) :
435 m(mm), p(pp), cM(1.0) // default value for the ring constant
436 {
437 assertInRange<InvalidArgument>(mm,
438 2l,
439 NTL_SP_BOUND,
440 "mm is not in [2, NTL_SP_BOUND)");
441 if (pp == -1) // pp==-1 signals using the complex field for plaintext
442 pp = m - 1;
443 else {
444 assertTrue<InvalidArgument>((bool)NTL::ProbPrime(pp),
445 "Modulus pp is not prime (nor -1)");
446 assertNeq<InvalidArgument>(mm % pp, 0l, "Modulus pp divides mm");
447 }
448
449 long k = NTL::NextPowerOfTwo(mm);
450 if (static_cast<unsigned long>(mm) == (1UL << k)) // m is a power of two
451 pow2 = k;
452 else if (p != -1) // is not power of two, set to zero (even if m is even!)
453 pow2 = 0;
454 else // CKKS requires m to be a power of two. Throw if not.
455 throw InvalidArgument("CKKS scheme only supports m as a power of two.");
456
457 // For dry-run, use a tiny m value for the PAlgebra tables
458 if (isDryRun())
459 m = (p == 3) ? 4 : 3;
460
461 // Compute the generators for (Z/mZ)^* (defined in NumbTh.cpp)
462
463 std::vector<long> tmpOrds;
464 if (_gens.size() > 0 && _gens.size() == _ords.size() && !isDryRun()) {
465 // externally supplied generator,orders
466 tmpOrds = _ords;
467 this->gens = _gens;
468 this->ordP = multOrd(pp, mm);
469 } else
470 // treat externally supplied generators (if any) as candidates
471 this->ordP = findGenerators(this->gens, tmpOrds, mm, pp, _gens);
472
473 // Record for each generator gi whether it has the same order in
474 // ZM* as in Zm* /(p,g1,...,g_{i-1})
475
476 resize(native, lsize(tmpOrds));
477 resize(frob_perturb, lsize(tmpOrds));
478 std::vector<long> p_subgp(mm);
479 for (long i : range(mm))
480 p_subgp[i] = -1;
481 long pmodm = pp % mm;
482 p_subgp[1] = 0;
483 for (long i = 1, p2i = pmodm; p2i != 1; i++, p2i = NTL::MulMod(p2i, pmodm, m))
484 p_subgp[p2i] = i;
485 for (long j : range(tmpOrds.size())) {
486 tmpOrds[j] = std::abs(tmpOrds[j]);
487 // for backward compatibility, a user supplied
488 // ords value could be negative, but we ignore that here.
489 // For testing and debugging, we may want to not ignore this...
490
491 long i = NTL::PowerMod(this->gens[j], tmpOrds[j], m);
492
493 native[j] = (i == 1);
494 frob_perturb[j] = p_subgp[i];
495 }
496
497 cube.initSignature(tmpOrds); // set hypercube with these dimensions
498
499 phiM = ordP * getNSlots();
500
501 NTL::Vec<NTL::Pair<long, long>> factors;
502 factorize(factors, mm);
503 nfactors = factors.length();
504
505 radm = 1;
506 for (long i : range(nfactors))
507 radm *= factors[i].a;
508
509 normBnd = 1.0;
510 for (long i : range(nfactors)) {
511 long u = factors[i].a;
512 normBnd *= 2.0L * cotan(PI / (2.0L * u)) / u;
513 }
514
515 polyNormBnd = calcPolyNormBnd(mm);
516
517 // Allocate space for the various arrays
518 resize(T, getNSlots());
519 Tidx.assign(mm, -1); // allocate m slots, initialize them to -1
520 zmsIdx.assign(mm, -1); // allocate m slots, initialize them to -1
521 resize(zmsRep, phiM);
522 long i, idx;
523 for (i = idx = 0; i < mm; i++) {
524 if (NTL::GCD(i, mm) == 1) {
525 zmsIdx[i] = idx++;
526 zmsRep[zmsIdx[i]] = i;
527 }
528 }
529
530 // Now fill the Tidx translation table. We identify an element t \in T
531 // with its representation t = \prod_{i=0}^n gi^{ei} mod m (where the
532 // gi's are the generators in gens[]) , represent t by the vector of
533 // exponents *in reverse order* (en,...,e1,e0), and order these vectors
534 // in lexicographic order.
535
536 // FIXME: is the comment above about reverse order true?
537 // It doesn't seem like it to me, VJS.
538 // The comment about reverse order is correct, SH.
539
540 // buffer is initialized to all-zero, which represents 1=\prod_i gi^0
541 std::vector<long> buffer(gens.size()); // temporary holds exponents
542 i = idx = 0;
543 long ctr = 0;
544 do {
545 ctr++;
546 long t = exponentiate(buffer);
547
548 // sanity check for user-supplied gens
549 assertEq(NTL::GCD(t, mm), 1l, "Bad user-supplied generator");
550 assertEq(Tidx[t], -1l, "Slot at index t has already been assigned");
551
552 T[i] = t; // The i'th element in T it t
553 Tidx[t] = i++; // the index of t in T is i
554
555 // increment buffer by one (in lexicographic order)
556 } while (nextExpVector(buffer)); // until we cover all the group
557
558 // sanity check for user-supplied gens
559 assertEq(ctr, getNSlots(), "Bad user-supplied generator set");
560
561 PhimX = Cyclotomic(mm); // compute and store Phi_m(X)
562 // pp_factorize(mFactors,mm); // prime-power factorization from NumbTh.cpp
563
564 if (mm % 2 == 0)
565 half_fftInfo = std::make_shared<half_FFT>(mm);
566 else
567 fftInfo = std::make_shared<PGFFT>(mm);
568
569 // fftInfo = std::make_shared<PGFFT>(mm); // Need this for some
570 // debugging/timing
571
572 if (mm % 4 == 0)
573 quarter_fftInfo = std::make_shared<quarter_FFT>(mm);
574 }
575
comparePAlgebra(const PAlgebra & palg,unsigned long m,unsigned long p,UNUSED unsigned long r,const std::vector<long> & gens,const std::vector<long> & ords)576 bool comparePAlgebra(const PAlgebra& palg,
577 unsigned long m,
578 unsigned long p,
579 UNUSED unsigned long r,
580 const std::vector<long>& gens,
581 const std::vector<long>& ords)
582 {
583 if (static_cast<unsigned long>(palg.getM()) != m ||
584 static_cast<unsigned long>(palg.getP()) != p ||
585 static_cast<std::size_t>(palg.numOfGens()) != gens.size() ||
586 static_cast<std::size_t>(palg.numOfGens()) != ords.size())
587 return false;
588
589 for (long i = 0; i < (long)gens.size(); i++) {
590 if (long(palg.ZmStarGen(i)) != gens[i])
591 return false;
592
593 if ((palg.SameOrd(i) && palg.OrderOf(i) != ords[i]) ||
594 (!palg.SameOrd(i) && palg.OrderOf(i) != -ords[i]))
595 return false;
596 }
597 return true;
598 }
599
frobeniusPow(long j) const600 long PAlgebra::frobeniusPow(long j) const
601 {
602 return NTL::PowerMod(mcMod(p, m), j, m);
603 // Don't forget to reduce p mod m!!
604 }
605
genToPow(long i,long j) const606 long PAlgebra::genToPow(long i, long j) const
607 {
608 long sz = gens.size();
609
610 if (i == sz) {
611 assertTrue(j == 0, "PAlgebra::genToPow: i == sz but j != 0");
612 return 1;
613 }
614
615 assertTrue(i >= -1 && i < LONG(gens.size()), "PAlgebra::genToPow: bad dim");
616
617 long res;
618 if (i == -1)
619 res = frobeniusPow(j);
620 else
621 res = NTL::PowerMod(gens[i], j, m);
622
623 return res;
624 }
625
626 /***********************************************************************
627
628 PAlgebraMod stuff....
629
630 ************************************************************************/
631
buildPAlgebraMod(const PAlgebra & zMStar,long r)632 PAlgebraModBase* buildPAlgebraMod(const PAlgebra& zMStar, long r)
633 {
634 long p = zMStar.getP();
635
636 if (p == -1) // complex plaintext space
637 return new PAlgebraModCx(zMStar, r);
638
639 assertTrue<InvalidArgument>(p >= 2,
640 "Modulus p is less than 2 (nor -1 for CKKS)");
641 assertTrue<InvalidArgument>(r > 0, "Hensel lifting r is less than 1");
642 if (p == 2 && r == 1)
643 return new PAlgebraModDerived<PA_GF2>(zMStar, r);
644 else
645 return new PAlgebraModDerived<PA_zz_p>(zMStar, r);
646 }
647
648 template <typename T>
649 void PAlgebraLift(const NTL::ZZX& phimx,
650 const T& lfactors,
651 T& factors,
652 T& crtc,
653 long r);
654
655 // Missing NTL functionality
656
EDF(NTL::vec_zz_pX & v,const NTL::zz_pX & f,long d)657 void EDF(NTL::vec_zz_pX& v, const NTL::zz_pX& f, long d)
658 {
659 EDF(v, f, PowerXMod(NTL::zz_p::modulus(), f), d);
660 }
661
FrobeniusMap(const NTL::zz_pEXModulus & F)662 NTL::zz_pEX FrobeniusMap(const NTL::zz_pEXModulus& F)
663 {
664 return PowerXMod(NTL::zz_pE::cardinality(), F);
665 }
666
667 template <typename type>
PAlgebraModDerived(const PAlgebra & _zMStar,long _r)668 PAlgebraModDerived<type>::PAlgebraModDerived(const PAlgebra& _zMStar, long _r) :
669 zMStar(_zMStar), r(_r)
670
671 {
672 long p = zMStar.getP();
673 long m = zMStar.getM();
674
675 // For dry-run, use a tiny m value for the PAlgebra tables
676 if (isDryRun())
677 m = (p == 3) ? 4 : 3;
678
679 assertTrue<InvalidArgument>(r > 0l, "Hensel lifting r is less than 1");
680
681 NTL::ZZ BigPPowR = NTL::power_ZZ(p, r);
682 assertTrue((bool)BigPPowR.SinglePrecision(),
683 "BigPPowR is not SinglePrecision");
684 pPowR = to_long(BigPPowR);
685
686 long nSlots = zMStar.getNSlots();
687
688 RBak bak;
689 bak.save();
690 SetModulus(p);
691
692 // Compute the factors Ft of Phi_m(X) mod p, for all t \in T
693
694 RX phimxmod;
695
696 conv(phimxmod, zMStar.getPhimX()); // Phi_m(X) mod p
697
698 vec_RX localFactors;
699
700 EDF(localFactors, phimxmod, zMStar.getOrdP()); // equal-degree factorization
701
702 RX* first = &localFactors[0];
703 RX* last = first + lsize(localFactors);
704 RX* smallest =
705 std::min_element(first,
706 last,
707 static_cast<bool (*)(const RX&, const RX&)>(less_than));
708 swap(*first, *smallest);
709
710 // We make the lexicographically smallest factor have index 0.
711 // The remaining factors are ordered according to their representatives.
712
713 RXModulus F1(localFactors[0]);
714 for (long i = 1; i < nSlots; i++) {
715 long t = zMStar.ith_rep(i); // Ft is minimal poly of x^{1/t} mod F1
716 long tInv = NTL::InvMod(t, m); // tInv = t^{-1} mod m
717 RX X2tInv = PowerXMod(tInv, F1); // X2tInv = X^{1/t} mod F1
718 NTL::IrredPolyMod(localFactors[i], X2tInv, F1);
719 // IrredPolyMod(X,P,Q) returns in X the minimal polynomial of P mod Q
720 }
721 /* Debugging sanity-check #1: we should have Ft= GCD(F1(X^t),Phi_m(X))
722 for (i=1; i<nSlots; i++) {
723 long t = T[i];
724 RX X2t = PowerXMod(t,phimxmod); // X2t = X^t mod Phi_m(X)
725 RX Ft = GCD(CompMod(F1,X2t,phimxmod),phimxmod);
726 if (Ft != localFactors[i]) {
727 cout << "Ft != F1(X^t) mod Phi_m(X), t=" << t << endl;
728 exit(0);
729 }
730 }*******************************************************************/
731
732 if (r == 1) {
733 build(PhimXMod, phimxmod);
734 factors = localFactors;
735 pPowRContext.save();
736
737 // Compute the CRT coefficients for the Ft's
738 resize(crtCoeffs, nSlots);
739 for (long i = 0; i < nSlots; i++) {
740 RX te = phimxmod / factors[i]; // \prod_{j\ne i} Fj
741 te %= factors[i]; // \prod_{j\ne i} Fj mod Fi
742 InvMod(crtCoeffs[i], te, factors[i]); // \prod_{j\ne i} Fj^{-1} mod Fi
743 }
744 } else {
745 PAlgebraLift(zMStar.getPhimX(), localFactors, factors, crtCoeffs, r);
746 RX phimxmod1;
747 conv(phimxmod1, zMStar.getPhimX());
748 build(PhimXMod, phimxmod1);
749 pPowRContext.save();
750 }
751
752 // set factorsOverZZ
753 resize(factorsOverZZ, nSlots);
754 for (long i = 0; i < nSlots; i++)
755 conv(factorsOverZZ[i], factors[i]);
756
757 genCrtTable();
758 genMaskTable();
759 }
760
761 // Assumes current zz_p modulus is p^r
762 // computes S = F^{-1} mod G via Hensel lifting
InvModpr(NTL::zz_pX & S,const NTL::zz_pX & F,const NTL::zz_pX & G,long p,long r)763 void InvModpr(NTL::zz_pX& S,
764 const NTL::zz_pX& F,
765 const NTL::zz_pX& G,
766 long p,
767 long r)
768 {
769 NTL::ZZX ff, gg, ss, tt;
770
771 ff = to_ZZX(F);
772 gg = to_ZZX(G);
773
774 NTL::zz_pBak bak;
775 bak.save();
776 NTL::zz_p::init(p);
777
778 NTL::zz_pX f, g, s, t;
779 f = to_zz_pX(ff);
780 g = to_zz_pX(gg);
781 s = InvMod(f, g);
782 t = (1 - s * f) / g;
783 assertTrue(static_cast<bool>(s * f + t * g == 1l),
784 "Arithmetic error during Hensel lifting");
785 ss = to_ZZX(s);
786 tt = to_ZZX(t);
787
788 NTL::ZZ pk = NTL::to_ZZ(1);
789
790 for (long k = 1; k < r; k++) {
791 // lift from p^k to p^{k+1}
792 pk = pk * p;
793
794 assertTrue((bool)divide(ss * ff + tt * gg - 1, pk),
795 "Arithmetic error during Hensel lifting");
796
797 NTL::zz_pX d = to_zz_pX((1 - (ss * ff + tt * gg)) / pk);
798 NTL::zz_pX s1, t1;
799 s1 = (s * d) % g;
800 t1 = (d - s1 * f) / g;
801 ss = ss + pk * to_ZZX(s1);
802 tt = tt + pk * to_ZZX(t1);
803 }
804
805 bak.restore();
806
807 S = to_zz_pX(ss);
808
809 assertTrue(static_cast<bool>((S * F) % G == 1),
810 "Hensel lifting failed to find solutions");
811 }
812
813 // FIXME: Consider changing this function to something non-templated.
814 #pragma GCC diagnostic push
815 #pragma GCC diagnostic ignored "-Wunused-parameter"
816 template <typename T>
PAlgebraLift(const NTL::ZZX & phimx,const T & lfactors,T & factors,T & crtc,long r)817 void PAlgebraLift(const NTL::ZZX& phimx,
818 const T& lfactors,
819 T& factors,
820 T& crtc,
821 long r)
822 {
823 throw LogicError("Uninstantiated version of PAlgebraLift");
824 }
825 #pragma GCC diagnostic pop
826
827 // This specialized version of PAlgebraLift does the hensel
828 // lifting needed to finish off the initialization.
829 // It assumes the zz_p modulus is initialized to p
830 // when called, and leaves it set to p^r
831
832 template <>
PAlgebraLift(const NTL::ZZX & phimx,const NTL::vec_zz_pX & lfactors,NTL::vec_zz_pX & factors,NTL::vec_zz_pX & crtc,long r)833 void PAlgebraLift(const NTL::ZZX& phimx,
834 const NTL::vec_zz_pX& lfactors,
835 NTL::vec_zz_pX& factors,
836 NTL::vec_zz_pX& crtc,
837 long r)
838 {
839 long p = NTL::zz_p::modulus();
840 long nSlots = lsize(lfactors);
841
842 NTL::vec_ZZX vzz; // need to go via ZZX
843
844 // lift the factors of Phi_m(X) from mod-2 to mod-2^r
845 if (lsize(lfactors) > 1)
846 MultiLift(vzz, lfactors, phimx, r); // defined in NTL::ZZXFactoring
847 else {
848 resize(vzz, 1);
849 vzz[0] = phimx;
850 }
851
852 // Compute the zz_pContext object for mod p^r arithmetic
853 NTL::zz_p::init(NTL::power_long(p, r));
854
855 NTL::zz_pX phimxmod = to_zz_pX(phimx);
856 resize(factors, nSlots);
857 for (long i = 0; i < nSlots; i++) // Convert from ZZX to zz_pX
858 conv(factors[i], vzz[i]);
859
860 // Finally compute the CRT coefficients for the factors
861 resize(crtc, nSlots);
862 for (long i = 0; i < nSlots; i++) {
863 NTL::zz_pX& fct = factors[i];
864 NTL::zz_pX te = phimxmod / fct; // \prod_{j\ne i} Fj
865 te %= fct; // \prod_{j\ne i} Fj mod Fi
866 InvModpr(crtc[i], te, fct, p, r); // \prod_{j\ne i} Fj^{-1} mod Fi
867 }
868 }
869
870 // Returns a vector crt[] such that crt[i] = p mod Ft (with t = T[i])
871 template <typename type>
CRT_decompose(std::vector<RX> & crt,const RX & H) const872 void PAlgebraModDerived<type>::CRT_decompose(std::vector<RX>& crt,
873 const RX& H) const
874 {
875 long nSlots = zMStar.getNSlots();
876
877 if (isDryRun()) {
878 crt.clear();
879 return;
880 }
881 resize(crt, nSlots);
882 for (long i = 0; i < nSlots; i++)
883 rem(crt[i], H, factors[i]); // crt[i] = H % factors[i]
884 }
885
886 template <typename type>
embedInAllSlots(RX & H,const RX & alpha,const MappingData<type> & mappingData) const887 void PAlgebraModDerived<type>::embedInAllSlots(
888 RX& H,
889 const RX& alpha,
890 const MappingData<type>& mappingData) const
891 {
892 if (isDryRun()) {
893 H = RX::zero();
894 return;
895 }
896 HELIB_TIMER_START;
897 long nSlots = zMStar.getNSlots();
898
899 std::vector<RX> crt(nSlots); // allocate space for CRT components
900
901 // The i'th CRT component is (H mod F_t) = alpha(maps[i]) mod F_t,
902 // where with t=T[i].
903
904 if (IsX(mappingData.G) || deg(alpha) <= 0) {
905 // special case...no need for CompMod, which is
906 // is not optimized for this case
907
908 for (long i = 0; i < nSlots; i++) // crt[i] = alpha(maps[i]) mod Ft
909 crt[i] = ConstTerm(alpha);
910 } else {
911 // general case...
912
913 // FIXME: should update this to use matrix_maps, but this routine
914 // isn't actually used anywhere
915
916 for (long i = 0; i < nSlots; i++) // crt[i] = alpha(maps[i]) mod Ft
917 CompMod(crt[i], alpha, mappingData.maps[i], factors[i]);
918 }
919
920 CRT_reconstruct(H, crt); // interpolate to get H
921 HELIB_TIMER_STOP;
922 }
923
924 template <typename type>
embedInSlots(RX & H,const std::vector<RX> & alphas,const MappingData<type> & mappingData) const925 void PAlgebraModDerived<type>::embedInSlots(
926 RX& H,
927 const std::vector<RX>& alphas,
928 const MappingData<type>& mappingData) const
929 {
930 if (isDryRun()) {
931 H = RX::zero();
932 return;
933 }
934 HELIB_TIMER_START;
935
936 long nSlots = zMStar.getNSlots();
937 // assert(lsize(alphas) == nSlots);
938 assertEq(
939 lsize(alphas),
940 nSlots,
941 "Cannot embed in slots: alphas size is different than number of slots");
942
943 long d = mappingData.degG;
944 for (long i = 0; i < nSlots; i++)
945 assertTrue(deg(alphas[i]) < d,
946 "Bad alpha element at index i: its degree is greater or "
947 "equal than mappingData.degG");
948
949 std::vector<RX> crt(nSlots); // allocate space for CRT components
950
951 // The i'th CRT component is (H mod F_t) = alphas[i](maps[i]) mod F_t,
952 // where with t=T[i].
953
954 if (IsX(mappingData.G)) {
955 // special case...no need for CompMod, which is
956 // is not optimized for this case
957
958 for (long i = 0; i < nSlots; i++) // crt[i] = alpha(maps[i]) mod Ft
959 crt[i] = ConstTerm(alphas[i]);
960 } else {
961 // general case...still try to avoid CompMod when possible,
962 // which is the common case for encoding masks
963
964 HELIB_NTIMER_START(CompMod);
965
966 #if 0
967 for (long i: range(nSlots)) {
968 if (deg(alphas[i]) <= 0)
969 crt[i] = alphas[i];
970 else
971 CompMod(crt[i], alphas[i], mappingData.maps[i], factors[i]);
972 }
973 #else
974 vec_R in, out;
975
976 for (long i : range(nSlots)) {
977 if (deg(alphas[i]) <= 0)
978 crt[i] = alphas[i];
979 else {
980 VectorCopy(in, alphas[i], d);
981 mul(out, in, mappingData.matrix_maps[i]);
982 conv(crt[i], out);
983 }
984 }
985 #endif
986 }
987
988 CRT_reconstruct(H, crt); // interpolate to get p
989
990 HELIB_TIMER_STOP;
991 }
992
993 template <typename type>
CRT_reconstruct(RX & H,std::vector<RX> & crt) const994 void PAlgebraModDerived<type>::CRT_reconstruct(RX& H,
995 std::vector<RX>& crt) const
996 {
997 if (isDryRun()) {
998 H = RX::zero();
999 return;
1000 }
1001 HELIB_TIMER_START;
1002 long nslots = zMStar.getNSlots();
1003
1004 const std::vector<RX>& ctab = crtTable;
1005
1006 clear(H);
1007 RX tmp1, tmp2;
1008
1009 bool easy = true;
1010 for (long i = 0; i < nslots; i++)
1011 if (!IsZero(crt[i]) && !IsOne(crt[i])) {
1012 easy = false;
1013 break;
1014 }
1015
1016 if (easy) {
1017 // VJS-FIXME: this looks easy, but could
1018 // be slower asymptotically
1019 for (long i = 0; i < nslots; i++)
1020 if (!IsZero(crt[i]))
1021 H += ctab[i];
1022 } else {
1023 std::vector<RX> crt1;
1024 resize(crt1, nslots);
1025 for (long i = 0; i < nslots; i++)
1026 MulMod(crt1[i], crt[i], crtCoeffs[i], factors[i]);
1027
1028 evalTree(H, crtTree, crt1, 0, nslots);
1029 }
1030 HELIB_TIMER_STOP;
1031 }
1032
1033 template <typename type>
mapToFt(RX & w,const RX & G,long t,const RX * rF1) const1034 void PAlgebraModDerived<type>::mapToFt(RX& w,
1035 const RX& G,
1036 long t,
1037 const RX* rF1) const
1038 {
1039 if (isDryRun()) {
1040 w = RX::zero();
1041 return;
1042 }
1043 long i = zMStar.indexOfRep(t);
1044 if (i < 0) {
1045 clear(w);
1046 return;
1047 }
1048
1049 if (rF1 == nullptr) { // Compute the representation "from scratch"
1050 // special case
1051 if (G == factors[i]) {
1052 SetX(w);
1053 return;
1054 }
1055
1056 // special case
1057 if (deg(G) == 1) {
1058 w = -ConstTerm(G);
1059 return;
1060 }
1061
1062 // the general case: currently only works when r == 1
1063 assertEq(r, 1l, "Bad Hensel lifting value in general case: r is not 1");
1064
1065 REBak bak;
1066 bak.save();
1067 RE::init(factors[i]); // work with the extension field GF_p[X]/Ft(X)
1068 REX Ga;
1069 conv(Ga, G); // G as a polynomial over the extension field
1070
1071 vec_RE roots;
1072 FindRoots(roots, Ga); // Find roots of G in this field
1073 RE* first = &roots[0];
1074 RE* last = first + lsize(roots);
1075 RE* smallest = std::min_element(
1076 first,
1077 last,
1078 static_cast<bool (*)(const RE&, const RE&)>(less_than));
1079 // make a canonical choice
1080 w = rep(*smallest);
1081 return;
1082 }
1083 // if rF1 is set, then use it instead, setting w = rF1(X^t) mod Ft(X)
1084 RXModulus Ft(factors[i]);
1085 // long tInv = InvMod(t,m);
1086 RX X2t = PowerXMod(t, Ft); // X2t = X^t mod Ft
1087 w = CompMod(*rF1, X2t, Ft); // w = F1(X2t) mod Ft
1088
1089 /* Debugging sanity-check: G(w)=0 in the extension field (Z/2Z)[X]/Ft(X)
1090 RE::init(factors[i]);
1091 REX Ga;
1092 conv(Ga, G); // G as a polynomial over the extension field
1093 RE ra;
1094 conv(ra, w); // w is an element in the extension field
1095 eval(ra,Ga,ra); // ra = Ga(ra)
1096 if (!IsZero(ra)) {// check that Ga(w)=0 in this extension field
1097 cout << "rF1(X^t) mod Ft(X) != root of G mod Ft, t=" << t << endl;
1098 exit(0);
1099 }*******************************************************************/
1100 }
1101
1102 template <typename type>
mapToSlots(MappingData<type> & mappingData,const RX & G) const1103 void PAlgebraModDerived<type>::mapToSlots(MappingData<type>& mappingData,
1104 const RX& G) const
1105 {
1106 assertTrue<InvalidArgument>(
1107 deg(G) > 0,
1108 "Polynomial G is constant (has degree less than one)");
1109 assertEq(zMStar.getOrdP() % deg(G),
1110 0l,
1111 "Degree of polynomial G does not divide zMStar.getOrdP()");
1112 assertTrue<InvalidArgument>(static_cast<bool>(LeadCoeff(G) == 1l),
1113 "Polynomial G is not monic");
1114 mappingData.G = G;
1115 mappingData.degG = deg(mappingData.G);
1116 long d = deg(G);
1117 long ordp = zMStar.getOrdP();
1118
1119 long nSlots = zMStar.getNSlots();
1120 long m = zMStar.getM();
1121
1122 resize(mappingData.maps, nSlots);
1123
1124 mapToF1(mappingData.maps[0], mappingData.G); // mapping from base-G to base-F1
1125 for (long i = 1; i < nSlots; i++)
1126 mapToFt(mappingData.maps[i],
1127 mappingData.G,
1128 zMStar.ith_rep(i),
1129 &(mappingData.maps[0]));
1130
1131 // create matrices to streamline CompMod operations
1132 resize(mappingData.matrix_maps, nSlots);
1133 for (long i : range(nSlots)) {
1134 mat_R& mat = mappingData.matrix_maps[i];
1135 mat.SetDims(d, ordp);
1136 RX pow;
1137 pow = 1;
1138 for (long j : range(d)) {
1139 VectorCopy(mat[j], pow, ordp);
1140 if (j < d - 1)
1141 MulMod(pow, pow, mappingData.maps[i], factors[i]);
1142 }
1143 }
1144
1145 REBak bak;
1146 bak.save();
1147 RE::init(mappingData.G);
1148 mappingData.contextForG.save();
1149
1150 if (deg(mappingData.G) == 1)
1151 return;
1152
1153 resize(mappingData.rmaps, nSlots);
1154
1155 if (G == factors[0]) {
1156 // an important special case
1157
1158 for (long i = 0; i < nSlots; i++) {
1159 long t = zMStar.ith_rep(i);
1160 long tInv = NTL::InvMod(t, m);
1161
1162 RX ct_rep;
1163 PowerXMod(ct_rep, tInv, G);
1164
1165 RE ct;
1166 conv(ct, ct_rep);
1167
1168 REX Qi;
1169 SetCoeff(Qi, 1, 1);
1170 SetCoeff(Qi, 0, -ct);
1171
1172 mappingData.rmaps[i] = Qi;
1173 }
1174 } else {
1175 // the general case: currently only works when r == 1
1176
1177 assertEq(r, 1l, "Bad Hensel lifting value in general case: r is not 1");
1178
1179 vec_REX FRts;
1180 for (long i = 0; i < nSlots; i++) {
1181 // We need to lift Fi from R[Y] to (R[X]/G(X))[Y]
1182 REX Qi;
1183 long t, tInv = 0;
1184
1185 if (i == 0) {
1186 conv(Qi, factors[i]);
1187 FRts = EDF(Qi, FrobeniusMap(Qi), deg(Qi) / deg(G));
1188 // factor Fi over GF(p)[X]/G(X)
1189 } else {
1190 t = zMStar.ith_rep(i);
1191 tInv = NTL::InvMod(t, m);
1192 }
1193
1194 // need to choose the right factor, the one that gives us back X
1195 long j;
1196 for (j = 0; j < lsize(FRts); j++) {
1197 // lift maps[i] to (R[X]/G(X))[Y] and reduce mod j'th factor of Fi
1198
1199 REX FRtsj;
1200 if (i == 0)
1201 FRtsj = FRts[j];
1202 else {
1203 REX X2tInv = PowerXMod(tInv, FRts[j]);
1204 IrredPolyMod(FRtsj, X2tInv, FRts[j]);
1205 }
1206
1207 // FRtsj is the jth factor of factors[i] over the extension field.
1208 // For j > 0, we save some time by computing it from the jth factor
1209 // of factors[0] via a minimal polynomial computation.
1210
1211 REX GRti;
1212 conv(GRti, mappingData.maps[i]);
1213 GRti %= FRtsj;
1214
1215 if (IsX(rep(ConstTerm(GRti)))) { // is GRti == X?
1216 Qi = FRtsj; // If so, we found the right factor
1217 break;
1218 } // If this does not happen then move to the next factor of Fi
1219 }
1220
1221 assertTrue(j < lsize(FRts),
1222 "Cannot find the right factor Qi. Loop did not "
1223 "terminate before visiting all elements");
1224 mappingData.rmaps[i] = Qi;
1225 }
1226 }
1227 }
1228
1229 template <typename type>
decodePlaintext(std::vector<RX> & alphas,const RX & ptxt,const MappingData<type> & mappingData) const1230 void PAlgebraModDerived<type>::decodePlaintext(
1231 std::vector<RX>& alphas,
1232 const RX& ptxt,
1233 const MappingData<type>& mappingData) const
1234 {
1235 long nSlots = zMStar.getNSlots();
1236 if (isDryRun()) {
1237 alphas.assign(nSlots, RX::zero());
1238 return;
1239 }
1240
1241 // First decompose p into CRT components
1242 std::vector<RX> CRTcomps(nSlots); // allocate space for CRT component
1243 CRT_decompose(CRTcomps, ptxt); // CRTcomps[i] = p mod facors[i]
1244
1245 if (mappingData.degG == 1) {
1246 alphas = CRTcomps;
1247 return;
1248 }
1249
1250 resize(alphas, nSlots);
1251
1252 REBak bak;
1253 bak.save();
1254 mappingData.contextForG.restore();
1255
1256 for (long i = 0; i < nSlots; i++) {
1257 REX te;
1258 conv(te, CRTcomps[i]); // lift i'th CRT component to mod G(X)
1259 te %= mappingData
1260 .rmaps[i]; // reduce CRTcomps[i](Y) mod Qi(Y), over (Z_2[X]/G(X))
1261
1262 // the free term (no Y component) should be our answer (as a poly(X))
1263 alphas[i] = rep(ConstTerm(te));
1264 }
1265 }
1266
1267 template <typename type>
buildLinPolyCoeffs(std::vector<RX> & C,const std::vector<RX> & L,const MappingData<type> & mappingData) const1268 void PAlgebraModDerived<type>::buildLinPolyCoeffs(
1269 std::vector<RX>& C,
1270 const std::vector<RX>& L,
1271 const MappingData<type>& mappingData) const
1272 {
1273 REBak bak;
1274 bak.save();
1275 mappingData.contextForG.restore();
1276
1277 long d = RE::degree();
1278 long p = zMStar.getP();
1279
1280 assertEq(lsize(L), d, "Vector L size is different than RE::degree()");
1281
1282 vec_RE LL;
1283 resize(LL, d);
1284
1285 for (long i = 0; i < d; i++)
1286 conv(LL[i], L[i]);
1287
1288 vec_RE CC;
1289 ::helib::buildLinPolyCoeffs(CC, LL, p, r);
1290
1291 resize(C, d);
1292 for (long i = 0; i < d; i++)
1293 C[i] = rep(CC[i]);
1294 }
1295
1296 // code for generating mask tables
1297 // currently, this is done when the PAlgebraMod
1298 // object is constructed.
1299
1300 // VJS-FIXME: what were we thinking? these tables
1301 // can be huge
1302
1303 template <typename type>
genMaskTable()1304 void PAlgebraModDerived<type>::genMaskTable()
1305 {
1306 // This is only called by the constructor, which has already
1307 // set the zz_p context and the crtTable
1308 resize(maskTable, zMStar.numOfGens());
1309 for (long i = 0; i < (long)zMStar.numOfGens(); i++) {
1310 long ord = zMStar.OrderOf(i);
1311 resize(maskTable[i], ord + 1);
1312 maskTable[i][ord] = 0;
1313 for (long j = ord - 1; j >= 1; j--) {
1314 // initialize mask that is 1 whenever the ith coordinate is at least j
1315 // Note: maskTable[i][0] = constant 1, maskTable[i][ord] = constant 0
1316 maskTable[i][j] = maskTable[i][j + 1];
1317 for (long k = 0; k < (long)zMStar.getNSlots(); k++) {
1318 if (zMStar.coordinate(i, k) == j) {
1319 add(maskTable[i][j], maskTable[i][j], crtTable[k]);
1320 }
1321 }
1322 }
1323 maskTable[i][0] = 1;
1324 }
1325 }
1326
1327 // code for generating crt tables
1328 // currently, this is done when the PAlgebraMod
1329 // object is constructed.
1330
1331 // VJS-FIXME: what were we thinking? these tables
1332 // can be huge
1333
1334 template <typename type>
genCrtTable()1335 void PAlgebraModDerived<type>::genCrtTable()
1336 {
1337 // This is only called by the constructor, which has already
1338 // set the zz_p context
1339
1340 long nslots = zMStar.getNSlots();
1341 resize(crtTable, nslots);
1342 for (long i = 0; i < nslots; i++) {
1343 RX allBut_i = PhimXMod / factors[i]; // = \prod_{j \ne i }Fj
1344 allBut_i *= crtCoeffs[i]; // = 1 mod Fi and = 0 mod Fj for j \ne i
1345 crtTable[i] = allBut_i;
1346 }
1347
1348 buildTree(crtTree, 0, nslots);
1349 }
1350
1351 template <typename type>
buildTree(std::shared_ptr<TNode<RX>> & res,long offset,long extent) const1352 void PAlgebraModDerived<type>::buildTree(std::shared_ptr<TNode<RX>>& res,
1353 long offset,
1354 long extent) const
1355 {
1356 if (extent == 1)
1357 res = buildTNode<RX>(nullTNode<RX>(), nullTNode<RX>(), factors[offset]);
1358 else {
1359 long half = extent / 2;
1360 std::shared_ptr<TNode<RX>> left, right;
1361 buildTree(left, offset, half);
1362 buildTree(right, offset + half, extent - half);
1363 RX data = left->data * right->data;
1364 res = buildTNode<RX>(left, right, data);
1365 }
1366 }
1367
1368 template <typename type>
evalTree(RX & res,std::shared_ptr<TNode<RX>> tree,const std::vector<RX> & crt1,long offset,long extent) const1369 void PAlgebraModDerived<type>::evalTree(RX& res,
1370 std::shared_ptr<TNode<RX>> tree,
1371 const std::vector<RX>& crt1,
1372 long offset,
1373 long extent) const
1374 {
1375 if (extent == 1)
1376 res = crt1[offset];
1377 else {
1378 long half = extent / 2;
1379 RX lres, rres;
1380 evalTree(lres, tree->left, crt1, offset, half);
1381 evalTree(rres, tree->right, crt1, offset + half, extent - half);
1382 RX tmp1, tmp2;
1383 mul(tmp1, lres, tree->right->data);
1384 mul(tmp2, rres, tree->left->data);
1385 add(tmp1, tmp1, tmp2);
1386 res = tmp1;
1387 }
1388 }
1389
1390 // Explicit instantiation
1391
1392 template class PAlgebraModDerived<PA_GF2>;
1393 template class PAlgebraModDerived<PA_zz_p>;
1394
1395 } // namespace helib
1396