1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 #ifndef HELIB_MATMUL_H
13 #define HELIB_MATMUL_H
14 
15 #include <helib/EncryptedArray.h>
16 #include <functional>
17 
18 namespace helib {
19 
20 class MatMulFullExec;
21 
22 // Abstract base class for representing a linear transformation on a full
23 // std::vector.
24 class MatMulFull
25 {
26 public:
~MatMulFull()27   virtual ~MatMulFull() {}
28   virtual const EncryptedArray& getEA() const = 0;
29   typedef MatMulFullExec ExecType;
30 };
31 
32 // Concrete derived class that defines the matrix entries.
33 template <typename type>
34 class MatMulFull_derived : public MatMulFull
35 {
36 public:
37   PA_INJECT(type)
38 
39   // Get (i, j) entry of matrix.
40   // Should return true when the entry is a zero.
41   virtual bool get(RX& out, long i, long j) const = 0;
42 };
43 
44 //====================================
45 
46 class BlockMatMulFullExec;
47 
48 // Abstract base class for representing a block linear transformation on a full
49 // std::vector.
50 class BlockMatMulFull
51 {
52 public:
~BlockMatMulFull()53   virtual ~BlockMatMulFull() {}
54   virtual const EncryptedArray& getEA() const = 0;
55   typedef BlockMatMulFullExec ExecType;
56 };
57 
58 // Concrete derived class that defines the matrix entries.
59 template <typename type>
60 class BlockMatMulFull_derived : public BlockMatMulFull
61 {
62 public:
63   PA_INJECT(type)
64 
65   // Get (i, j) entry of matrix.
66   // Each entry is a d x d matrix over the base ring.
67   // Should return true when the entry is a zero.
68   virtual bool get(mat_R& out, long i, long j) const = 0;
69 };
70 
71 //====================================
72 
73 class MatMul1DExec;
74 
75 // Abstract base class for representing a 1D linear transformation.
76 class MatMul1D
77 {
78 public:
~MatMul1D()79   virtual ~MatMul1D() {}
80   virtual const EncryptedArray& getEA() const = 0;
81   virtual long getDim() const = 0;
82   typedef MatMul1DExec ExecType;
83 };
84 
85 // An intermediate class that is mainly intended for internal use.
86 template <typename type>
87 class MatMul1D_partial : public MatMul1D
88 {
89 public:
90   PA_INJECT(type)
91 
92   // Get the i'th diagonal, encoded as a single constant.
93   // MatMul1D_derived (below) supplies a default implementation,
94   // which can be overridden in special circumstances.
95   virtual void processDiagonal(RX& poly,
96                                long i,
97                                const EncryptedArrayDerived<type>& ea) const = 0;
98 };
99 
100 // Concrete derived class that defines the matrix entries.
101 template <typename type>
102 class MatMul1D_derived : public MatMul1D_partial<type>
103 {
104 public:
105   PA_INJECT(type)
106 
107   // Should return true if their are multiple (different) transforms
108   // among the various components.
109   virtual bool multipleTransforms() const = 0;
110 
111   // Get coordinate (i, j) of the kth component.
112   // Should return true when the entry is a zero.
113   virtual bool get(RX& out, long i, long j, long k) const = 0;
114 
115   void processDiagonal(RX& poly,
116                        long i,
117                        const EncryptedArrayDerived<type>& ea) const override;
118 };
119 
120 template <>
121 class MatMul1D_derived<PA_cx> : public MatMul1D
122 {
123 public:
124   // Get coordinate (i, j)
125   virtual std::complex<double> get(long i, long j) const = 0;
126 
127   void processDiagonal(std::vector<std::complex<double>>& diag,
128                        long i,
129                        const EncryptedArrayCx& ea) const;
130 
131   // final: ensures that dim==0 is the only possible dimension
getDim()132   virtual long getDim() const final { return 0; }
133 };
134 
135 typedef MatMul1D_derived<PA_cx> MatMul1D_CKKS;
136 
137 // more convenient user interfaces
138 // VJS-FIXME: document some of this stuff
139 
140 class MatMul_CKKS : public MatMul1D_CKKS
141 {
142 public:
143   typedef std::function<double(long, long)> get_fun_type;
144 
145 private:
146   const EncryptedArray& ea;
147 
148   get_fun_type get_fun;
149   // get_fun(i,j) returns matrix entry (i,j)
150   // see get_fun_type definitions below
151 
152 public:
MatMul_CKKS(const EncryptedArray & _ea,get_fun_type _get_fun)153   MatMul_CKKS(const EncryptedArray& _ea, get_fun_type _get_fun) :
154       ea(_ea), get_fun(_get_fun)
155   {}
156 
MatMul_CKKS(const Context & context,get_fun_type _get_fun)157   MatMul_CKKS(const Context& context, get_fun_type _get_fun) :
158       ea(context.getDefaultEA()), get_fun(_get_fun)
159   {}
160 
getEA()161   virtual const EncryptedArray& getEA() const override { return ea; }
162 
get(long i,long j)163   virtual std::complex<double> get(long i, long j) const override
164   {
165     return get_fun(i, j);
166   }
167 };
168 
169 class MatMul_CKKS_Complex : public MatMul1D_CKKS
170 {
171 public:
172   typedef std::function<std::complex<double>(long, long)> get_fun_type;
173 
174 private:
175   const EncryptedArray& ea;
176 
177   get_fun_type get_fun;
178   // get_fun(i,j) returns matrix entry (i,j)
179   // see get_fun_type definitions below
180 
181 public:
MatMul_CKKS_Complex(const EncryptedArray & _ea,get_fun_type _get_fun)182   MatMul_CKKS_Complex(const EncryptedArray& _ea, get_fun_type _get_fun) :
183       ea(_ea), get_fun(_get_fun)
184   {}
185 
MatMul_CKKS_Complex(const Context & context,get_fun_type _get_fun)186   MatMul_CKKS_Complex(const Context& context, get_fun_type _get_fun) :
187       ea(context.getDefaultEA()), get_fun(_get_fun)
188   {}
189 
getEA()190   virtual const EncryptedArray& getEA() const override { return ea; }
191 
get(long i,long j)192   virtual std::complex<double> get(long i, long j) const override
193   {
194     return get_fun(i, j);
195   }
196 };
197 
198 //====================================
199 
200 class BlockMatMul1DExec;
201 
202 // Abstract base class for representing a block 1D linear transformation.
203 class BlockMatMul1D
204 {
205 public:
~BlockMatMul1D()206   virtual ~BlockMatMul1D() {}
207   virtual const EncryptedArray& getEA() const = 0;
208   virtual long getDim() const = 0;
209   typedef BlockMatMul1DExec ExecType;
210 };
211 
212 // An intermediate class that is mainly intended for internal use.
213 template <typename type>
214 class BlockMatMul1D_partial : public BlockMatMul1D
215 {
216 public:
217   PA_INJECT(type)
218 
219   // Get the i'th diagonal, encoded as a std::vector of d constants,
220   // where d is the order of p.
221   // BlockMatMul1D_derived (below) supplies a default implementation,
222   // which can be overridden in special circumstances.
223   virtual bool processDiagonal(std::vector<RX>& poly,
224                                long i,
225                                const EncryptedArrayDerived<type>& ea) const = 0;
226 };
227 
228 // Concrete derived class that defines the matrix entries.
229 template <typename type>
230 class BlockMatMul1D_derived : public BlockMatMul1D_partial<type>
231 {
232 public:
233   PA_INJECT(type)
234 
235   // Should return true if their are multiple (different) transforms
236   // among the various components.
237   virtual bool multipleTransforms() const = 0;
238 
239   // Get coordinate (i, j) of the kth component.
240   // Each entry is a d x d matrix over the base ring.
241   // Should return true when the entry is a zero.
242   virtual bool get(mat_R& out, long i, long j, long k) const = 0;
243 
244   bool processDiagonal(std::vector<RX>& poly,
245                        long i,
246                        const EncryptedArrayDerived<type>& ea) const override;
247 };
248 
249 //====================================
250 
251 struct ConstMultiplier;
252 // Defined in matmul.cpp.
253 // Holds a constant by which a ciphertext can be multiplied.
254 // Internally, it is represented as either zzX or a DoubleCRT.
255 // The former occupies less space, but the latter makes for
256 // much faster multiplication.
257 
258 struct ConstMultiplierCache
259 {
260   std::vector<std::shared_ptr<ConstMultiplier>> multiplier;
261 
262   // Upgrade zzX constants to DoubleCRT constants.
263   void upgrade(const Context& context);
264 };
265 
266 //====================================
267 
268 // Abstract base case for multiplying an encrypted std::vector by a plaintext
269 // matrix.
270 class MatMulExecBase
271 {
272 public:
~MatMulExecBase()273   virtual ~MatMulExecBase() {}
274 
275   virtual const EncryptedArray& getEA() const = 0;
276 
277   // Upgrade zzX constants to DoubleCRT constants.
278   virtual void upgrade() = 0;
279 
280   // If ctxt encrypts a row std::vector v, then this replaces ctxt
281   // by an encryption of the row std::vector v*mat, where mat is
282   // a matrix provided to the constructor of one of the
283   // concrete subclasses MatMul1DExec, BlockMatMul1DExec,
284   // MatMulFullExec, BlockMatMulFullExec, defined below.
285   virtual void mul(Ctxt& ctxt) const = 0;
286 };
287 
288 //====================================
289 
290 // Class used to multiply an encrypted row std::vector by a 1D linear
291 // transformation.
292 class MatMul1DExec : public MatMulExecBase
293 {
294 public:
295   const EncryptedArray& ea;
296 
297   long dim;
298   long D;
299   bool native;
300   bool minimal;
301   long g;
302 
303   ConstMultiplierCache cache;
304   ConstMultiplierCache cache1; // only for non-native dimension
305 
306   // The constructor encodes all the constants for a given
307   // matrix in zzX format.
308   // The mat argument defines the entries of the matrix.
309   // Use the upgrade method (below) to convert to DoubleCRT format.
310   // If the minimal flag is set to true, a strategy that relies
311   // on a minimal number of key switching matrices will be used;
312   // this is intended for use in conjunction with the
313   // addMinimal{1D,Frb}Matrices routines declared in helib.h.
314   // If the minimal flag is false, it is best to use the
315   // addSome{1D,Frb}Matrices routines declared in helib.h.
316   explicit MatMul1DExec(const MatMul1D& mat, bool minimal = false);
317 
318   // VJS-FIXME: it seems that the minimal flag is currently
319   // redundant, as the decision is essentially based on
320   // ctxt.getPubKey().getKSStrategy(dim0). Need to look into this
321   // and re-assess.
322 
323   // Replaces an encryption of row std::vector v by encryption of v*mat
324   void mul(Ctxt& ctxt) const override;
325 
326   // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()327   void upgrade() override
328   {
329     cache.upgrade(ea.getContext());
330     cache1.upgrade(ea.getContext());
331   }
332 
getEA()333   const EncryptedArray& getEA() const override { return ea; }
334 };
335 
336 // A more convenient and naturally-named interface for CKKS
337 // VJS-FIXME: document some of this stuff
338 
339 class EncodedMatMul_CKKS : public MatMul1DExec
340 {
341 public:
EncodedMatMul_CKKS(const MatMul1D_CKKS & mat)342   EncodedMatMul_CKKS(const MatMul1D_CKKS& mat) : MatMul1DExec(mat) {}
343 };
344 
345 //====================================
346 
347 // Class used to multiply an encrypted row std::vector by a block 1D linear
348 // transformation.
349 class BlockMatMul1DExec : public MatMulExecBase
350 {
351 public:
352   const EncryptedArray& ea;
353 
354   long dim;
355   long D;
356   long d;
357   bool native;
358   long strategy;
359 
360   ConstMultiplierCache cache;
361   ConstMultiplierCache cache1; // only for non-native dimension
362 
363   // The constructor encodes all the constants for a given
364   // matrix in zzX format.
365   // The mat argument defines the entries of the matrix.
366   // Use the upgrade method (below) to convert to DoubleCRT format.
367   // If the minimal flag is set to true, a strategy that relies
368   // on a minimal number of key switching matrices will be used;
369   // this is intended for use in conjunction with the
370   // addMinimal{1D,Frb}Matrices routines declared in helib.h.
371   // If the minimal flag is false, it is best to use the
372   // addSome{1D,Frb}Matrices routines declared in helib.h.
373   explicit BlockMatMul1DExec(const BlockMatMul1D& mat, bool minimal = false);
374 
375   // Replaces an encryption of row std::vector v by encryption of v*mat
376   void mul(Ctxt& ctxt) const override;
377 
378   // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()379   void upgrade() override
380   {
381     cache.upgrade(ea.getContext());
382     cache1.upgrade(ea.getContext());
383   }
384 
getEA()385   const EncryptedArray& getEA() const override { return ea; }
386 };
387 
388 //====================================
389 
390 // Class used to multiply an encrypted row std::vector by a full linear
391 // transformation.
392 class MatMulFullExec : public MatMulExecBase
393 {
394 public:
395   const EncryptedArray& ea;
396   bool minimal;
397   std::vector<long> dims;
398   std::vector<MatMul1DExec> transforms;
399 
400   // The constructor encodes all the constants for a given
401   // matrix in zzX format.
402   // The mat argument defines the entries of the matrix.
403   // Use the upgrade method (below) to convert to DoubleCRT format.
404   // If the minimal flag is set to true, a strategy that relies
405   // on a minimal number of key switching matrices will be used;
406   // this is intended for use in conjunction with the
407   // addMinimal{1D,Frb}Matrices routines declared in helib.h.
408   // If the minimal flag is false, it is best to use the
409   // addSome{1D,Frb}Matrices routines declared in helib.h.
410   explicit MatMulFullExec(const MatMulFull& mat, bool minimal = false);
411 
412   // Replaces an encryption of row std::vector v by encryption of v*mat
413   void mul(Ctxt& ctxt) const override;
414 
415   // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()416   void upgrade() override
417   {
418     for (auto& t : transforms)
419       t.upgrade();
420   }
421 
getEA()422   const EncryptedArray& getEA() const override { return ea; }
423 
424   // This really should be private.
425   long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
426 };
427 
428 //====================================
429 
430 // Class used to multiply an encrypted row std::vector by a full block linear
431 // transformation.
432 class BlockMatMulFullExec : public MatMulExecBase
433 {
434 public:
435   const EncryptedArray& ea;
436   bool minimal;
437   std::vector<long> dims;
438   std::vector<BlockMatMul1DExec> transforms;
439 
440   // The constructor encodes all the constants for a given
441   // matrix in zzX format.
442   // The mat argument defines the entries of the matrix.
443   // Use the upgrade method (below) to convert to DoubleCRT format.
444   // If the minimal flag is set to true, a strategy that relies
445   // on a minimal number of key switching matrices will be used;
446   // this is intended for use in conjunction with the
447   // addMinimal{1D,Frb}Matrices routines declared in helib.h.
448   // If the minimal flag is false, it is best to use the
449   // addSome{1D,Frb}Matrices routines declared in helib.h.
450   explicit BlockMatMulFullExec(const BlockMatMulFull& mat,
451                                bool minimal = false);
452 
453   // Replaces an encryption of row std::vector v by encryption of v*mat
454   void mul(Ctxt& ctxt) const override;
455 
456   // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()457   void upgrade() override
458   {
459     for (auto& t : transforms)
460       t.upgrade();
461   }
462 
getEA()463   const EncryptedArray& getEA() const override { return ea; }
464 
465   // This really should be private.
466   long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
467 };
468 
469 //===================================
470 
471 // ctxt = \sum_{i=0}^{d-1} \sigma^i(ctxt),
472 //   where d = order of p mod m, and \sigma is the Frobenius map
473 
474 void traceMap(Ctxt& ctxt);
475 
476 //====================================
477 
478 // These routines apply linear transformation to plaintext arrays.
479 // Mainly for testing purposes.
480 void mul(PlaintextArray& pa, const MatMul1D& mat);
481 void mul(PlaintextArray& pa, const BlockMatMul1D& mat);
482 void mul(PlaintextArray& pa, const MatMulFull& mat);
483 void mul(PlaintextArray& pa, const BlockMatMulFull& mat);
484 
485 // VJS-FIXME: these should be documented
486 
mul(PtxtArray & a,const MatMul1D & mat)487 inline void mul(PtxtArray& a, const MatMul1D& mat)
488 {
489   assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
490   mul(a.pa, mat);
491 }
492 
mul(PtxtArray & a,const BlockMatMul1D & mat)493 inline void mul(PtxtArray& a, const BlockMatMul1D& mat)
494 {
495   assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
496   mul(a.pa, mat);
497 }
498 
mul(PtxtArray & a,const MatMulFull & mat)499 inline void mul(PtxtArray& a, const MatMulFull& mat)
500 {
501   assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
502   mul(a.pa, mat);
503 }
504 
mul(PtxtArray & a,const BlockMatMulFull & mat)505 inline void mul(PtxtArray& a, const BlockMatMulFull& mat)
506 {
507   assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
508   mul(a.pa, mat);
509 }
510 
511 // more interface conviences, both for PtxtArray and Ctxt
512 
513 inline PtxtArray& operator*=(PtxtArray& a, const MatMul1D& mat)
514 {
515   mul(a, mat);
516   return a;
517 }
518 
519 inline PtxtArray& operator*=(PtxtArray& a, const BlockMatMul1D& mat)
520 {
521   mul(a, mat);
522   return a;
523 }
524 
525 inline PtxtArray& operator*=(PtxtArray& a, const MatMulFull& mat)
526 {
527   mul(a, mat);
528   return a;
529 }
530 
531 inline PtxtArray& operator*=(PtxtArray& a, const BlockMatMulFull& mat)
532 {
533   mul(a, mat);
534   return a;
535 }
536 
537 // For ctxt's, these functions don't do any pre-computation
538 
539 inline Ctxt& operator*=(Ctxt& a, const MatMul1D& mat)
540 {
541   MatMul1DExec mat_exec(mat);
542   mat_exec.mul(a);
543   return a;
544 }
545 
546 inline Ctxt& operator*=(Ctxt& a, const BlockMatMul1D& mat)
547 {
548   BlockMatMul1DExec mat_exec(mat);
549   mat_exec.mul(a);
550   return a;
551 }
552 
553 inline Ctxt& operator*=(Ctxt& a, const MatMulFull& mat)
554 {
555   MatMulFullExec mat_exec(mat);
556   mat_exec.mul(a);
557   return a;
558 }
559 
560 inline Ctxt& operator*=(Ctxt& a, const BlockMatMulFull& mat)
561 {
562   BlockMatMulFullExec mat_exec(mat);
563   mat_exec.mul(a);
564   return a;
565 }
566 
567 //  For ctxt's, these functions do allow pre-computation
568 
569 inline Ctxt& operator*=(Ctxt& a, const MatMulExecBase& mat)
570 {
571   mat.mul(a);
572   return a;
573 }
574 
575 // These are used mainly for performance evaluation.
576 
577 extern int fhe_test_force_bsgs;
578 // Controls whether or not we use BSGS multiplication.
579 // 1 to force on, -1 to force off, 0 for default behaviour.
580 
581 extern int fhe_test_force_hoist;
582 // Controls whether ot not we use hoisting.
583 // -1 to force off, 0 for default behaviour.
584 
585 } // namespace helib
586 
587 #endif // ifndef HELIB_MATMUL_H
588