1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 #ifndef HELIB_PALGEBRA_H
13 #define HELIB_PALGEBRA_H
14 /**
15  * @file PAlgebra.h
16  * @brief Declarations of the classes PAlgebra
17  */
18 #include <exception>
19 #include <utility>
20 #include <vector>
21 #include <complex>
22 
23 #include <helib/NumbTh.h>
24 #include <helib/zzX.h>
25 #include <helib/hypercube.h>
26 #include <helib/PGFFT.h>
27 #include <helib/ClonedPtr.h>
28 #include <helib/apiAttributes.h>
29 
30 namespace helib {
31 
32 struct half_FFT
33 {
34   PGFFT fft;
35   std::vector<std::complex<double>> pow;
36 
37   half_FFT(long m);
38 };
39 
40 struct quarter_FFT
41 {
42   PGFFT fft;
43   std::vector<std::complex<double>> pow1, pow2;
44 
45   quarter_FFT(long m);
46 };
47 
48 /**
49  * @class PAlgebra
50  * @brief The structure of (Z/mZ)* /(p)
51  *
52  * A PAlgebra object is determined by an integer m and a prime p, where p does
53  * not divide m. It holds information describing the structure of (Z/mZ)^*,
54  * which is isomorphic to the Galois group over A = Z[X]/Phi_m(X)).
55  *
56  * We represent (Z/mZ)^* as (Z/mZ)^* = (p) x (g1,g2,...) x (h1,h2,...)
57  * where the group generated by g1,g2,... consists of the elements that
58  * have the same order in (Z/mZ)^* as in (Z/mZ)^* /(p,g_1,...,g_{i-1}), and
59  * h1,h2,... generate the remaining quotient group (Z/mZ)^* /(p,g1,g2,...).
60  *
61  * We let T subset (Z/mZ)^* be a set of representatives for the quotient
62  * group (Z/mZ)^* /(p), defined as T={ prod_i gi^{ei} * prod_j hj^{ej} }
63  * where the ei's range over 0,1,...,ord(gi)-1 and the ej's range over
64  * 0,1,...ord(hj)-1 (these last orders are in (Z/mZ)^* /(p,g1,g2,...)).
65  *
66  * Phi_m(X) is factored as Phi_m(X)= prod_{t in T} F_t(X) mod p,
67  * where the F_t's are irreducible modulo p. An arbitrary factor
68  * is chosen as F_1, then for each t in T we associate with the index t the
69  * factor F_t(X) = GCD(F_1(X^t), Phi_m(X)).
70  *
71  * Note that fixing a representation of the field R=(Z/pZ)[X]/F_1(X)
72  * and letting z be a root of F_1 in R (which
73  * is a primitive m-th root of unity in R), we get that F_t is the minimal
74  * polynomial of z^{1/t}.
75  */
76 class PAlgebra
77 {
78   long m; // the integer m defines (Z/mZ)^*, Phi_m(X), etc.
79   long p; // the prime base of the plaintext space
80 
81   long phiM;      // phi(m)
82   long ordP;      // the order of p in (Z/mZ)^*
83   long nfactors;  // number of distinct prime factors of m
84   long radm;      // rad(m) = prod of distinct primes dividing m
85   double normBnd; // max-norm-on-pwfl-basis <= normBnd * max-norm-canon-embed
86   double polyNormBnd; // max-norm-on-poly-basis <= polyNormBnd *
87                       // max-norm-canon-embed
88 
89   long pow2; // if m = 2^k, then pow2 == k; otherwise, pow2 == 0
90 
91   std::vector<long> gens; // Our generators for (Z/mZ)^* (other than p)
92 
93   //  native[i] is true iff gens[i] has the same order in the quotient
94   //  group as its order in Zm*.
95   NTL::Vec<bool> native;
96 
97   // frob_perturb[i] = j if gens[i] raised to its order equals p^j,
98   // otherwise -1
99   NTL::Vec<long> frob_perturb;
100 
101   CubeSignature cube; // the hypercube structure of Zm* /(p)
102 
103   NTL::ZZX PhimX; // Holds the integer polynomial Phi_m(X)
104 
105   double cM; // the "ring constant" c_m for Z[X]/Phi_m(X)
106   // NOTE: cM is related to the ratio between the l_infinity norm of
107   // a "random" ring element in different bases. For example, think of
108   // choosing the power-basis coefficients of x uniformly at random in
109   // [+-a/2] (for some parameter a), then the powerful basis norm of x
110   // should be bounded whp by cM*a.
111   //
112   // More precisely, for an element x whose coefficients are chosen
113   // uniformly in [+-a/2] (in either the powerful or the power basis)
114   // we have a high-probability bound |x|_canonical < A*a for some
115   // A = O(sqrt(phi(m)). Also for "random enough" x we have some bound
116   //       |x|_powerful < |x|_canonical * B
117   // where we "hope" that B = O(1/sqrt(phi(m)). The cM value is
118   // supposed to be cM=A*B.
119   //
120   // The value cM is only used for bootstrapping, see more comments
121   // for the method RecryptData::setAE in recryption.cpp. Also see
122   // Appendix A of https://ia.cr/2014/873 (updated version from 2019)
123 
124   std::vector<long> T;    // The representatives for the quotient group Zm* /(p)
125   std::vector<long> Tidx; // i=Tidx[t] is the index i s.t. T[i]=t.
126                           // Tidx[t]==-1 if t notin T
127 
128   std::vector<long> zmsIdx; // if t is the i'th element in Zm* then zmsIdx[t]=i
129                             // zmsIdx[t]==-1 if t notin Zm*
130 
131   std::vector<long> zmsRep; // inverse of zmsIdx
132 
133   std::shared_ptr<PGFFT> fftInfo; // info for computing m-point complex FFT's
134                                   // shard_ptr allows delayed initialization
135                                   // and lightweight copying
136 
137   std::shared_ptr<half_FFT> half_fftInfo;
138   // an optimization for FFT's with even m
139 
140   std::shared_ptr<quarter_FFT> quarter_fftInfo;
141   // an optimization for FFT's with m = 0 (mod 4)
142 
143 public:
144   PAlgebra& operator=(const PAlgebra&) = delete;
145 
146   PAlgebra(long mm,
147            long pp = 2,
148            const std::vector<long>& _gens = std::vector<long>(),
149            const std::vector<long>& _ords = std::vector<long>()); // constructor
150 
151   bool operator==(const PAlgebra& other) const;
152   bool operator!=(const PAlgebra& other) const { return !(*this == other); }
153   // comparison
154 
155   /* I/O methods */
156 
157   //! Prints the structure in a readable form
158   void printout(std::ostream& out = std::cout) const;
159   void printAll(std::ostream& out = std::cout) const; // print even more
160 
161   /* Access methods */
162 
163   //! Returns m
getM()164   long getM() const { return m; }
165 
166   //! Returns p
getP()167   long getP() const { return p; }
168 
169   //! Returns phi(m)
getPhiM()170   long getPhiM() const { return phiM; }
171 
172   //! The order of p in (Z/mZ)^*
getOrdP()173   long getOrdP() const { return ordP; }
174 
175   //! The number of distinct prime factors of m
getNFactors()176   long getNFactors() const { return nfactors; }
177 
178   //! getRadM() = prod of distinct prime factors of m
getRadM()179   long getRadM() const { return radm; }
180 
181   //! max-norm-on-pwfl-basis <= normBnd * max-norm-canon-embed
getNormBnd()182   double getNormBnd() const { return normBnd; }
183 
184   //! max-norm-on-pwfl-basis <= polyNormBnd * max-norm-canon-embed
getPolyNormBnd()185   double getPolyNormBnd() const { return polyNormBnd; }
186 
187   //! The number of plaintext slots = phi(m)/ord(p)
getNSlots()188   long getNSlots() const { return cube.getSize(); }
189 
190   //! if m = 2^k, then pow2 == k; otherwise, pow2 == 0
getPow2()191   long getPow2() const { return pow2; }
192 
193   //! The cyclotomix polynomial Phi_m(X)
getPhimX()194   const NTL::ZZX& getPhimX() const { return PhimX; }
195 
196   //! The "ring constant" cM
set_cM(double c)197   void set_cM(double c) { cM = c; }
get_cM()198   double get_cM() const { return cM; }
199 
200   //! The prime-power factorization of m
201   //  const std::vector<long> getMfactors() const { return mFactors; }
202 
203   //! The number of generators in (Z/mZ)^* /(p)
numOfGens()204   long numOfGens() const { return gens.size(); }
205 
206   //! the i'th generator in (Z/mZ)^* /(p) (if any)
ZmStarGen(long i)207   long ZmStarGen(long i) const { return (i < long(gens.size())) ? gens[i] : 0; }
208 
209   //! the i'th generator to the power j mod m
210   // VJS: I'm moving away from all of this unsigned stuff...
211   // Also, note that j really may be negative
212   // NOTE: i == -1 means Frobenius
213   long genToPow(long i, long j) const;
214 
215   // p to the power j mod m
216   long frobeniusPow(long j) const;
217 
218   //! The order of i'th generator (if any)
OrderOf(long i)219   long OrderOf(long i) const { return cube.getDim(i); }
220 
221   //! The product prod_{j=i}^{n-1} OrderOf(i)
ProdOrdsFrom(long i)222   long ProdOrdsFrom(long i) const { return cube.getProd(i); }
223 
224   //! Is ord(i'th generator) the same as its order in (Z/mZ)^*?
SameOrd(long i)225   bool SameOrd(long i) const { return native[i]; }
226 
227   // FrobPerturb[i] = j if gens[i] raised to its order equals p^j,
228   // where j in [0..ordP), otherwise -1
FrobPerturb(long i)229   long FrobPerturb(long i) const { return frob_perturb[i]; }
230 
231   //! @name Translation between index, representatives, and exponents
232 
233   //! Returns the i'th element in T
ith_rep(long i)234   long ith_rep(long i) const { return (i < getNSlots()) ? T[i] : 0; }
235 
236   //! Returns the index of t in T
indexOfRep(long t)237   long indexOfRep(long t) const { return (t > 0 && t < m) ? Tidx[t] : -1; }
238 
239   //! Is t in T?
isRep(long t)240   bool isRep(long t) const { return (t > 0 && t < m && Tidx[t] > -1); }
241 
242   //! Returns the index of t in (Z/mZ)*
indexInZmstar(long t)243   long indexInZmstar(long t) const { return (t > 0 && t < m) ? zmsIdx[t] : -1; }
244 
245   //! Returns the index of t in (Z/mZ)* -- no range checking
indexInZmstar_unchecked(long t)246   long indexInZmstar_unchecked(long t) const { return zmsIdx[t]; }
247 
248   //! Returns rep whose index is i
repInZmstar_unchecked(long idx)249   long repInZmstar_unchecked(long idx) const { return zmsRep[idx]; }
250 
inZmStar(long t)251   bool inZmStar(long t) const { return (t > 0 && t < m && zmsIdx[t] > -1); }
252 
253   //! @brief Returns prod_i gi^{exps[i]} mod m. If onlySameOrd=true,
254   //! use only generators that have the same order as in (Z/mZ)^*.
255   long exponentiate(const std::vector<long>& exps,
256                     bool onlySameOrd = false) const;
257 
258   //! @brief Returns coordinate of index k along the i'th dimension.
coordinate(long i,long k)259   long coordinate(long i, long k) const { return cube.getCoord(k, i); }
260 
261   //! Break an index into the hypercube to index of the dimension-dim
262   //! subcube and index inside that subcube.
breakIndexByDim(long idx,long dim)263   std::pair<long, long> breakIndexByDim(long idx, long dim) const
264   {
265     return cube.breakIndexByDim(idx, dim);
266   }
267   //! The inverse of breakIndexByDim
assembleIndexByDim(std::pair<long,long> idx,long dim)268   long assembleIndexByDim(std::pair<long, long> idx, long dim) const
269   {
270     return cube.assembleIndexByDim(idx, dim);
271   }
272 
273   //! @brief adds offset to index k in the i'th dimension
addCoord(long i,long k,long offset)274   long addCoord(long i, long k, long offset) const
275   {
276     return cube.addCoord(k, i, offset);
277   }
278 
279   /* Miscellaneous */
280 
281   //! exps is an array of exponents (the dLog of some t in T), this function
282   //! increment exps lexicographic order, return false if it cannot be
283   //! incremented (because it is at its maximum value)
nextExpVector(std::vector<long> & exps)284   bool nextExpVector(std::vector<long>& exps) const
285   {
286     return cube.incrementCoords(exps);
287   }
288 
289   //! The largest FFT we need to handle degree-m polynomials
fftSizeNeeded()290   long fftSizeNeeded() const { return NTL::NextPowerOfTwo(getM()) + 1; }
291   // TODO: should have a special case when m is power of two
292 
getFFTInfo()293   const PGFFT& getFFTInfo() const { return *fftInfo; }
getHalfFFTInfo()294   const half_FFT& getHalfFFTInfo() const { return *half_fftInfo; }
getQuarterFFTInfo()295   const quarter_FFT& getQuarterFFTInfo() const { return *quarter_fftInfo; }
296 };
297 
298 enum PA_tag
299 {
300   PA_GF2_tag,
301   PA_zz_p_tag,
302   PA_cx_tag
303 };
304 
305 /**
306 @class: PAlgebraMod
307 @brief The structure of Z[X]/(Phi_m(X), p)
308 
309 An object of type PAlgebraMod stores information about a PAlgebra object
310 zMStar, and an integer r. It also provides support for encoding and decoding
311 plaintext slots.
312 
313 the PAlgebra object zMStar defines (Z/mZ)^* /(0), and the PAlgebraMod object
314 stores various tables related to the polynomial ring Z/(p^r)[X].  To do this
315 most efficiently, if p == 2 and r == 1, then these polynomials are represented
316 as GF2X's, and otherwise as zz_pX's. Thus, the types of these objects are not
317 determined until run time. As such, we need to use a class hierarchy, as
318 follows.
319 
320 \li PAlgebraModBase is a virtual class
321 
322 \li PAlgebraModDerived<type> is a derived template class, where
323   type is either PA_GF2 or PA_zz_p.
324 
325 \li The class PAlgebraMod is a simple wrapper around a smart pointer to a
326   PAlgebraModBase object: copying a PAlgebra object results is a "deep copy" of
327   the underlying object of the derived class. It provides dDirect access to the
328   virtual methods of PAlgebraModBase, along with a "downcast" operator to get a
329   reference to the object as a derived type, and also == and != operators.
330 **/
331 
332 //! \cond FALSE (make doxygen ignore these classes)
333 class DummyBak
334 {
335   // placeholder class used in GF2X impl
336 
337 public:
save()338   void save() {}
restore()339   void restore() const {}
340 };
341 
342 class DummyContext
343 {
344   // placeholder class used in GF2X impl
345 
346 public:
save()347   void save() {}
restore()348   void restore() const {}
DummyContext()349   DummyContext() {}
DummyContext(long)350   DummyContext(long) {}
351 };
352 
353 class DummyModulus
354 {};
355 // placeholder class for CKKS
356 
357 // some stuff to help with template code
358 template <typename R>
359 struct GenericModulus
360 {};
361 
362 template <>
363 struct GenericModulus<NTL::zz_p>
364 {
365   static void init(long p) { NTL::zz_p::init(p); }
366 };
367 
368 template <>
369 struct GenericModulus<NTL::GF2>
370 {
371   static void init(long p)
372   {
373     assertEq<InvalidArgument>(p, 2l, "Cannot init NTL::GF2 with p not 2");
374   }
375 };
376 
377 class PA_GF2
378 {
379   // typedefs for algebraic structures built up from GF2
380 
381 public:
382   static const PA_tag tag = PA_GF2_tag;
383   typedef NTL::GF2 R;
384   typedef NTL::GF2X RX;
385   typedef NTL::vec_GF2X vec_RX;
386   typedef NTL::GF2XModulus RXModulus;
387   typedef DummyBak RBak;
388   typedef DummyContext RContext;
389   typedef NTL::GF2E RE;
390   typedef NTL::vec_GF2E vec_RE;
391   typedef NTL::mat_GF2E mat_RE;
392   typedef NTL::GF2EX REX;
393   typedef NTL::GF2EBak REBak;
394   typedef NTL::vec_GF2EX vec_REX;
395   typedef NTL::GF2EContext REContext;
396   typedef NTL::mat_GF2 mat_R;
397   typedef NTL::vec_GF2 vec_R;
398 };
399 
400 class PA_zz_p
401 {
402   // typedefs for algebraic structures built up from zz_p
403 
404 public:
405   static const PA_tag tag = PA_zz_p_tag;
406   typedef NTL::zz_p R;
407   typedef NTL::zz_pX RX;
408   typedef NTL::vec_zz_pX vec_RX;
409   typedef NTL::zz_pXModulus RXModulus;
410   typedef NTL::zz_pBak RBak;
411   typedef NTL::zz_pContext RContext;
412   typedef NTL::zz_pE RE;
413   typedef NTL::vec_zz_pE vec_RE;
414   typedef NTL::mat_zz_pE mat_RE;
415   typedef NTL::zz_pEX REX;
416   typedef NTL::zz_pEBak REBak;
417   typedef NTL::vec_zz_pEX vec_REX;
418   typedef NTL::zz_pEContext REContext;
419   typedef NTL::mat_zz_p mat_R;
420   typedef NTL::vec_zz_p vec_R;
421 };
422 
423 class PA_cx
424 {
425   // typedefs for algebraic structures built up from complex<double>
426 
427 public:
428   static const PA_tag tag = PA_cx_tag;
429   typedef double R;
430   typedef std::complex<double> RX;
431   typedef NTL::Vec<RX> vec_RX;
432   typedef DummyModulus RXModulus;
433   typedef DummyBak RBak;
434   typedef DummyContext RContext;
435 
436   // the other typedef's should not ever be used...they
437   // are all defined as void, so that PA_INJECT still works
438   typedef void RE;
439   typedef void vec_RE;
440   typedef void mat_RE;
441   typedef void REX;
442   typedef void REBak;
443   typedef void vec_REX;
444   typedef void REContext;
445   typedef void mat_R;
446   typedef void vec_R;
447 };
448 
449 //! \endcond
450 
451 //! Virtual base class for PAlgebraMod
452 class PAlgebraModBase
453 {
454 
455 public:
456   virtual ~PAlgebraModBase() {}
457   virtual PAlgebraModBase* clone() const = 0;
458 
459   //! Returns the type tag: PA_GF2_tag or PA_zz_p_tag
460   virtual PA_tag getTag() const = 0;
461 
462   //! Returns reference to underlying PAlgebra object
463   virtual const PAlgebra& getZMStar() const = 0;
464 
465   //! Returns reference to the factorization of Phi_m(X) mod p^r, but as ZZX's
466   virtual const std::vector<NTL::ZZX>& getFactorsOverZZ() const = 0;
467 
468   //! The value r
469   virtual long getR() const = 0;
470 
471   //! The value p^r
472   virtual long getPPowR() const = 0;
473 
474   //! Restores the NTL context for p^r
475   virtual void restoreContext() const = 0;
476 
477   virtual zzX getMask_zzX(long i, long j) const = 0;
478 };
479 
480 #ifndef DOXYGEN_IGNORE
481 #define PA_INJECT(typ)                                                         \
482   static const PA_tag tag = typ::tag;                                          \
483   typedef typename typ::R R;                                                   \
484   typedef typename typ::RX RX;                                                 \
485   typedef typename typ::vec_RX vec_RX;                                         \
486   typedef typename typ::RXModulus RXModulus;                                   \
487   typedef typename typ::RBak RBak;                                             \
488   typedef typename typ::RContext RContext;                                     \
489   typedef typename typ::RE RE;                                                 \
490   typedef typename typ::vec_RE vec_RE;                                         \
491   typedef typename typ::mat_RE mat_RE;                                         \
492   typedef typename typ::REX REX;                                               \
493   typedef typename typ::REBak REBak;                                           \
494   typedef typename typ::vec_REX vec_REX;                                       \
495   typedef typename typ::REContext REContext;                                   \
496   typedef typename typ::mat_R mat_R;                                           \
497   typedef typename typ::vec_R vec_R;
498 
499 #endif
500 
501 template <typename type>
502 class PAlgebraModDerived;
503 // forward declaration
504 
505 //! Auxiliary structure to support encoding/decoding slots.
506 template <typename type>
507 class MappingData
508 {
509 
510 public:
511   PA_INJECT(type)
512 
513   friend class PAlgebraModDerived<type>;
514 
515 private:
516   RX G;      // the polynomial defining the field extension
517   long degG; // the degree of the polynomial
518 
519   REContext contextForG;
520 
521   /* the remaining fields are visible only to PAlgebraModDerived */
522 
523   std::vector<RX> maps;
524   std::vector<mat_R> matrix_maps;
525   std::vector<REX> rmaps;
526 
527 public:
528   const RX& getG() const { return G; }
529   long getDegG() const { return degG; }
530   void restoreContextForG() const { contextForG.restore(); }
531 
532   // copy and assignment
533 };
534 
535 //! \cond FALSE (make doxygen ignore these classes)
536 template <typename T>
537 class TNode
538 {
539 public:
540   std::shared_ptr<TNode<T>> left, right;
541   T data;
542 
543   TNode(std::shared_ptr<TNode<T>> _left,
544         std::shared_ptr<TNode<T>> _right,
545         const T& _data) :
546       left(_left), right(_right), data(_data)
547   {}
548 };
549 
550 template <typename T>
551 std::shared_ptr<TNode<T>> buildTNode(std::shared_ptr<TNode<T>> left,
552                                      std::shared_ptr<TNode<T>> right,
553                                      const T& data)
554 {
555   return std::shared_ptr<TNode<T>>(new TNode<T>(left, right, data));
556 }
557 
558 template <typename T>
559 std::shared_ptr<TNode<T>> nullTNode()
560 {
561   return std::shared_ptr<TNode<T>>();
562 }
563 //! \endcond
564 
565 //! A concrete instantiation of the virtual class
566 template <typename type>
567 class PAlgebraModDerived : public PAlgebraModBase
568 {
569 public:
570   PA_INJECT(type)
571 
572 private:
573   const PAlgebra& zMStar;
574   long r;
575   long pPowR;
576   RContext pPowRContext;
577 
578   RXModulus PhimXMod;
579 
580   vec_RX factors;
581   std::vector<NTL::ZZX> factorsOverZZ;
582   vec_RX crtCoeffs;
583   std::vector<std::vector<RX>> maskTable;
584   std::vector<RX> crtTable;
585   std::shared_ptr<TNode<RX>> crtTree;
586 
587   void genMaskTable();
588   void genCrtTable();
589 
590 public:
591   PAlgebraModDerived& operator=(const PAlgebraModDerived&) = delete;
592 
593   PAlgebraModDerived(const PAlgebra& zMStar, long r);
594 
595   PAlgebraModDerived(const PAlgebraModDerived& other) // copy constructor
596       :
597       zMStar(other.zMStar),
598       r(other.r),
599       pPowR(other.pPowR),
600       pPowRContext(other.pPowRContext)
601   {
602     RBak bak;
603     bak.save();
604     restoreContext();
605     PhimXMod = other.PhimXMod;
606     factors = other.factors;
607     maskTable = other.maskTable;
608     crtTable = other.crtTable;
609     crtTree = other.crtTree;
610   }
611 
612   //! Returns a pointer to a "clone"
613   virtual PAlgebraModBase* clone() const override
614   {
615     return new PAlgebraModDerived(*this);
616   }
617 
618   //! Returns the type tag: PA_GF2_tag or PA_zz_p_tag
619   virtual PA_tag getTag() const override { return tag; }
620 
621   //! Returns reference to underlying PAlgebra object
622   virtual const PAlgebra& getZMStar() const override { return zMStar; }
623 
624   //! Returns reference to the factorization of Phi_m(X) mod p^r, but as ZZX's
625   virtual const std::vector<NTL::ZZX>& getFactorsOverZZ() const override
626   {
627     return factorsOverZZ;
628   }
629 
630   //! The value r
631   virtual long getR() const override { return r; }
632 
633   //! The value p^r
634   virtual long getPPowR() const override { return pPowR; }
635 
636   //! Restores the NTL context for p^r
637   virtual void restoreContext() const override { pPowRContext.restore(); }
638 
639   /* In all of the following functions, it is expected that the caller
640      has already restored the relevant modulus (p^r), which
641      can be done by invoking the method restoreContext()
642    */
643 
644   //! Returns reference to an RXModulus representing Phi_m(X) (mod p^r)
645   const RXModulus& getPhimXMod() const { return PhimXMod; }
646 
647   //! Returns reference to the factors of Phim_m(X) modulo p^r
648   const vec_RX& getFactors() const { return factors; }
649 
650   //! @brief Returns the CRT coefficients:
651   //! element i contains (prod_{j!=i} F_j)^{-1} mod F_i,
652   //! where F_0 F_1 ... is the factorization of Phi_m(X) mod p^r
653   const vec_RX& getCrtCoeffs() const { return crtCoeffs; }
654 
655   /**
656      @brief Returns ref to maskTable, which is used to implement rotations
657      (in the EncryptedArray module).
658 
659      `maskTable[i][j]` is a polynomial representation of a mask that is 1 in
660      all slots whose i'th coordinate is at least j, and 0 elsewhere. We have:
661      \verbatim
662        maskTable.size() == zMStar.numOfGens()     // # of generators
663        for i = 0..maskTable.size()-1:
664          maskTable[i].size() == zMStar.OrderOf(i)+1 // order of generator i
665      \endverbatim
666   **/
667   // logically, but not really, const
668   const std::vector<std::vector<RX>>& getMaskTable() const { return maskTable; }
669 
670   zzX getMask_zzX(long i, long j) const override
671   {
672     RBak bak;
673     bak.save();
674     restoreContext();
675     return balanced_zzX(maskTable.at(i).at(j));
676   }
677 
678   ///@{
679   //! @name Embedding in the plaintext slots and decoding back
680   //! In all the functions below, G must be irreducible mod p,
681   //! and the order of G must divide the order of p modulo m
682   //! (as returned by zMStar.getOrdP()).
683   //! In addition, when r > 1, G must be the monomial X (RX(1, 1))
684 
685   //! @brief Returns a std::vector crt[] such that crt[i] = H mod Ft (with t =
686   //! T[i])
687   void CRT_decompose(std::vector<RX>& crt, const RX& H) const;
688 
689   //! @brief Returns H in R[X]/Phi_m(X) s.t. for every i<nSlots and t=T[i],
690   //! we have H == crt[i] (mod Ft)
691   void CRT_reconstruct(RX& H, std::vector<RX>& crt) const;
692 
693   //! @brief Compute the maps for all the slots.
694   //! In the current implementation, we if r > 1, then
695   //! we must have either deg(G) == 1 or G == factors[0]
696   void mapToSlots(MappingData<type>& mappingData, const RX& G) const;
697 
698   //! @brief Returns H in R[X]/Phi_m(X) s.t. for every t in T, the element
699   //! Ht = (H mod Ft) in R[X]/Ft(X) represents the same element as alpha
700   //! in R[X]/G(X).
701   //!
702   //! Must have deg(alpha)<deg(G). The mappingData argument should contain
703   //! the output of mapToSlots(G).
704   void embedInAllSlots(RX& H,
705                        const RX& alpha,
706                        const MappingData<type>& mappingData) const;
707 
708   //! @brief Returns H in R[X]/Phi_m(X) s.t. for every t in T, the element
709   //! Ht = (H mod Ft) in R[X]/Ft(X) represents the same element as alphas[i]
710   //! in R[X]/G(X).
711   //!
712   //! Must have deg(alpha[i])<deg(G). The mappingData argument should contain
713   //! the output of mapToSlots(G).
714   void embedInSlots(RX& H,
715                     const std::vector<RX>& alphas,
716                     const MappingData<type>& mappingData) const;
717 
718   //! @brief Return an array such that alphas[i] in R[X]/G(X) represent the
719   //! same element as rt = (H mod Ft) in R[X]/Ft(X) where t=T[i].
720   //!
721   //! The mappingData argument should contain the output of mapToSlots(G).
722   void decodePlaintext(std::vector<RX>& alphas,
723                        const RX& ptxt,
724                        const MappingData<type>& mappingData) const;
725 
726   //! @brief Returns a coefficient std::vector C for the linearized polynomial
727   //! representing M.
728   //!
729   //! For h in Z/(p^r)[X] of degree < d,
730   //! \f[ M(h(X) mod G) = sum_{i=0}^{d-1} (C[j] mod G) * (h(X^{p^j}) mod G).\f]
731   //! G is assumed to be defined in mappingData, with d = deg(G).
732   //! L describes a linear map M by describing its action on the standard
733   //! power basis: M(x^j mod G) = (L[j] mod G), for j = 0..d-1.
734   void buildLinPolyCoeffs(std::vector<RX>& C,
735                           const std::vector<RX>& L,
736                           const MappingData<type>& mappingData) const;
737   ///@}
738 private:
739   /* internal functions, not for public consumption */
740 
741   static void SetModulus(long p)
742   {
743     RContext context(p);
744     context.restore();
745   }
746 
747   //! w in R[X]/F1(X) represents the same as X in R[X]/G(X)
748   void mapToF1(RX& w, const RX& G) const { mapToFt(w, G, 1); }
749 
750   //! Same as above, but embeds relative to Ft rather than F1. The
751   //! optional rF1 contains the output of mapToF1, to speed this operation.
752   void mapToFt(RX& w, const RX& G, long t, const RX* rF1 = nullptr) const;
753 
754   void buildTree(std::shared_ptr<TNode<RX>>& res,
755                  long offset,
756                  long extent) const;
757 
758   void evalTree(RX& res,
759                 std::shared_ptr<TNode<RX>> tree,
760                 const std::vector<RX>& crt1,
761                 long offset,
762                 long extent) const;
763 };
764 
765 //! A different derived class to be used for the approximate-numbers scheme
766 //! This is mostly a dummy class, but needed since the context always has a
767 //! PAlgebraMod data member.
768 template <>
769 class PAlgebraModDerived<PA_cx> : public PAlgebraModBase
770 {
771   const PAlgebra& zMStar;
772   long r; // counts bits of precision
773 
774 public:
775   PAlgebraModDerived(const PAlgebra& palg, long _r) : zMStar(palg), r(_r)
776   {
777     assertInRange<InvalidArgument>(r,
778                                    1l,
779                                    (long)NTL_SP_NBITS,
780                                    "Invalid bit precision r");
781   }
782 
783   PAlgebraModBase* clone() const override
784   {
785     return new PAlgebraModDerived(*this);
786   }
787   PA_tag getTag() const override { return PA_cx_tag; }
788 
789   const PAlgebra& getZMStar() const override { return zMStar; }
790   long getR() const override { return r; }
791   long getPPowR() const override { return 1L << r; }
792   void restoreContext() const override {}
793 
794   // These function make no sense for PAlgebraModCx
795   const std::vector<NTL::ZZX>& getFactorsOverZZ() const override
796   {
797     throw LogicError("PAlgebraModCx::getFactorsOverZZ undefined");
798   }
799   zzX getMask_zzX(UNUSED long i, UNUSED long j) const override
800   {
801     throw LogicError("PAlgebraModCx::getMask_zzX undefined");
802   }
803 };
804 
805 typedef PAlgebraModDerived<PA_cx> PAlgebraModCx;
806 
807 //! Builds a table, of type PA_GF2 if p == 2 and r == 1, and PA_zz_p otherwise
808 PAlgebraModBase* buildPAlgebraMod(const PAlgebra& zMStar, long r);
809 
810 // A simple wrapper for a pointer to an object of type PAlgebraModBase.
811 //
812 // Direct access to the virtual methods of PAlgebraModBase is provided,
813 // along with a "downcast" operator to get a reference to the object
814 // as a derived type, and == and != operators.
815 class PAlgebraMod
816 {
817 
818 private:
819   ClonedPtr<PAlgebraModBase> rep;
820 
821 public:
822   // copy constructor: default
823   // assignment: deleted
824   // destructor: default
825   // NOTE: the use of ClonedPtr ensures that the default copy constructor
826   // and destructor will work correctly.
827 
828   PAlgebraMod& operator=(const PAlgebraMod&) = delete;
829 
830   explicit PAlgebraMod(const PAlgebra& zMStar, long r) :
831       rep(buildPAlgebraMod(zMStar, r))
832   {}
833   // constructor
834 
835   //! Downcast operator
836   //! example: const PAlgebraModDerived<PA_GF2>& rep =
837   //! alMod.getDerived(PA_GF2());
838   template <typename type>
839   const PAlgebraModDerived<type>& getDerived(type) const
840   {
841     return dynamic_cast<const PAlgebraModDerived<type>&>(*rep);
842   }
843   const PAlgebraModCx& getCx() const
844   {
845     return dynamic_cast<const PAlgebraModCx&>(*rep);
846   }
847 
848   bool operator==(const PAlgebraMod& other) const
849   {
850     return getZMStar() == other.getZMStar() && getR() == other.getR();
851   }
852   // comparison
853 
854   bool operator!=(const PAlgebraMod& other) const { return !(*this == other); }
855   // comparison
856 
857   /* direct access to the PAlgebraModBase methods */
858 
859   //! Returns the type tag: PA_GF2_tag or PA_zz_p_tag
860   PA_tag getTag() const { return rep->getTag(); }
861   //! Returns reference to underlying PAlgebra object
862   const PAlgebra& getZMStar() const { return rep->getZMStar(); }
863   //! Returns reference to the factorization of Phi_m(X) mod p^r, but as ZZX's
864   const std::vector<NTL::ZZX>& getFactorsOverZZ() const
865   {
866     return rep->getFactorsOverZZ();
867   }
868   //! The value r
869   long getR() const { return rep->getR(); }
870   //! The value p^r
871   long getPPowR() const { return rep->getPPowR(); }
872   //! Restores the NTL context for p^r
873   void restoreContext() const { rep->restoreContext(); }
874 
875   zzX getMask_zzX(long i, long j) const { return rep->getMask_zzX(i, j); }
876 };
877 
878 //! returns true if the palg parameters match the rest, false otherwise
879 bool comparePAlgebra(const PAlgebra& palg,
880                      unsigned long m,
881                      unsigned long p,
882                      unsigned long r,
883                      const std::vector<long>& gens,
884                      const std::vector<long>& ords);
885 
886 // for internal consumption only
887 double calcPolyNormBnd(long m);
888 
889 } // namespace helib
890 
891 #endif // #ifndef HELIB_PALGEBRA_H
892