1 /* Copyright (C) 2012-2020 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12 #ifndef HELIB_MATMUL_H
13 #define HELIB_MATMUL_H
14
15 #include <helib/EncryptedArray.h>
16 #include <functional>
17
18 namespace helib {
19
20 class MatMulFullExec;
21
22 // Abstract base class for representing a linear transformation on a full
23 // std::vector.
24 class MatMulFull
25 {
26 public:
~MatMulFull()27 virtual ~MatMulFull() {}
28 virtual const EncryptedArray& getEA() const = 0;
29 typedef MatMulFullExec ExecType;
30 };
31
32 // Concrete derived class that defines the matrix entries.
33 template <typename type>
34 class MatMulFull_derived : public MatMulFull
35 {
36 public:
37 PA_INJECT(type)
38
39 // Get (i, j) entry of matrix.
40 // Should return true when the entry is a zero.
41 virtual bool get(RX& out, long i, long j) const = 0;
42 };
43
44 //====================================
45
46 class BlockMatMulFullExec;
47
48 // Abstract base class for representing a block linear transformation on a full
49 // std::vector.
50 class BlockMatMulFull
51 {
52 public:
~BlockMatMulFull()53 virtual ~BlockMatMulFull() {}
54 virtual const EncryptedArray& getEA() const = 0;
55 typedef BlockMatMulFullExec ExecType;
56 };
57
58 // Concrete derived class that defines the matrix entries.
59 template <typename type>
60 class BlockMatMulFull_derived : public BlockMatMulFull
61 {
62 public:
63 PA_INJECT(type)
64
65 // Get (i, j) entry of matrix.
66 // Each entry is a d x d matrix over the base ring.
67 // Should return true when the entry is a zero.
68 virtual bool get(mat_R& out, long i, long j) const = 0;
69 };
70
71 //====================================
72
73 class MatMul1DExec;
74
75 // Abstract base class for representing a 1D linear transformation.
76 class MatMul1D
77 {
78 public:
~MatMul1D()79 virtual ~MatMul1D() {}
80 virtual const EncryptedArray& getEA() const = 0;
81 virtual long getDim() const = 0;
82 typedef MatMul1DExec ExecType;
83 };
84
85 // An intermediate class that is mainly intended for internal use.
86 template <typename type>
87 class MatMul1D_partial : public MatMul1D
88 {
89 public:
90 PA_INJECT(type)
91
92 // Get the i'th diagonal, encoded as a single constant.
93 // MatMul1D_derived (below) supplies a default implementation,
94 // which can be overridden in special circumstances.
95 virtual void processDiagonal(RX& poly,
96 long i,
97 const EncryptedArrayDerived<type>& ea) const = 0;
98 };
99
100 // Concrete derived class that defines the matrix entries.
101 template <typename type>
102 class MatMul1D_derived : public MatMul1D_partial<type>
103 {
104 public:
105 PA_INJECT(type)
106
107 // Should return true if their are multiple (different) transforms
108 // among the various components.
109 virtual bool multipleTransforms() const = 0;
110
111 // Get coordinate (i, j) of the kth component.
112 // Should return true when the entry is a zero.
113 virtual bool get(RX& out, long i, long j, long k) const = 0;
114
115 void processDiagonal(RX& poly,
116 long i,
117 const EncryptedArrayDerived<type>& ea) const override;
118 };
119
120 template <>
121 class MatMul1D_derived<PA_cx> : public MatMul1D
122 {
123 public:
124 // Get coordinate (i, j)
125 virtual std::complex<double> get(long i, long j) const = 0;
126
127 void processDiagonal(std::vector<std::complex<double>>& diag,
128 long i,
129 const EncryptedArrayCx& ea) const;
130
131 // final: ensures that dim==0 is the only possible dimension
getDim()132 virtual long getDim() const final { return 0; }
133 };
134
135 typedef MatMul1D_derived<PA_cx> MatMul1D_CKKS;
136
137 // more convenient user interfaces
138 // VJS-FIXME: document some of this stuff
139
140 class MatMul_CKKS : public MatMul1D_CKKS
141 {
142 public:
143 typedef std::function<double(long, long)> get_fun_type;
144
145 private:
146 const EncryptedArray& ea;
147
148 get_fun_type get_fun;
149 // get_fun(i,j) returns matrix entry (i,j)
150 // see get_fun_type definitions below
151
152 public:
MatMul_CKKS(const EncryptedArray & _ea,get_fun_type _get_fun)153 MatMul_CKKS(const EncryptedArray& _ea, get_fun_type _get_fun) :
154 ea(_ea), get_fun(_get_fun)
155 {}
156
MatMul_CKKS(const Context & context,get_fun_type _get_fun)157 MatMul_CKKS(const Context& context, get_fun_type _get_fun) :
158 ea(context.getDefaultEA()), get_fun(_get_fun)
159 {}
160
getEA()161 virtual const EncryptedArray& getEA() const override { return ea; }
162
get(long i,long j)163 virtual std::complex<double> get(long i, long j) const override
164 {
165 return get_fun(i, j);
166 }
167 };
168
169 class MatMul_CKKS_Complex : public MatMul1D_CKKS
170 {
171 public:
172 typedef std::function<std::complex<double>(long, long)> get_fun_type;
173
174 private:
175 const EncryptedArray& ea;
176
177 get_fun_type get_fun;
178 // get_fun(i,j) returns matrix entry (i,j)
179 // see get_fun_type definitions below
180
181 public:
MatMul_CKKS_Complex(const EncryptedArray & _ea,get_fun_type _get_fun)182 MatMul_CKKS_Complex(const EncryptedArray& _ea, get_fun_type _get_fun) :
183 ea(_ea), get_fun(_get_fun)
184 {}
185
MatMul_CKKS_Complex(const Context & context,get_fun_type _get_fun)186 MatMul_CKKS_Complex(const Context& context, get_fun_type _get_fun) :
187 ea(context.getDefaultEA()), get_fun(_get_fun)
188 {}
189
getEA()190 virtual const EncryptedArray& getEA() const override { return ea; }
191
get(long i,long j)192 virtual std::complex<double> get(long i, long j) const override
193 {
194 return get_fun(i, j);
195 }
196 };
197
198 //====================================
199
200 class BlockMatMul1DExec;
201
202 // Abstract base class for representing a block 1D linear transformation.
203 class BlockMatMul1D
204 {
205 public:
~BlockMatMul1D()206 virtual ~BlockMatMul1D() {}
207 virtual const EncryptedArray& getEA() const = 0;
208 virtual long getDim() const = 0;
209 typedef BlockMatMul1DExec ExecType;
210 };
211
212 // An intermediate class that is mainly intended for internal use.
213 template <typename type>
214 class BlockMatMul1D_partial : public BlockMatMul1D
215 {
216 public:
217 PA_INJECT(type)
218
219 // Get the i'th diagonal, encoded as a std::vector of d constants,
220 // where d is the order of p.
221 // BlockMatMul1D_derived (below) supplies a default implementation,
222 // which can be overridden in special circumstances.
223 virtual bool processDiagonal(std::vector<RX>& poly,
224 long i,
225 const EncryptedArrayDerived<type>& ea) const = 0;
226 };
227
228 // Concrete derived class that defines the matrix entries.
229 template <typename type>
230 class BlockMatMul1D_derived : public BlockMatMul1D_partial<type>
231 {
232 public:
233 PA_INJECT(type)
234
235 // Should return true if their are multiple (different) transforms
236 // among the various components.
237 virtual bool multipleTransforms() const = 0;
238
239 // Get coordinate (i, j) of the kth component.
240 // Each entry is a d x d matrix over the base ring.
241 // Should return true when the entry is a zero.
242 virtual bool get(mat_R& out, long i, long j, long k) const = 0;
243
244 bool processDiagonal(std::vector<RX>& poly,
245 long i,
246 const EncryptedArrayDerived<type>& ea) const override;
247 };
248
249 //====================================
250
251 struct ConstMultiplier;
252 // Defined in matmul.cpp.
253 // Holds a constant by which a ciphertext can be multiplied.
254 // Internally, it is represented as either zzX or a DoubleCRT.
255 // The former occupies less space, but the latter makes for
256 // much faster multiplication.
257
258 struct ConstMultiplierCache
259 {
260 std::vector<std::shared_ptr<ConstMultiplier>> multiplier;
261
262 // Upgrade zzX constants to DoubleCRT constants.
263 void upgrade(const Context& context);
264 };
265
266 //====================================
267
268 // Abstract base case for multiplying an encrypted std::vector by a plaintext
269 // matrix.
270 class MatMulExecBase
271 {
272 public:
~MatMulExecBase()273 virtual ~MatMulExecBase() {}
274
275 virtual const EncryptedArray& getEA() const = 0;
276
277 // Upgrade zzX constants to DoubleCRT constants.
278 virtual void upgrade() = 0;
279
280 // If ctxt encrypts a row std::vector v, then this replaces ctxt
281 // by an encryption of the row std::vector v*mat, where mat is
282 // a matrix provided to the constructor of one of the
283 // concrete subclasses MatMul1DExec, BlockMatMul1DExec,
284 // MatMulFullExec, BlockMatMulFullExec, defined below.
285 virtual void mul(Ctxt& ctxt) const = 0;
286 };
287
288 //====================================
289
290 // Class used to multiply an encrypted row std::vector by a 1D linear
291 // transformation.
292 class MatMul1DExec : public MatMulExecBase
293 {
294 public:
295 const EncryptedArray& ea;
296
297 long dim;
298 long D;
299 bool native;
300 bool minimal;
301 long g;
302
303 ConstMultiplierCache cache;
304 ConstMultiplierCache cache1; // only for non-native dimension
305
306 // The constructor encodes all the constants for a given
307 // matrix in zzX format.
308 // The mat argument defines the entries of the matrix.
309 // Use the upgrade method (below) to convert to DoubleCRT format.
310 // If the minimal flag is set to true, a strategy that relies
311 // on a minimal number of key switching matrices will be used;
312 // this is intended for use in conjunction with the
313 // addMinimal{1D,Frb}Matrices routines declared in helib.h.
314 // If the minimal flag is false, it is best to use the
315 // addSome{1D,Frb}Matrices routines declared in helib.h.
316 explicit MatMul1DExec(const MatMul1D& mat, bool minimal = false);
317
318 // VJS-FIXME: it seems that the minimal flag is currently
319 // redundant, as the decision is essentially based on
320 // ctxt.getPubKey().getKSStrategy(dim0). Need to look into this
321 // and re-assess.
322
323 // Replaces an encryption of row std::vector v by encryption of v*mat
324 void mul(Ctxt& ctxt) const override;
325
326 // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()327 void upgrade() override
328 {
329 cache.upgrade(ea.getContext());
330 cache1.upgrade(ea.getContext());
331 }
332
getEA()333 const EncryptedArray& getEA() const override { return ea; }
334 };
335
336 // A more convenient and naturally-named interface for CKKS
337 // VJS-FIXME: document some of this stuff
338
339 class EncodedMatMul_CKKS : public MatMul1DExec
340 {
341 public:
EncodedMatMul_CKKS(const MatMul1D_CKKS & mat)342 EncodedMatMul_CKKS(const MatMul1D_CKKS& mat) : MatMul1DExec(mat) {}
343 };
344
345 //====================================
346
347 // Class used to multiply an encrypted row std::vector by a block 1D linear
348 // transformation.
349 class BlockMatMul1DExec : public MatMulExecBase
350 {
351 public:
352 const EncryptedArray& ea;
353
354 long dim;
355 long D;
356 long d;
357 bool native;
358 long strategy;
359
360 ConstMultiplierCache cache;
361 ConstMultiplierCache cache1; // only for non-native dimension
362
363 // The constructor encodes all the constants for a given
364 // matrix in zzX format.
365 // The mat argument defines the entries of the matrix.
366 // Use the upgrade method (below) to convert to DoubleCRT format.
367 // If the minimal flag is set to true, a strategy that relies
368 // on a minimal number of key switching matrices will be used;
369 // this is intended for use in conjunction with the
370 // addMinimal{1D,Frb}Matrices routines declared in helib.h.
371 // If the minimal flag is false, it is best to use the
372 // addSome{1D,Frb}Matrices routines declared in helib.h.
373 explicit BlockMatMul1DExec(const BlockMatMul1D& mat, bool minimal = false);
374
375 // Replaces an encryption of row std::vector v by encryption of v*mat
376 void mul(Ctxt& ctxt) const override;
377
378 // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()379 void upgrade() override
380 {
381 cache.upgrade(ea.getContext());
382 cache1.upgrade(ea.getContext());
383 }
384
getEA()385 const EncryptedArray& getEA() const override { return ea; }
386 };
387
388 //====================================
389
390 // Class used to multiply an encrypted row std::vector by a full linear
391 // transformation.
392 class MatMulFullExec : public MatMulExecBase
393 {
394 public:
395 const EncryptedArray& ea;
396 bool minimal;
397 std::vector<long> dims;
398 std::vector<MatMul1DExec> transforms;
399
400 // The constructor encodes all the constants for a given
401 // matrix in zzX format.
402 // The mat argument defines the entries of the matrix.
403 // Use the upgrade method (below) to convert to DoubleCRT format.
404 // If the minimal flag is set to true, a strategy that relies
405 // on a minimal number of key switching matrices will be used;
406 // this is intended for use in conjunction with the
407 // addMinimal{1D,Frb}Matrices routines declared in helib.h.
408 // If the minimal flag is false, it is best to use the
409 // addSome{1D,Frb}Matrices routines declared in helib.h.
410 explicit MatMulFullExec(const MatMulFull& mat, bool minimal = false);
411
412 // Replaces an encryption of row std::vector v by encryption of v*mat
413 void mul(Ctxt& ctxt) const override;
414
415 // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()416 void upgrade() override
417 {
418 for (auto& t : transforms)
419 t.upgrade();
420 }
421
getEA()422 const EncryptedArray& getEA() const override { return ea; }
423
424 // This really should be private.
425 long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
426 };
427
428 //====================================
429
430 // Class used to multiply an encrypted row std::vector by a full block linear
431 // transformation.
432 class BlockMatMulFullExec : public MatMulExecBase
433 {
434 public:
435 const EncryptedArray& ea;
436 bool minimal;
437 std::vector<long> dims;
438 std::vector<BlockMatMul1DExec> transforms;
439
440 // The constructor encodes all the constants for a given
441 // matrix in zzX format.
442 // The mat argument defines the entries of the matrix.
443 // Use the upgrade method (below) to convert to DoubleCRT format.
444 // If the minimal flag is set to true, a strategy that relies
445 // on a minimal number of key switching matrices will be used;
446 // this is intended for use in conjunction with the
447 // addMinimal{1D,Frb}Matrices routines declared in helib.h.
448 // If the minimal flag is false, it is best to use the
449 // addSome{1D,Frb}Matrices routines declared in helib.h.
450 explicit BlockMatMulFullExec(const BlockMatMulFull& mat,
451 bool minimal = false);
452
453 // Replaces an encryption of row std::vector v by encryption of v*mat
454 void mul(Ctxt& ctxt) const override;
455
456 // Upgrades encoded constants from zzX to DoubleCRT.
upgrade()457 void upgrade() override
458 {
459 for (auto& t : transforms)
460 t.upgrade();
461 }
462
getEA()463 const EncryptedArray& getEA() const override { return ea; }
464
465 // This really should be private.
466 long rec_mul(Ctxt& acc, const Ctxt& ctxt, long dim, long idx) const;
467 };
468
469 //===================================
470
471 // ctxt = \sum_{i=0}^{d-1} \sigma^i(ctxt),
472 // where d = order of p mod m, and \sigma is the Frobenius map
473
474 void traceMap(Ctxt& ctxt);
475
476 //====================================
477
478 // These routines apply linear transformation to plaintext arrays.
479 // Mainly for testing purposes.
480 void mul(PlaintextArray& pa, const MatMul1D& mat);
481 void mul(PlaintextArray& pa, const BlockMatMul1D& mat);
482 void mul(PlaintextArray& pa, const MatMulFull& mat);
483 void mul(PlaintextArray& pa, const BlockMatMulFull& mat);
484
485 // VJS-FIXME: these should be documented
486
mul(PtxtArray & a,const MatMul1D & mat)487 inline void mul(PtxtArray& a, const MatMul1D& mat)
488 {
489 assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
490 mul(a.pa, mat);
491 }
492
mul(PtxtArray & a,const BlockMatMul1D & mat)493 inline void mul(PtxtArray& a, const BlockMatMul1D& mat)
494 {
495 assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
496 mul(a.pa, mat);
497 }
498
mul(PtxtArray & a,const MatMulFull & mat)499 inline void mul(PtxtArray& a, const MatMulFull& mat)
500 {
501 assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
502 mul(a.pa, mat);
503 }
504
mul(PtxtArray & a,const BlockMatMulFull & mat)505 inline void mul(PtxtArray& a, const BlockMatMulFull& mat)
506 {
507 assertTrue(&a.ea == &mat.getEA(), "PtxtArray: inconsistent operation");
508 mul(a.pa, mat);
509 }
510
511 // more interface conviences, both for PtxtArray and Ctxt
512
513 inline PtxtArray& operator*=(PtxtArray& a, const MatMul1D& mat)
514 {
515 mul(a, mat);
516 return a;
517 }
518
519 inline PtxtArray& operator*=(PtxtArray& a, const BlockMatMul1D& mat)
520 {
521 mul(a, mat);
522 return a;
523 }
524
525 inline PtxtArray& operator*=(PtxtArray& a, const MatMulFull& mat)
526 {
527 mul(a, mat);
528 return a;
529 }
530
531 inline PtxtArray& operator*=(PtxtArray& a, const BlockMatMulFull& mat)
532 {
533 mul(a, mat);
534 return a;
535 }
536
537 // For ctxt's, these functions don't do any pre-computation
538
539 inline Ctxt& operator*=(Ctxt& a, const MatMul1D& mat)
540 {
541 MatMul1DExec mat_exec(mat);
542 mat_exec.mul(a);
543 return a;
544 }
545
546 inline Ctxt& operator*=(Ctxt& a, const BlockMatMul1D& mat)
547 {
548 BlockMatMul1DExec mat_exec(mat);
549 mat_exec.mul(a);
550 return a;
551 }
552
553 inline Ctxt& operator*=(Ctxt& a, const MatMulFull& mat)
554 {
555 MatMulFullExec mat_exec(mat);
556 mat_exec.mul(a);
557 return a;
558 }
559
560 inline Ctxt& operator*=(Ctxt& a, const BlockMatMulFull& mat)
561 {
562 BlockMatMulFullExec mat_exec(mat);
563 mat_exec.mul(a);
564 return a;
565 }
566
567 // For ctxt's, these functions do allow pre-computation
568
569 inline Ctxt& operator*=(Ctxt& a, const MatMulExecBase& mat)
570 {
571 mat.mul(a);
572 return a;
573 }
574
575 // These are used mainly for performance evaluation.
576
577 extern int fhe_test_force_bsgs;
578 // Controls whether or not we use BSGS multiplication.
579 // 1 to force on, -1 to force off, 0 for default behaviour.
580
581 extern int fhe_test_force_hoist;
582 // Controls whether ot not we use hoisting.
583 // -1 to force off, 0 for default behaviour.
584
585 } // namespace helib
586
587 #endif // ifndef HELIB_MATMUL_H
588