1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 
13 #include <helib/PAlgebra.h>
14 #include <helib/hypercube.h>
15 #include <helib/timing.h>
16 #include <helib/range.h>
17 
18 #include <NTL/ZZXFactoring.h>
19 #include <NTL/GF2EXFactoring.h>
20 #include <NTL/lzz_pEXFactoring.h>
21 #include <NTL/BasicThreadPool.h>
22 
23 #include <algorithm> // defines count(...), min(...)
24 #include <cmath>
25 #include <mutex> // std::mutex, std::unique_lock
26 
27 namespace helib {
28 
29 // polynomials are sorted lexicographically, with the
30 // constant term being the "most significant"
31 
32 template <typename RX>
33 bool poly_comp(const RX& a, const RX& b);
34 
less_than(NTL::GF2 a,NTL::GF2 b)35 bool less_than(NTL::GF2 a, NTL::GF2 b) { return rep(a) < rep(b); }
less_than(NTL::zz_p a,NTL::zz_p b)36 bool less_than(NTL::zz_p a, NTL::zz_p b) { return rep(a) < rep(b); }
37 
less_than(const NTL::GF2X & a,const NTL::GF2X & b)38 bool less_than(const NTL::GF2X& a, const NTL::GF2X& b)
39 {
40   return poly_comp(a, b);
41 }
less_than(const NTL::zz_pX & a,const NTL::zz_pX & b)42 bool less_than(const NTL::zz_pX& a, const NTL::zz_pX& b)
43 {
44   return poly_comp(a, b);
45 }
46 
less_than(const NTL::GF2E & a,const NTL::GF2E & b)47 bool less_than(const NTL::GF2E& a, const NTL::GF2E& b)
48 {
49   return less_than(rep(a), rep(b));
50 }
less_than(const NTL::zz_pE & a,const NTL::zz_pE & b)51 bool less_than(const NTL::zz_pE& a, const NTL::zz_pE& b)
52 {
53   return less_than(rep(a), rep(b));
54 }
55 
less_than(const NTL::GF2EX & a,const NTL::GF2EX & b)56 bool less_than(const NTL::GF2EX& a, const NTL::GF2EX& b)
57 {
58   return poly_comp(a, b);
59 }
less_than(const NTL::zz_pEX & a,const NTL::zz_pEX & b)60 bool less_than(const NTL::zz_pEX& a, const NTL::zz_pEX& b)
61 {
62   return poly_comp(a, b);
63 }
64 
65 template <typename RX>
poly_comp(const RX & a,const RX & b)66 bool poly_comp(const RX& a, const RX& b)
67 {
68   long na = deg(a) + 1;
69   long nb = deg(b) + 1;
70 
71   long i = 0;
72   while (i < na && i < nb && coeff(a, i) == coeff(b, i))
73     i++;
74 
75   if (i < na && i < nb)
76     return less_than(coeff(a, i), coeff(b, i));
77   else
78     return na < nb;
79 }
80 
operator ==(const PAlgebra & other) const81 bool PAlgebra::operator==(const PAlgebra& other) const
82 {
83   if (m != other.m)
84     return false;
85   if (p != other.p)
86     return false;
87 
88   return true;
89 }
90 
exponentiate(const std::vector<long> & exps,bool onlySameOrd) const91 long PAlgebra::exponentiate(const std::vector<long>& exps,
92                             bool onlySameOrd) const
93 {
94   if (isDryRun())
95     return 1;
96   long t = 1;
97   long n = std::min(exps.size(), gens.size());
98   for (long i = 0; i < n; i++) {
99     if (onlySameOrd && !SameOrd(i))
100       continue;
101     long g = NTL::PowerMod(gens[i], exps[i], m);
102     t = NTL::MulMod(t, g, m);
103   }
104   return t;
105 }
106 
printout(std::ostream & out) const107 void PAlgebra::printout(std::ostream& out) const
108 {
109   out << "m = " << m << ", p = " << p;
110   if (isDryRun()) {
111     out << " (dry run)" << std::endl;
112     return;
113   }
114 
115   out << ", phi(m) = " << phiM << std::endl;
116   out << "  ord(p) = " << ordP << std::endl;
117   out << "  normBnd = " << normBnd << std::endl;
118   out << "  polyNormBnd = " << polyNormBnd << std::endl;
119 
120   std::vector<long> facs;
121   factorize(facs, m);
122   out << "  factors = " << facs << std::endl;
123 
124   for (std::size_t i = 0; i < gens.size(); i++)
125     if (gens[i]) {
126       // FIXME: is it really possible that gens[i] can be 0?
127       // There is very likely some code here and there that
128       // would break if that happens.
129 
130       out << "  generator " << gens[i] << " has order (";
131       if (FrobPerturb(i) == 0)
132         out << "=";
133       else if (FrobPerturb(i) > 0)
134         out << "!";
135       else
136         out << "!!";
137       out << "= Z_m^*) of ";
138       out << OrderOf(i) << std::endl;
139     }
140 
141   if (cube.getSize() < 40) {
142     out << "  T = [ ";
143     for (auto const& t : T)
144       out << t << " ";
145     out << "]" << std::endl;
146   }
147 }
148 
printAll(std::ostream & out) const149 void PAlgebra::printAll(std::ostream& out) const
150 {
151   printout(out);
152   if (cube.getSize() < 40) {
153     out << "  Tidx = [ ";
154     for (const auto& x : Tidx)
155       out << x << " ";
156     out << "]\n";
157     out << "  zmsIdx = [ ";
158     for (const auto& x : zmsIdx)
159       out << x << " ";
160     out << "]\n";
161     out << "  zmsRep = [ ";
162     for (const auto& x : zmsRep)
163       out << x << " ";
164     out << "]\n";
165   }
166 }
167 
cotan(double x)168 static double cotan(double x) { return 1 / tan(x); }
169 
half_FFT(long m)170 half_FFT::half_FFT(long m) : fft(m / 2)
171 {
172   typedef std::complex<double> cmplx_t;
173   typedef long double ldbl;
174 
175   pow.resize(m / 2);
176   for (long i : range(m / 2)) {
177     // pow[i] = 2^{2*pi*I*(i/m)}
178     ldbl angle = -((2.0L * PI) * (ldbl(i) / ldbl(m)));
179     pow[i] = cmplx_t(std::cos(angle), std::sin(angle));
180   }
181 }
182 
quarter_FFT(long m)183 quarter_FFT::quarter_FFT(long m) : fft(m / 4)
184 {
185   typedef std::complex<double> cmplx_t;
186   typedef long double ldbl;
187 
188   pow1.resize(m / 4);
189   pow2.resize(m / 4);
190   for (long i : range(m / 2)) {
191     // pow[i] = 2^{2*pi*I*(i/m)}
192     ldbl angle = -((2.0L * PI) * (ldbl(i) / ldbl(m)));
193     if (i % 2)
194       pow1[i >> 1] = cmplx_t(std::cos(angle), std::sin(angle));
195     else
196       pow2[i >> 1] = cmplx_t(std::cos(angle), std::sin(angle));
197   }
198 }
199 
MUL(std::complex<double> a,std::complex<double> b)200 static inline std::complex<double> MUL(std::complex<double> a,
201                                        std::complex<double> b)
202 {
203   double x = a.real(), y = a.imag(), u = b.real(), v = b.imag();
204   return std::complex<double>(x * u - y * v, x * v + y * u);
205 }
206 
ABS(std::complex<double> a)207 static inline double ABS(std::complex<double> a)
208 {
209   double x = a.real(), y = a.imag();
210   return std::sqrt(x * x + y * y);
211 }
212 
calcPolyNormBnd(long m)213 double calcPolyNormBnd(long m)
214 {
215   assertTrue(m >= 1, "m >= 1");
216 
217   typedef std::complex<double> cmplx_t;
218   typedef long double ldbl;
219 
220   // first, remove 2's
221   while (m % 2 == 0)
222     m /= 2;
223 
224   if (m == 1) {
225     return 1;
226   }
227 
228   std::vector<long> fac;
229   factorize(fac, m);
230 
231   long radm = 1;
232   for (long p : fac)
233     radm *= p;
234 
235   if (fac.size() == 1) {
236     long u = fac[0];
237     return 2.0L * cotan(PI / (2.0L * u)) / u;
238   }
239 
240   m = radm;
241 
242   long n = phi_N(m);
243 
244   NTL::ZZX PhiPoly = Cyclotomic(m);
245 
246   std::vector<double> a(n);
247   for (long i : range(n))
248     conv(a[i], PhiPoly[i]);
249   // a does not include the leading coefficient 1
250   // NOTE: according to the Arnold and Monogan paper
251   // (Table 6) the least m such that the coefficients of Phi_m
252   // do not fit in 53-bits is m=43,730,115.
253 
254   std::vector<cmplx_t> roots(m);
255   std::vector<cmplx_t> x(n);
256 
257   for (long i : range(m)) {
258     ldbl re = std::cos(2.0L * PI * (ldbl(i) / ldbl(m)));
259     ldbl im = std::sin(2.0L * PI * (ldbl(i) / ldbl(m)));
260     roots[i] = cmplx_t(re, im);
261   }
262 
263   std::vector<long> res_tab(n);
264 
265   long row_num = 0;
266   for (long i : range(1, m)) {
267     if (NTL::GCD(i, m) != 1)
268       continue;
269     x[row_num] = roots[i];
270     res_tab[row_num] = i;
271     row_num++;
272   }
273 
274   std::vector<double> dist_tab_vec(2 * m - 1);
275   std::vector<int> dist_exp_tab_vec(2 * m);
276 
277   double* dist_tab = &dist_tab_vec[m - 1];
278   int* dist_exp_tab = &dist_exp_tab_vec[m - 1];
279 
280   const double sqrt2_inv = 1.0 / std::sqrt(ldbl(2));
281   constexpr long FREXP_ITER = 1600;
282 
283   for (long i : range(1, m)) {
284     dist_tab[i] = std::frexp(double(2.0L * std::sin(PI * (ldbl(i) / ldbl(m)))),
285                              &dist_exp_tab[i]);
286 
287     if (dist_tab[i] < sqrt2_inv) {
288       dist_tab[i] *= 2.0;
289       dist_exp_tab[i]--;
290     }
291 
292     dist_tab[-i] = dist_tab[i];
293     dist_exp_tab[-i] = dist_exp_tab[i];
294   }
295 
296   dist_tab[0] = 1;
297   dist_exp_tab[0] = 0;
298 
299   std::vector<double> global_norm_col(n);
300   for (long i : range(n))
301     global_norm_col[i] = 0;
302   std::mutex global_norm_col_mutex;
303 
304   NTL_EXEC_RANGE(n, first, last)
305 
306   std::vector<double> norm_col(n);
307   for (long i : range(n))
308     norm_col[i] = 0;
309 
310   long j = first;
311 
312   for (; j <= last - 2; j += 2) {
313     // process columns j and j+1 of inverse matrix
314     // NOTE: processing columns two at a time gives an almost 2x speedup
315 
316     long res_j = res_tab[j];
317     long res_j_1 = res_tab[j + 1];
318 
319     double prod = 1;
320     double prod_1 = 1;
321     long e_total = 0;
322     long e_total_1 = 0;
323     int e;
324     int e_1;
325 
326     {
327       long i = 0;
328       while (i <= n - FREXP_ITER) {
329         for (long k = 0; k < FREXP_ITER; k++) {
330           long res_i = res_tab[i + k];
331           prod *= dist_tab[res_i - res_j];
332           prod_1 *= dist_tab[res_i - res_j_1];
333           e_total += dist_exp_tab[res_i - res_j];
334           e_total_1 += dist_exp_tab[res_i - res_j_1];
335         }
336         prod = std::frexp(prod, &e);
337         prod_1 = std::frexp(prod_1, &e_1);
338         e_total += e;
339         e_total_1 += e_1;
340 
341         i += FREXP_ITER;
342       }
343       while (i < n) {
344         long res_i = res_tab[i];
345         prod *= dist_tab[res_i - res_j];
346         prod_1 *= dist_tab[res_i - res_j_1];
347         e_total += dist_exp_tab[res_i - res_j];
348         e_total_1 += dist_exp_tab[res_i - res_j_1];
349         i++;
350       }
351     }
352 
353     prod = std::ldexp(prod, e_total);
354     prod_1 = std::ldexp(prod_1, e_total_1);
355 
356     double inv_prod = 1.0 / prod;
357     double inv_prod_1 = 1.0 / prod_1;
358 
359     cmplx_t xj = x[j];
360     cmplx_t xj_1 = x[j + 1];
361     cmplx_t q = 1;
362     cmplx_t q_1 = 1;
363 
364     norm_col[0] += (inv_prod + inv_prod_1);
365 
366     for (long i : range(1, n)) {
367       q = MUL(q, xj) + a[n - i];
368       q_1 = MUL(q_1, xj_1) + a[n - i];
369       norm_col[i] += (ABS(q) * inv_prod + ABS(q_1) * inv_prod_1);
370     }
371   }
372 
373   if (j == last - 1) {
374     // process column j of inverse matrix
375 
376     long res_j = res_tab[j];
377 
378     double prod = 1;
379     long e_total = 0;
380     int e;
381 
382     {
383       long i = 0;
384       while (i <= n - FREXP_ITER) {
385         for (long k = 0; k < FREXP_ITER; k++) {
386           long res_i = res_tab[i + k];
387           prod *= dist_tab[res_i - res_j];
388           e_total += dist_exp_tab[res_i - res_j];
389         }
390         prod = std::frexp(prod, &e);
391         e_total += e;
392         i += FREXP_ITER;
393       }
394       while (i < n) {
395         long res_i = res_tab[i];
396         prod *= dist_tab[res_i - res_j];
397         e_total += dist_exp_tab[res_i - res_j];
398         i++;
399       }
400     }
401 
402     prod = std::ldexp(prod, e_total);
403 
404     double inv_prod = 1.0 / prod;
405 
406     cmplx_t xj = x[j];
407     cmplx_t q = 1;
408     norm_col[0] += inv_prod;
409     for (long i : range(1, n)) {
410       q = MUL(q, xj) + a[n - i];
411       norm_col[i] += ABS(q) * inv_prod;
412     }
413   }
414 
415   std::lock_guard<std::mutex> guard(global_norm_col_mutex);
416 
417   for (long i : range(n))
418     global_norm_col[i] += norm_col[i];
419 
420   NTL_EXEC_INDEX_END
421 
422   double max_norm = 0;
423   for (long i : range(n)) {
424     if (max_norm < global_norm_col[i])
425       max_norm = global_norm_col[i];
426   }
427 
428   return max_norm;
429 }
430 
PAlgebra(long mm,long pp,const std::vector<long> & _gens,const std::vector<long> & _ords)431 PAlgebra::PAlgebra(long mm,
432                    long pp,
433                    const std::vector<long>& _gens,
434                    const std::vector<long>& _ords) :
435     m(mm), p(pp), cM(1.0) // default value for the ring constant
436 {
437   assertInRange<InvalidArgument>(mm,
438                                  2l,
439                                  NTL_SP_BOUND,
440                                  "mm is not in [2, NTL_SP_BOUND)");
441   if (pp == -1) // pp==-1 signals using the complex field for plaintext
442     pp = m - 1;
443   else {
444     assertTrue<InvalidArgument>((bool)NTL::ProbPrime(pp),
445                                 "Modulus pp is not prime (nor -1)");
446     assertNeq<InvalidArgument>(mm % pp, 0l, "Modulus pp divides mm");
447   }
448 
449   long k = NTL::NextPowerOfTwo(mm);
450   if (static_cast<unsigned long>(mm) == (1UL << k)) // m is a power of two
451     pow2 = k;
452   else if (p != -1) // is not power of two, set to zero (even if m is even!)
453     pow2 = 0;
454   else // CKKS requires m to be a power of two.  Throw if not.
455     throw InvalidArgument("CKKS scheme only supports m as a power of two.");
456 
457   // For dry-run, use a tiny m value for the PAlgebra tables
458   if (isDryRun())
459     m = (p == 3) ? 4 : 3;
460 
461   // Compute the generators for (Z/mZ)^* (defined in NumbTh.cpp)
462 
463   std::vector<long> tmpOrds;
464   if (_gens.size() > 0 && _gens.size() == _ords.size() && !isDryRun()) {
465     // externally supplied generator,orders
466     tmpOrds = _ords;
467     this->gens = _gens;
468     this->ordP = multOrd(pp, mm);
469   } else
470     // treat externally supplied generators (if any) as candidates
471     this->ordP = findGenerators(this->gens, tmpOrds, mm, pp, _gens);
472 
473   // Record for each generator gi whether it has the same order in
474   // ZM* as in Zm* /(p,g1,...,g_{i-1})
475 
476   resize(native, lsize(tmpOrds));
477   resize(frob_perturb, lsize(tmpOrds));
478   std::vector<long> p_subgp(mm);
479   for (long i : range(mm))
480     p_subgp[i] = -1;
481   long pmodm = pp % mm;
482   p_subgp[1] = 0;
483   for (long i = 1, p2i = pmodm; p2i != 1; i++, p2i = NTL::MulMod(p2i, pmodm, m))
484     p_subgp[p2i] = i;
485   for (long j : range(tmpOrds.size())) {
486     tmpOrds[j] = std::abs(tmpOrds[j]);
487     // for backward compatibility, a user supplied
488     // ords value could be negative, but we ignore that here.
489     // For testing and debugging, we may want to not ignore this...
490 
491     long i = NTL::PowerMod(this->gens[j], tmpOrds[j], m);
492 
493     native[j] = (i == 1);
494     frob_perturb[j] = p_subgp[i];
495   }
496 
497   cube.initSignature(tmpOrds); // set hypercube with these dimensions
498 
499   phiM = ordP * getNSlots();
500 
501   NTL::Vec<NTL::Pair<long, long>> factors;
502   factorize(factors, mm);
503   nfactors = factors.length();
504 
505   radm = 1;
506   for (long i : range(nfactors))
507     radm *= factors[i].a;
508 
509   normBnd = 1.0;
510   for (long i : range(nfactors)) {
511     long u = factors[i].a;
512     normBnd *= 2.0L * cotan(PI / (2.0L * u)) / u;
513   }
514 
515   polyNormBnd = calcPolyNormBnd(mm);
516 
517   // Allocate space for the various arrays
518   resize(T, getNSlots());
519   Tidx.assign(mm, -1);   // allocate m slots, initialize them to -1
520   zmsIdx.assign(mm, -1); // allocate m slots, initialize them to -1
521   resize(zmsRep, phiM);
522   long i, idx;
523   for (i = idx = 0; i < mm; i++) {
524     if (NTL::GCD(i, mm) == 1) {
525       zmsIdx[i] = idx++;
526       zmsRep[zmsIdx[i]] = i;
527     }
528   }
529 
530   // Now fill the Tidx translation table. We identify an element t \in T
531   // with its representation t = \prod_{i=0}^n gi^{ei} mod m (where the
532   // gi's are the generators in gens[]) , represent t by the vector of
533   // exponents *in reverse order* (en,...,e1,e0), and order these vectors
534   // in lexicographic order.
535 
536   // FIXME: is the comment above about reverse order true?
537   // It doesn't seem like it to me, VJS.
538   // The comment about reverse order is correct, SH.
539 
540   // buffer is initialized to all-zero, which represents 1=\prod_i gi^0
541   std::vector<long> buffer(gens.size()); // temporary holds exponents
542   i = idx = 0;
543   long ctr = 0;
544   do {
545     ctr++;
546     long t = exponentiate(buffer);
547 
548     // sanity check for user-supplied gens
549     assertEq(NTL::GCD(t, mm), 1l, "Bad user-supplied generator");
550     assertEq(Tidx[t], -1l, "Slot at index t has already been assigned");
551 
552     T[i] = t;      // The i'th element in T it t
553     Tidx[t] = i++; // the index of t in T is i
554 
555     // increment buffer by one (in lexicographic order)
556   } while (nextExpVector(buffer)); // until we cover all the group
557 
558   // sanity check for user-supplied gens
559   assertEq(ctr, getNSlots(), "Bad user-supplied generator set");
560 
561   PhimX = Cyclotomic(mm); // compute and store Phi_m(X)
562   //  pp_factorize(mFactors,mm); // prime-power factorization from NumbTh.cpp
563 
564   if (mm % 2 == 0)
565     half_fftInfo = std::make_shared<half_FFT>(mm);
566   else
567     fftInfo = std::make_shared<PGFFT>(mm);
568 
569   // fftInfo = std::make_shared<PGFFT>(mm); // Need this for some
570   // debugging/timing
571 
572   if (mm % 4 == 0)
573     quarter_fftInfo = std::make_shared<quarter_FFT>(mm);
574 }
575 
comparePAlgebra(const PAlgebra & palg,unsigned long m,unsigned long p,UNUSED unsigned long r,const std::vector<long> & gens,const std::vector<long> & ords)576 bool comparePAlgebra(const PAlgebra& palg,
577                      unsigned long m,
578                      unsigned long p,
579                      UNUSED unsigned long r,
580                      const std::vector<long>& gens,
581                      const std::vector<long>& ords)
582 {
583   if (static_cast<unsigned long>(palg.getM()) != m ||
584       static_cast<unsigned long>(palg.getP()) != p ||
585       static_cast<std::size_t>(palg.numOfGens()) != gens.size() ||
586       static_cast<std::size_t>(palg.numOfGens()) != ords.size())
587     return false;
588 
589   for (long i = 0; i < (long)gens.size(); i++) {
590     if (long(palg.ZmStarGen(i)) != gens[i])
591       return false;
592 
593     if ((palg.SameOrd(i) && palg.OrderOf(i) != ords[i]) ||
594         (!palg.SameOrd(i) && palg.OrderOf(i) != -ords[i]))
595       return false;
596   }
597   return true;
598 }
599 
frobeniusPow(long j) const600 long PAlgebra::frobeniusPow(long j) const
601 {
602   return NTL::PowerMod(mcMod(p, m), j, m);
603   // Don't forget to reduce p mod m!!
604 }
605 
genToPow(long i,long j) const606 long PAlgebra::genToPow(long i, long j) const
607 {
608   long sz = gens.size();
609 
610   if (i == sz) {
611     assertTrue(j == 0, "PAlgebra::genToPow: i == sz but j != 0");
612     return 1;
613   }
614 
615   assertTrue(i >= -1 && i < LONG(gens.size()), "PAlgebra::genToPow: bad dim");
616 
617   long res;
618   if (i == -1)
619     res = frobeniusPow(j);
620   else
621     res = NTL::PowerMod(gens[i], j, m);
622 
623   return res;
624 }
625 
626 /***********************************************************************
627 
628   PAlgebraMod stuff....
629 
630 ************************************************************************/
631 
buildPAlgebraMod(const PAlgebra & zMStar,long r)632 PAlgebraModBase* buildPAlgebraMod(const PAlgebra& zMStar, long r)
633 {
634   long p = zMStar.getP();
635 
636   if (p == -1) // complex plaintext space
637     return new PAlgebraModCx(zMStar, r);
638 
639   assertTrue<InvalidArgument>(p >= 2,
640                               "Modulus p is less than 2 (nor -1 for CKKS)");
641   assertTrue<InvalidArgument>(r > 0, "Hensel lifting r is less than 1");
642   if (p == 2 && r == 1)
643     return new PAlgebraModDerived<PA_GF2>(zMStar, r);
644   else
645     return new PAlgebraModDerived<PA_zz_p>(zMStar, r);
646 }
647 
648 template <typename T>
649 void PAlgebraLift(const NTL::ZZX& phimx,
650                   const T& lfactors,
651                   T& factors,
652                   T& crtc,
653                   long r);
654 
655 // Missing NTL functionality
656 
EDF(NTL::vec_zz_pX & v,const NTL::zz_pX & f,long d)657 void EDF(NTL::vec_zz_pX& v, const NTL::zz_pX& f, long d)
658 {
659   EDF(v, f, PowerXMod(NTL::zz_p::modulus(), f), d);
660 }
661 
FrobeniusMap(const NTL::zz_pEXModulus & F)662 NTL::zz_pEX FrobeniusMap(const NTL::zz_pEXModulus& F)
663 {
664   return PowerXMod(NTL::zz_pE::cardinality(), F);
665 }
666 
667 template <typename type>
PAlgebraModDerived(const PAlgebra & _zMStar,long _r)668 PAlgebraModDerived<type>::PAlgebraModDerived(const PAlgebra& _zMStar, long _r) :
669     zMStar(_zMStar), r(_r)
670 
671 {
672   long p = zMStar.getP();
673   long m = zMStar.getM();
674 
675   // For dry-run, use a tiny m value for the PAlgebra tables
676   if (isDryRun())
677     m = (p == 3) ? 4 : 3;
678 
679   assertTrue<InvalidArgument>(r > 0l, "Hensel lifting r is less than 1");
680 
681   NTL::ZZ BigPPowR = NTL::power_ZZ(p, r);
682   assertTrue((bool)BigPPowR.SinglePrecision(),
683              "BigPPowR is not SinglePrecision");
684   pPowR = to_long(BigPPowR);
685 
686   long nSlots = zMStar.getNSlots();
687 
688   RBak bak;
689   bak.save();
690   SetModulus(p);
691 
692   // Compute the factors Ft of Phi_m(X) mod p, for all t \in T
693 
694   RX phimxmod;
695 
696   conv(phimxmod, zMStar.getPhimX()); // Phi_m(X) mod p
697 
698   vec_RX localFactors;
699 
700   EDF(localFactors, phimxmod, zMStar.getOrdP()); // equal-degree factorization
701 
702   RX* first = &localFactors[0];
703   RX* last = first + lsize(localFactors);
704   RX* smallest =
705       std::min_element(first,
706                        last,
707                        static_cast<bool (*)(const RX&, const RX&)>(less_than));
708   swap(*first, *smallest);
709 
710   // We make the lexicographically smallest factor have index 0.
711   // The remaining factors are ordered according to their representatives.
712 
713   RXModulus F1(localFactors[0]);
714   for (long i = 1; i < nSlots; i++) {
715     long t = zMStar.ith_rep(i);      // Ft is minimal poly of x^{1/t} mod F1
716     long tInv = NTL::InvMod(t, m);   // tInv = t^{-1} mod m
717     RX X2tInv = PowerXMod(tInv, F1); // X2tInv = X^{1/t} mod F1
718     NTL::IrredPolyMod(localFactors[i], X2tInv, F1);
719     // IrredPolyMod(X,P,Q) returns in X the minimal polynomial of P mod Q
720   }
721   /* Debugging sanity-check #1: we should have Ft= GCD(F1(X^t),Phi_m(X))
722   for (i=1; i<nSlots; i++) {
723     long t = T[i];
724     RX X2t = PowerXMod(t,phimxmod);  // X2t = X^t mod Phi_m(X)
725     RX Ft = GCD(CompMod(F1,X2t,phimxmod),phimxmod);
726     if (Ft != localFactors[i]) {
727       cout << "Ft != F1(X^t) mod Phi_m(X), t=" << t << endl;
728       exit(0);
729     }
730   }*******************************************************************/
731 
732   if (r == 1) {
733     build(PhimXMod, phimxmod);
734     factors = localFactors;
735     pPowRContext.save();
736 
737     // Compute the CRT coefficients for the Ft's
738     resize(crtCoeffs, nSlots);
739     for (long i = 0; i < nSlots; i++) {
740       RX te = phimxmod / factors[i];        // \prod_{j\ne i} Fj
741       te %= factors[i];                     // \prod_{j\ne i} Fj mod Fi
742       InvMod(crtCoeffs[i], te, factors[i]); // \prod_{j\ne i} Fj^{-1} mod Fi
743     }
744   } else {
745     PAlgebraLift(zMStar.getPhimX(), localFactors, factors, crtCoeffs, r);
746     RX phimxmod1;
747     conv(phimxmod1, zMStar.getPhimX());
748     build(PhimXMod, phimxmod1);
749     pPowRContext.save();
750   }
751 
752   // set factorsOverZZ
753   resize(factorsOverZZ, nSlots);
754   for (long i = 0; i < nSlots; i++)
755     conv(factorsOverZZ[i], factors[i]);
756 
757   genCrtTable();
758   genMaskTable();
759 }
760 
761 // Assumes current zz_p modulus is p^r
762 // computes S = F^{-1} mod G via Hensel lifting
InvModpr(NTL::zz_pX & S,const NTL::zz_pX & F,const NTL::zz_pX & G,long p,long r)763 void InvModpr(NTL::zz_pX& S,
764               const NTL::zz_pX& F,
765               const NTL::zz_pX& G,
766               long p,
767               long r)
768 {
769   NTL::ZZX ff, gg, ss, tt;
770 
771   ff = to_ZZX(F);
772   gg = to_ZZX(G);
773 
774   NTL::zz_pBak bak;
775   bak.save();
776   NTL::zz_p::init(p);
777 
778   NTL::zz_pX f, g, s, t;
779   f = to_zz_pX(ff);
780   g = to_zz_pX(gg);
781   s = InvMod(f, g);
782   t = (1 - s * f) / g;
783   assertTrue(static_cast<bool>(s * f + t * g == 1l),
784              "Arithmetic error during Hensel lifting");
785   ss = to_ZZX(s);
786   tt = to_ZZX(t);
787 
788   NTL::ZZ pk = NTL::to_ZZ(1);
789 
790   for (long k = 1; k < r; k++) {
791     // lift from p^k to p^{k+1}
792     pk = pk * p;
793 
794     assertTrue((bool)divide(ss * ff + tt * gg - 1, pk),
795                "Arithmetic error during Hensel lifting");
796 
797     NTL::zz_pX d = to_zz_pX((1 - (ss * ff + tt * gg)) / pk);
798     NTL::zz_pX s1, t1;
799     s1 = (s * d) % g;
800     t1 = (d - s1 * f) / g;
801     ss = ss + pk * to_ZZX(s1);
802     tt = tt + pk * to_ZZX(t1);
803   }
804 
805   bak.restore();
806 
807   S = to_zz_pX(ss);
808 
809   assertTrue(static_cast<bool>((S * F) % G == 1),
810              "Hensel lifting failed to find solutions");
811 }
812 
813 // FIXME: Consider changing this function to something non-templated.
814 #pragma GCC diagnostic push
815 #pragma GCC diagnostic ignored "-Wunused-parameter"
816 template <typename T>
PAlgebraLift(const NTL::ZZX & phimx,const T & lfactors,T & factors,T & crtc,long r)817 void PAlgebraLift(const NTL::ZZX& phimx,
818                   const T& lfactors,
819                   T& factors,
820                   T& crtc,
821                   long r)
822 {
823   throw LogicError("Uninstantiated version of PAlgebraLift");
824 }
825 #pragma GCC diagnostic pop
826 
827 // This specialized version of PAlgebraLift does the hensel
828 // lifting needed to finish off the initialization.
829 // It assumes the zz_p modulus is initialized to p
830 // when called, and leaves it set to p^r
831 
832 template <>
PAlgebraLift(const NTL::ZZX & phimx,const NTL::vec_zz_pX & lfactors,NTL::vec_zz_pX & factors,NTL::vec_zz_pX & crtc,long r)833 void PAlgebraLift(const NTL::ZZX& phimx,
834                   const NTL::vec_zz_pX& lfactors,
835                   NTL::vec_zz_pX& factors,
836                   NTL::vec_zz_pX& crtc,
837                   long r)
838 {
839   long p = NTL::zz_p::modulus();
840   long nSlots = lsize(lfactors);
841 
842   NTL::vec_ZZX vzz; // need to go via ZZX
843 
844   // lift the factors of Phi_m(X) from mod-2 to mod-2^r
845   if (lsize(lfactors) > 1)
846     MultiLift(vzz, lfactors, phimx, r); // defined in NTL::ZZXFactoring
847   else {
848     resize(vzz, 1);
849     vzz[0] = phimx;
850   }
851 
852   // Compute the zz_pContext object for mod p^r arithmetic
853   NTL::zz_p::init(NTL::power_long(p, r));
854 
855   NTL::zz_pX phimxmod = to_zz_pX(phimx);
856   resize(factors, nSlots);
857   for (long i = 0; i < nSlots; i++) // Convert from ZZX to zz_pX
858     conv(factors[i], vzz[i]);
859 
860   // Finally compute the CRT coefficients for the factors
861   resize(crtc, nSlots);
862   for (long i = 0; i < nSlots; i++) {
863     NTL::zz_pX& fct = factors[i];
864     NTL::zz_pX te = phimxmod / fct;   // \prod_{j\ne i} Fj
865     te %= fct;                        // \prod_{j\ne i} Fj mod Fi
866     InvModpr(crtc[i], te, fct, p, r); // \prod_{j\ne i} Fj^{-1} mod Fi
867   }
868 }
869 
870 // Returns a vector crt[] such that crt[i] = p mod Ft (with t = T[i])
871 template <typename type>
CRT_decompose(std::vector<RX> & crt,const RX & H) const872 void PAlgebraModDerived<type>::CRT_decompose(std::vector<RX>& crt,
873                                              const RX& H) const
874 {
875   long nSlots = zMStar.getNSlots();
876 
877   if (isDryRun()) {
878     crt.clear();
879     return;
880   }
881   resize(crt, nSlots);
882   for (long i = 0; i < nSlots; i++)
883     rem(crt[i], H, factors[i]); // crt[i] = H % factors[i]
884 }
885 
886 template <typename type>
embedInAllSlots(RX & H,const RX & alpha,const MappingData<type> & mappingData) const887 void PAlgebraModDerived<type>::embedInAllSlots(
888     RX& H,
889     const RX& alpha,
890     const MappingData<type>& mappingData) const
891 {
892   if (isDryRun()) {
893     H = RX::zero();
894     return;
895   }
896   HELIB_TIMER_START;
897   long nSlots = zMStar.getNSlots();
898 
899   std::vector<RX> crt(nSlots); // allocate space for CRT components
900 
901   // The i'th CRT component is (H mod F_t) = alpha(maps[i]) mod F_t,
902   // where with t=T[i].
903 
904   if (IsX(mappingData.G) || deg(alpha) <= 0) {
905     // special case...no need for CompMod, which is
906     // is not optimized for this case
907 
908     for (long i = 0; i < nSlots; i++) // crt[i] = alpha(maps[i]) mod Ft
909       crt[i] = ConstTerm(alpha);
910   } else {
911     // general case...
912 
913     // FIXME: should update this to use matrix_maps, but this routine
914     // isn't actually used anywhere
915 
916     for (long i = 0; i < nSlots; i++) // crt[i] = alpha(maps[i]) mod Ft
917       CompMod(crt[i], alpha, mappingData.maps[i], factors[i]);
918   }
919 
920   CRT_reconstruct(H, crt); // interpolate to get H
921   HELIB_TIMER_STOP;
922 }
923 
924 template <typename type>
embedInSlots(RX & H,const std::vector<RX> & alphas,const MappingData<type> & mappingData) const925 void PAlgebraModDerived<type>::embedInSlots(
926     RX& H,
927     const std::vector<RX>& alphas,
928     const MappingData<type>& mappingData) const
929 {
930   if (isDryRun()) {
931     H = RX::zero();
932     return;
933   }
934   HELIB_TIMER_START;
935 
936   long nSlots = zMStar.getNSlots();
937   // assert(lsize(alphas) == nSlots);
938   assertEq(
939       lsize(alphas),
940       nSlots,
941       "Cannot embed in slots: alphas size is different than number of slots");
942 
943   long d = mappingData.degG;
944   for (long i = 0; i < nSlots; i++)
945     assertTrue(deg(alphas[i]) < d,
946                "Bad alpha element at index i: its degree is greater or "
947                "equal than mappingData.degG");
948 
949   std::vector<RX> crt(nSlots); // allocate space for CRT components
950 
951   // The i'th CRT component is (H mod F_t) = alphas[i](maps[i]) mod F_t,
952   // where with t=T[i].
953 
954   if (IsX(mappingData.G)) {
955     // special case...no need for CompMod, which is
956     // is not optimized for this case
957 
958     for (long i = 0; i < nSlots; i++) // crt[i] = alpha(maps[i]) mod Ft
959       crt[i] = ConstTerm(alphas[i]);
960   } else {
961     // general case...still try to avoid CompMod when possible,
962     // which is the common case for encoding masks
963 
964     HELIB_NTIMER_START(CompMod);
965 
966 #if 0
967     for (long i: range(nSlots)) {
968       if (deg(alphas[i]) <= 0)
969         crt[i] = alphas[i];
970       else
971         CompMod(crt[i], alphas[i], mappingData.maps[i], factors[i]);
972     }
973 #else
974     vec_R in, out;
975 
976     for (long i : range(nSlots)) {
977       if (deg(alphas[i]) <= 0)
978         crt[i] = alphas[i];
979       else {
980         VectorCopy(in, alphas[i], d);
981         mul(out, in, mappingData.matrix_maps[i]);
982         conv(crt[i], out);
983       }
984     }
985 #endif
986   }
987 
988   CRT_reconstruct(H, crt); // interpolate to get p
989 
990   HELIB_TIMER_STOP;
991 }
992 
993 template <typename type>
CRT_reconstruct(RX & H,std::vector<RX> & crt) const994 void PAlgebraModDerived<type>::CRT_reconstruct(RX& H,
995                                                std::vector<RX>& crt) const
996 {
997   if (isDryRun()) {
998     H = RX::zero();
999     return;
1000   }
1001   HELIB_TIMER_START;
1002   long nslots = zMStar.getNSlots();
1003 
1004   const std::vector<RX>& ctab = crtTable;
1005 
1006   clear(H);
1007   RX tmp1, tmp2;
1008 
1009   bool easy = true;
1010   for (long i = 0; i < nslots; i++)
1011     if (!IsZero(crt[i]) && !IsOne(crt[i])) {
1012       easy = false;
1013       break;
1014     }
1015 
1016   if (easy) {
1017     // VJS-FIXME: this looks easy, but could
1018     // be slower asymptotically
1019     for (long i = 0; i < nslots; i++)
1020       if (!IsZero(crt[i]))
1021         H += ctab[i];
1022   } else {
1023     std::vector<RX> crt1;
1024     resize(crt1, nslots);
1025     for (long i = 0; i < nslots; i++)
1026       MulMod(crt1[i], crt[i], crtCoeffs[i], factors[i]);
1027 
1028     evalTree(H, crtTree, crt1, 0, nslots);
1029   }
1030   HELIB_TIMER_STOP;
1031 }
1032 
1033 template <typename type>
mapToFt(RX & w,const RX & G,long t,const RX * rF1) const1034 void PAlgebraModDerived<type>::mapToFt(RX& w,
1035                                        const RX& G,
1036                                        long t,
1037                                        const RX* rF1) const
1038 {
1039   if (isDryRun()) {
1040     w = RX::zero();
1041     return;
1042   }
1043   long i = zMStar.indexOfRep(t);
1044   if (i < 0) {
1045     clear(w);
1046     return;
1047   }
1048 
1049   if (rF1 == nullptr) { // Compute the representation "from scratch"
1050     // special case
1051     if (G == factors[i]) {
1052       SetX(w);
1053       return;
1054     }
1055 
1056     // special case
1057     if (deg(G) == 1) {
1058       w = -ConstTerm(G);
1059       return;
1060     }
1061 
1062     // the general case: currently only works when r == 1
1063     assertEq(r, 1l, "Bad Hensel lifting value in general case: r is not 1");
1064 
1065     REBak bak;
1066     bak.save();
1067     RE::init(factors[i]); // work with the extension field GF_p[X]/Ft(X)
1068     REX Ga;
1069     conv(Ga, G); // G as a polynomial over the extension field
1070 
1071     vec_RE roots;
1072     FindRoots(roots, Ga); // Find roots of G in this field
1073     RE* first = &roots[0];
1074     RE* last = first + lsize(roots);
1075     RE* smallest = std::min_element(
1076         first,
1077         last,
1078         static_cast<bool (*)(const RE&, const RE&)>(less_than));
1079     // make a canonical choice
1080     w = rep(*smallest);
1081     return;
1082   }
1083   // if rF1 is set, then use it instead, setting w = rF1(X^t) mod Ft(X)
1084   RXModulus Ft(factors[i]);
1085   //  long tInv = InvMod(t,m);
1086   RX X2t = PowerXMod(t, Ft);  // X2t = X^t mod Ft
1087   w = CompMod(*rF1, X2t, Ft); // w = F1(X2t) mod Ft
1088 
1089   /* Debugging sanity-check: G(w)=0 in the extension field (Z/2Z)[X]/Ft(X)
1090   RE::init(factors[i]);
1091   REX Ga;
1092   conv(Ga, G); // G as a polynomial over the extension field
1093   RE ra;
1094   conv(ra, w);         // w is an element in the extension field
1095   eval(ra,Ga,ra);  // ra = Ga(ra)
1096   if (!IsZero(ra)) {// check that Ga(w)=0 in this extension field
1097     cout << "rF1(X^t) mod Ft(X) != root of G mod Ft, t=" << t << endl;
1098     exit(0);
1099   }*******************************************************************/
1100 }
1101 
1102 template <typename type>
mapToSlots(MappingData<type> & mappingData,const RX & G) const1103 void PAlgebraModDerived<type>::mapToSlots(MappingData<type>& mappingData,
1104                                           const RX& G) const
1105 {
1106   assertTrue<InvalidArgument>(
1107       deg(G) > 0,
1108       "Polynomial G is constant (has degree less than one)");
1109   assertEq(zMStar.getOrdP() % deg(G),
1110            0l,
1111            "Degree of polynomial G does not divide zMStar.getOrdP()");
1112   assertTrue<InvalidArgument>(static_cast<bool>(LeadCoeff(G) == 1l),
1113                               "Polynomial G is not monic");
1114   mappingData.G = G;
1115   mappingData.degG = deg(mappingData.G);
1116   long d = deg(G);
1117   long ordp = zMStar.getOrdP();
1118 
1119   long nSlots = zMStar.getNSlots();
1120   long m = zMStar.getM();
1121 
1122   resize(mappingData.maps, nSlots);
1123 
1124   mapToF1(mappingData.maps[0], mappingData.G); // mapping from base-G to base-F1
1125   for (long i = 1; i < nSlots; i++)
1126     mapToFt(mappingData.maps[i],
1127             mappingData.G,
1128             zMStar.ith_rep(i),
1129             &(mappingData.maps[0]));
1130 
1131   // create matrices to streamline CompMod operations
1132   resize(mappingData.matrix_maps, nSlots);
1133   for (long i : range(nSlots)) {
1134     mat_R& mat = mappingData.matrix_maps[i];
1135     mat.SetDims(d, ordp);
1136     RX pow;
1137     pow = 1;
1138     for (long j : range(d)) {
1139       VectorCopy(mat[j], pow, ordp);
1140       if (j < d - 1)
1141         MulMod(pow, pow, mappingData.maps[i], factors[i]);
1142     }
1143   }
1144 
1145   REBak bak;
1146   bak.save();
1147   RE::init(mappingData.G);
1148   mappingData.contextForG.save();
1149 
1150   if (deg(mappingData.G) == 1)
1151     return;
1152 
1153   resize(mappingData.rmaps, nSlots);
1154 
1155   if (G == factors[0]) {
1156     // an important special case
1157 
1158     for (long i = 0; i < nSlots; i++) {
1159       long t = zMStar.ith_rep(i);
1160       long tInv = NTL::InvMod(t, m);
1161 
1162       RX ct_rep;
1163       PowerXMod(ct_rep, tInv, G);
1164 
1165       RE ct;
1166       conv(ct, ct_rep);
1167 
1168       REX Qi;
1169       SetCoeff(Qi, 1, 1);
1170       SetCoeff(Qi, 0, -ct);
1171 
1172       mappingData.rmaps[i] = Qi;
1173     }
1174   } else {
1175     // the general case: currently only works when r == 1
1176 
1177     assertEq(r, 1l, "Bad Hensel lifting value in general case: r is not 1");
1178 
1179     vec_REX FRts;
1180     for (long i = 0; i < nSlots; i++) {
1181       // We need to lift Fi from R[Y] to (R[X]/G(X))[Y]
1182       REX Qi;
1183       long t, tInv = 0;
1184 
1185       if (i == 0) {
1186         conv(Qi, factors[i]);
1187         FRts = EDF(Qi, FrobeniusMap(Qi), deg(Qi) / deg(G));
1188         // factor Fi over GF(p)[X]/G(X)
1189       } else {
1190         t = zMStar.ith_rep(i);
1191         tInv = NTL::InvMod(t, m);
1192       }
1193 
1194       // need to choose the right factor, the one that gives us back X
1195       long j;
1196       for (j = 0; j < lsize(FRts); j++) {
1197         // lift maps[i] to (R[X]/G(X))[Y] and reduce mod j'th factor of Fi
1198 
1199         REX FRtsj;
1200         if (i == 0)
1201           FRtsj = FRts[j];
1202         else {
1203           REX X2tInv = PowerXMod(tInv, FRts[j]);
1204           IrredPolyMod(FRtsj, X2tInv, FRts[j]);
1205         }
1206 
1207         // FRtsj is the jth factor of factors[i] over the extension field.
1208         // For j > 0, we save some time by computing it from the jth factor
1209         // of factors[0] via a minimal polynomial computation.
1210 
1211         REX GRti;
1212         conv(GRti, mappingData.maps[i]);
1213         GRti %= FRtsj;
1214 
1215         if (IsX(rep(ConstTerm(GRti)))) { // is GRti == X?
1216           Qi = FRtsj;                    // If so, we found the right factor
1217           break;
1218         } // If this does not happen then move to the next factor of Fi
1219       }
1220 
1221       assertTrue(j < lsize(FRts),
1222                  "Cannot find the right factor Qi. Loop did not "
1223                  "terminate before visiting all elements");
1224       mappingData.rmaps[i] = Qi;
1225     }
1226   }
1227 }
1228 
1229 template <typename type>
decodePlaintext(std::vector<RX> & alphas,const RX & ptxt,const MappingData<type> & mappingData) const1230 void PAlgebraModDerived<type>::decodePlaintext(
1231     std::vector<RX>& alphas,
1232     const RX& ptxt,
1233     const MappingData<type>& mappingData) const
1234 {
1235   long nSlots = zMStar.getNSlots();
1236   if (isDryRun()) {
1237     alphas.assign(nSlots, RX::zero());
1238     return;
1239   }
1240 
1241   // First decompose p into CRT components
1242   std::vector<RX> CRTcomps(nSlots); // allocate space for CRT component
1243   CRT_decompose(CRTcomps, ptxt);    // CRTcomps[i] = p mod facors[i]
1244 
1245   if (mappingData.degG == 1) {
1246     alphas = CRTcomps;
1247     return;
1248   }
1249 
1250   resize(alphas, nSlots);
1251 
1252   REBak bak;
1253   bak.save();
1254   mappingData.contextForG.restore();
1255 
1256   for (long i = 0; i < nSlots; i++) {
1257     REX te;
1258     conv(te, CRTcomps[i]); // lift i'th CRT component to mod G(X)
1259     te %= mappingData
1260               .rmaps[i]; // reduce CRTcomps[i](Y) mod Qi(Y), over (Z_2[X]/G(X))
1261 
1262     // the free term (no Y component) should be our answer (as a poly(X))
1263     alphas[i] = rep(ConstTerm(te));
1264   }
1265 }
1266 
1267 template <typename type>
buildLinPolyCoeffs(std::vector<RX> & C,const std::vector<RX> & L,const MappingData<type> & mappingData) const1268 void PAlgebraModDerived<type>::buildLinPolyCoeffs(
1269     std::vector<RX>& C,
1270     const std::vector<RX>& L,
1271     const MappingData<type>& mappingData) const
1272 {
1273   REBak bak;
1274   bak.save();
1275   mappingData.contextForG.restore();
1276 
1277   long d = RE::degree();
1278   long p = zMStar.getP();
1279 
1280   assertEq(lsize(L), d, "Vector L size is different than RE::degree()");
1281 
1282   vec_RE LL;
1283   resize(LL, d);
1284 
1285   for (long i = 0; i < d; i++)
1286     conv(LL[i], L[i]);
1287 
1288   vec_RE CC;
1289   ::helib::buildLinPolyCoeffs(CC, LL, p, r);
1290 
1291   resize(C, d);
1292   for (long i = 0; i < d; i++)
1293     C[i] = rep(CC[i]);
1294 }
1295 
1296 // code for generating mask tables
1297 // currently, this is done when the PAlgebraMod
1298 // object is constructed.
1299 
1300 // VJS-FIXME: what were we thinking? these tables
1301 // can be huge
1302 
1303 template <typename type>
genMaskTable()1304 void PAlgebraModDerived<type>::genMaskTable()
1305 {
1306   // This is only called by the constructor, which has already
1307   // set the zz_p context and the crtTable
1308   resize(maskTable, zMStar.numOfGens());
1309   for (long i = 0; i < (long)zMStar.numOfGens(); i++) {
1310     long ord = zMStar.OrderOf(i);
1311     resize(maskTable[i], ord + 1);
1312     maskTable[i][ord] = 0;
1313     for (long j = ord - 1; j >= 1; j--) {
1314       // initialize mask that is 1 whenever the ith coordinate is at least j
1315       // Note: maskTable[i][0] = constant 1, maskTable[i][ord] = constant 0
1316       maskTable[i][j] = maskTable[i][j + 1];
1317       for (long k = 0; k < (long)zMStar.getNSlots(); k++) {
1318         if (zMStar.coordinate(i, k) == j) {
1319           add(maskTable[i][j], maskTable[i][j], crtTable[k]);
1320         }
1321       }
1322     }
1323     maskTable[i][0] = 1;
1324   }
1325 }
1326 
1327 // code for generating crt tables
1328 // currently, this is done when the PAlgebraMod
1329 // object is constructed.
1330 
1331 // VJS-FIXME: what were we thinking? these tables
1332 // can be huge
1333 
1334 template <typename type>
genCrtTable()1335 void PAlgebraModDerived<type>::genCrtTable()
1336 {
1337   // This is only called by the constructor, which has already
1338   // set the zz_p context
1339 
1340   long nslots = zMStar.getNSlots();
1341   resize(crtTable, nslots);
1342   for (long i = 0; i < nslots; i++) {
1343     RX allBut_i = PhimXMod / factors[i]; // = \prod_{j \ne i }Fj
1344     allBut_i *= crtCoeffs[i]; // = 1 mod Fi and = 0 mod Fj for j \ne i
1345     crtTable[i] = allBut_i;
1346   }
1347 
1348   buildTree(crtTree, 0, nslots);
1349 }
1350 
1351 template <typename type>
buildTree(std::shared_ptr<TNode<RX>> & res,long offset,long extent) const1352 void PAlgebraModDerived<type>::buildTree(std::shared_ptr<TNode<RX>>& res,
1353                                          long offset,
1354                                          long extent) const
1355 {
1356   if (extent == 1)
1357     res = buildTNode<RX>(nullTNode<RX>(), nullTNode<RX>(), factors[offset]);
1358   else {
1359     long half = extent / 2;
1360     std::shared_ptr<TNode<RX>> left, right;
1361     buildTree(left, offset, half);
1362     buildTree(right, offset + half, extent - half);
1363     RX data = left->data * right->data;
1364     res = buildTNode<RX>(left, right, data);
1365   }
1366 }
1367 
1368 template <typename type>
evalTree(RX & res,std::shared_ptr<TNode<RX>> tree,const std::vector<RX> & crt1,long offset,long extent) const1369 void PAlgebraModDerived<type>::evalTree(RX& res,
1370                                         std::shared_ptr<TNode<RX>> tree,
1371                                         const std::vector<RX>& crt1,
1372                                         long offset,
1373                                         long extent) const
1374 {
1375   if (extent == 1)
1376     res = crt1[offset];
1377   else {
1378     long half = extent / 2;
1379     RX lres, rres;
1380     evalTree(lres, tree->left, crt1, offset, half);
1381     evalTree(rres, tree->right, crt1, offset + half, extent - half);
1382     RX tmp1, tmp2;
1383     mul(tmp1, lres, tree->right->data);
1384     mul(tmp2, rres, tree->left->data);
1385     add(tmp1, tmp1, tmp2);
1386     res = tmp1;
1387   }
1388 }
1389 
1390 // Explicit instantiation
1391 
1392 template class PAlgebraModDerived<PA_GF2>;
1393 template class PAlgebraModDerived<PA_zz_p>;
1394 
1395 } // namespace helib
1396