1 /* Copyright (C) 2012-2017 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 /* matmul.cpp - Data-movement operations on arrays of slots
13  */
14 #include <algorithm>
15 #include <NTL/BasicThreadPool.h>
16 #include "matmul.h"
17 
18 #ifdef DEBUG
19 void printCache(const CachedzzxMatrix& cache);
20 void printCache(const CachedDCRTMatrix& cache);
21 #endif
22 
23 
24 /********************************************************************
25  ************* Helper routines to handle caches *********************/
26 
lockCache(MatrixCacheType ty)27 bool MatMulBase::lockCache(MatrixCacheType ty)
28 {
29   // Check if we need to build cache (1st time)
30   if (ty==cacheEmpty || hasDCRTcache()
31       || (ty==cachezzX && haszzxcache())) // no need to build
32     return false;
33 
34   cachelock.lock();  // Take the lock
35 
36   // Check again if we need to build the cache
37   if (ty==cacheEmpty || hasDCRTcache()
38       || (ty==cachezzX && haszzxcache())) { // no need to build
39     cachelock.unlock();
40     return false; // no need to build
41   }
42 
43   // We have the lock, can upgrade zzx to dcrt if needed
44   if (ty==cacheDCRT && haszzxcache()) {
45     upgradeCache();     // upgrade zzx to DCRT
46     cachelock.unlock(); // release the lock
47     return false; // already built
48   }
49   return true; // need to build, and we have the lock
50 }
51 
52 // This function assumes that we have the lock
upgradeCache()53 void MatMulBase::upgradeCache()
54 {
55   std::unique_ptr<CachedDCRTMatrix> dCache(new CachedDCRTMatrix());
56   dCache->SetLength(zzxCache->length());
57   for (long i=0; i<zzxCache->length(); i++) {
58     if ((*zzxCache)[i] != nullptr)
59       (*dCache)[i].reset(new DoubleCRT(*(*zzxCache)[i], ea.getContext()));
60   }
61   dcrtCache.swap(dCache);
62 }
63 /********************************************************************/
64 
65 
66 /********************************************************************
67  * An implementation class for dense matmul.
68  *
69  * Using such a class (rather than plain functions) is convenient
70  * since you only need to PA_INJECT once (and the injected names do
71  * not pollute the external namespace).
72  * We also use it here to store some pieces of data between recursive
73  * calls, as well as pointers to temporary caches as we build them.
74  *********************************************************************/
75 template<class type> class matmul_impl {
76   PA_INJECT(type)
77   const MatrixCacheType buildCache;
78   std::unique_ptr<CachedzzxMatrix> zCache;
79   std::unique_ptr<CachedDCRTMatrix> dCache;
80 
81   Ctxt* res;
82   MatMul<type>& mat;
83   std::vector<long> dims;
84   const EncryptedArrayDerived<type>& ea;
85 
86   // helper class to sort dimensions, so that
87   //    - bad dimensions come before good dimensions (primary sort key)
88   //    - small dimensions come before large dimesnions (secondary sort key)
89   // this is a good order to process the dimensions in the recursive
90   // mat_mul_dense routine: it ensures that the work done at the leaves
91   // of the recursion is minimized, and that the work done at the non-leaves
92   // is dominated by the work done at the leaves.
93   struct MatMulDimComp {
94     const EncryptedArrayDerived<type> *ea;
MatMulDimCompmatmul_impl::MatMulDimComp95     MatMulDimComp(const EncryptedArrayDerived<type> *_ea) : ea(_ea) {}
96 
operator ()matmul_impl::MatMulDimComp97     bool operator()(long i, long j) {
98       return (!ea->nativeDimension(i) && ea->nativeDimension(j)) ||
99              (  (ea->nativeDimension(i) == ea->nativeDimension(j)) &&
100                 (ea->sizeOfDimension(i) < ea->sizeOfDimension(j))  );
101     }
102   };
103 
104 public:
matmul_impl(Ctxt * c,MatMulBase & _mat,MatrixCacheType tag)105   matmul_impl(Ctxt* c, MatMulBase& _mat, MatrixCacheType tag)
106     : buildCache(tag), res(c),
107       mat(dynamic_cast< MatMul<type>& >(_mat)),
108       ea(_mat.getEA().getDerived(type()))
109   {
110     if (buildCache==cachezzX)
111       zCache.reset(new CachedzzxMatrix(NTL::INIT_SIZE,ea.size()));
112     else if (buildCache==cacheDCRT)
113       dCache.reset(new CachedDCRTMatrix(NTL::INIT_SIZE,ea.size()));
114 
115     // dims stores the order of dimensions
116     dims.resize(ea.dimension());
117     for (long i = 0; i < ea.dimension(); i++)
118       dims[i] = i;
119     sort(dims.begin(), dims.end(), MatMulDimComp(&ea));
120     // sort the dimenesions so that bad ones come before good,
121     // and then small ones come before large
122   }
123 
124   // Get a diagonal encoded as a single constant
processDiagonal(zzX & epmat,const vector<long> & idxes)125   bool processDiagonal(zzX& epmat, const vector<long>& idxes)
126   {
127     vector<RX> pmat;  // the plaintext diagonal
128     pmat.resize(ea.size());
129     bool zDiag = true; // is this a zero diagonal
130     for (long j = 0; j < ea.size(); j++) {
131       long i = idxes[j];
132       RX val;
133       if (mat.get(val, i, j)) // returns true if the entry is zero
134 	clear(pmat[j]);
135       else {           // not a zero entry
136 	pmat[j] = val;
137 	zDiag = false; // not a zero diagonal
138       }
139     }
140     // Now we have the constants for all the diagonal entries, encode the
141     // diagonal as a single polynomial with these constants in the slots
142     if (!zDiag) {
143       ea.encode(epmat, pmat);
144     }
145     return zDiag;
146   }
147 
148 
149   // A recursive matrix-by-vector multiply, used by the dense matrix code.
150   // This routine is optimized to use only the rotate1D routine rather
151   // than the more expensive linear-array rotations.
rec_mul(const Ctxt * pdata,long dim,long idx,const vector<long> & idxes)152   long rec_mul(const Ctxt* pdata, long dim, long idx,
153                const vector<long>& idxes)
154   {
155     if (dim >= ea.dimension()) { // Last dimension (recursion edge condition)
156       zzX pt;
157       zzX* zxPtr=nullptr;
158       DoubleCRT* dxPtr=nullptr;
159 
160       // Check if we have the relevant constant in cache
161       CachedzzxMatrix* zcp;
162       CachedDCRTMatrix* dcp;
163       mat.getCache(&zcp, &dcp);
164       if (dcp != nullptr)         // DoubleCRT cache exists
165 	dxPtr = (*dcp)[idx].get();
166       else if (zcp != nullptr)    // zzx cache exists but no DoubleCRT
167 	zxPtr = (*zcp)[idx].get();
168       else if (!processDiagonal(pt, idxes)) // no cache, compute const
169 	zxPtr = &pt; // if it is not a zero value, point to it
170 
171       // if constant is zero, return without doing anything
172       if (zxPtr==nullptr && dxPtr==nullptr)
173 	return idx+1;
174 
175       // Constant is non-zero, store it in cache and/or multiply/add it
176 
177       if (pdata!=nullptr && res!=nullptr) {
178         Ctxt tmp = *pdata;
179         if (dxPtr!=nullptr) tmp.multByConstant(*dxPtr); // mult by DCRT
180         else                tmp.multByConstant(*zxPtr); // mult by zzx
181         *res += tmp;
182       }
183 
184       if (buildCache==cachezzX) {
185         (*zCache)[idx].reset(new zzX(*zxPtr));
186       }
187       else if (buildCache==cacheDCRT) {
188         (*dCache)[idx].reset(new DoubleCRT(*zxPtr, ea.getContext()));
189       }
190       return idx+1;
191     }
192 
193     // not the last dimension, make a recursive call
194     long sdim = ea.sizeOfDimension(dims[dim]);
195 
196     // compute "in spirit" sum_i (pdata >> i) * i'th-diagonal, but
197     // adjust the indexes so that we only need to rotate the cipehrtext
198     // along the different dimensions separately
199     for (long offset = 0; offset < sdim; offset++) {
200       vector<long> idxes1;
201       ea.EncryptedArrayBase::rotate1D(idxes1, idxes, dims[dim], offset);
202       if (pdata!=nullptr && res!=nullptr) {
203 	Ctxt pdata1 = *pdata;
204 	ea.rotate1D(pdata1, dims[dim], offset);
205 	// indexes adjusted, make the recursive call
206 	idx = rec_mul(&pdata1, dim+1, idx, idxes1);
207       }
208       else // don't bother with the ciphertext
209 	idx = rec_mul(pdata, dim+1, idx, idxes1);
210     }
211 
212     return idx;
213   }
214 
215   // Multiply a ciphertext vector by a plaintext dense matrix
216   // and/or build a cache with the multiplication constants
multilpy(Ctxt * ctxt)217   void multilpy(Ctxt* ctxt)
218   {
219     RBak bak; bak.save(); ea.getTab().restoreContext();
220     // idxes describes a genealized diagonal, {(i,idx[i])}_i
221     // initially just the identity, idx[i]==i
222     vector<long> idxes(ea.size());
223     for (long i = 0; i < ea.size(); i++) idxes[i] = i;
224 
225     // call the recursive procedure to do the actual work
226     rec_mul(ctxt, 0, 0, idxes);
227 
228     if (ctxt!=nullptr && res!=nullptr)
229       *ctxt = *res; // copy the result back to ctxt
230 
231     // "install" the cache (if needed)
232     if (buildCache == cachezzX)
233       mat.installzzxcache(zCache);
234     else if (buildCache == cacheDCRT)
235       mat.installDCRTcache(dCache);
236   }
237 };
238 
239 // Wrapper functions around the implemenmtation class
mat_mul(Ctxt * ctxt,MatMulBase & mat,MatrixCacheType buildCache)240 static void mat_mul(Ctxt* ctxt, MatMulBase& mat, MatrixCacheType buildCache)
241 {
242   MatMulLock locking(mat, buildCache);
243 
244   // If locking.getType()!=cacheEmpty then we really do need to
245   // build the cache, and we also have the lock for it.
246 
247   if (locking.getType() == cacheEmpty && ctxt==nullptr) //  nothing to do
248     return;
249 
250   std::unique_ptr<Ctxt> res;
251   if (ctxt!=nullptr) { // we need to do an actual multiplication
252     ctxt->cleanUp(); // not sure, but this may be a good idea
253     res.reset(new Ctxt(ZeroCtxtLike, *ctxt));
254   }
255 
256   switch (mat.getEA().getTag()) {
257     case PA_GF2_tag: {
258       matmul_impl<PA_GF2> M(res.get(), mat, locking.getType());
259       M.multilpy(ctxt);
260       break;
261     }
262     case PA_zz_p_tag: {
263       matmul_impl<PA_zz_p> M(res.get(), mat, locking.getType());
264       M.multilpy(ctxt);
265       break;
266     }
267     default:
268       throw std::logic_error("mat_mul: neither PA_GF2 nor PA_zz_p");
269   }
270 }
buildCache4MatMul(MatMulBase & mat,MatrixCacheType buildCache)271 void buildCache4MatMul(MatMulBase& mat, MatrixCacheType buildCache)
272 { mat_mul(nullptr, mat, buildCache); }
273 
matMul(Ctxt & ctxt,MatMulBase & mat,MatrixCacheType buildCache)274 void matMul(Ctxt& ctxt, MatMulBase& mat, MatrixCacheType buildCache)
275 { mat_mul(&ctxt, mat, buildCache); }
276 
277 
278 
279 /********************************************************************
280  * An implementation class for sparse diagonal matmul.
281  *
282  * Using such a class (rather than plain functions) is convenient
283  * since you only need to PA_INJECT once (and the injected names do
284  * not pollute the external namespace). We also use it here to store
285  * pointers to temporary caches as we build them.
286  ********************************************************************/
287 template<class type> class matmul_sparse_impl {
288   PA_INJECT(type)
289   const MatrixCacheType buildCache;
290   std::unique_ptr<CachedzzxMatrix> zCache;
291   std::unique_ptr<CachedDCRTMatrix> dCache;
292 
293   MatMul<type>& mat;
294   const EncryptedArrayDerived<type>& ea;
295 public:
matmul_sparse_impl(MatMulBase & _mat,MatrixCacheType tag)296   matmul_sparse_impl(MatMulBase& _mat, MatrixCacheType tag)
297     : buildCache(tag), mat(dynamic_cast< MatMul<type>& >(_mat)),
298       ea(_mat.getEA().getDerived(type()))
299   {
300     if (buildCache==cachezzX)
301       zCache.reset(new CachedzzxMatrix(NTL::INIT_SIZE,ea.size()));
302     else if (buildCache==cacheDCRT)
303       dCache.reset(new CachedDCRTMatrix(NTL::INIT_SIZE,ea.size()));
304   }
305 
306   // Get a diagonal encoded as a single constant
processDiagonal(zzX & cPoly,long diagIdx,long nslots)307   bool processDiagonal(zzX& cPoly, long diagIdx, long nslots)
308   {
309     bool zDiag = true; // is this a zero diagonal
310     vector<RX> diag(ea.size()); // the plaintext diagonal
311 
312     for (long j = 0; j < nslots; j++) { // process entry j
313       long ii = mcMod(j-diagIdx, nslots);    // j-diagIdx mod nslots
314       bool zEntry = mat.get(diag[j], ii, j); // callback
315       assert(zEntry || deg(diag[j]) < ea.getDegree());
316 
317       if (zEntry) clear(diag[j]);
318       else if (!IsZero(diag[j]))
319         zDiag = false; // diagonal is non-zero
320     }
321     // Now we have the constants for all the diagonal entries, encode the
322     // diagonal as a single polynomial with these constants in the slots
323     if (!zDiag) {
324       ea.encode(cPoly, diag);
325     }
326     return zDiag;
327   }
328 
multiply(Ctxt * ctxt)329   void multiply(Ctxt* ctxt)
330   {
331     RBak bak; bak.save(); ea.getTab().restoreContext();
332 
333     long nslots = ea.size();
334     bool sequential = (ea.dimension()==1) && ea.nativeDimension(0);
335     // if just a single native dimension, then rotate adds only little noise
336 
337     std::unique_ptr<Ctxt> res, shCtxt;
338     if (ctxt!=nullptr) { // we need to do an actual multiplication
339       ctxt->cleanUp(); // not sure, but this may be a good idea
340       res.reset(new Ctxt(ZeroCtxtLike, *ctxt));
341       shCtxt.reset(new Ctxt(*ctxt));
342     }
343 
344     // Check if we have the relevant constant in cache
345     CachedzzxMatrix* zcp;
346     CachedDCRTMatrix* dcp;
347     mat.getCache(&zcp, &dcp);
348 
349     // Process the diagonals one at a time
350     long lastRotate = 0;
351     for (long i = 0; i < nslots; i++) {  // process diagonal i
352       zzX cpoly;
353       zzX* zxPtr=nullptr;
354       DoubleCRT* dxPtr=nullptr;
355 
356       if (dcp != nullptr)         // DoubleCRT cache exists
357 	dxPtr = (*dcp)[i].get();
358       else if (zcp != nullptr)    // zzx cache exists but no DoubleCRT
359 	zxPtr = (*zcp)[i].get();
360       else { // no cache, compute const
361         if (!processDiagonal(cpoly,i,nslots)) { // returns true if zero
362           zxPtr = &cpoly;   // if it is not a zero value, point to it
363 	}
364       }
365 
366       // if zero diagonal, nothing to do for this iteration
367       if (zxPtr==nullptr && dxPtr==nullptr)
368         continue;
369 
370       // Non-zero diagonal, store it in cache and/or multiply/add it
371 
372       if (ctxt!=nullptr && res!=nullptr) {
373         // rotate by i, multiply by the polynomial, then add to the result
374         if (i>0) {
375           if (sequential) {
376             ea.rotate(*ctxt, i-lastRotate);
377             *shCtxt = *ctxt;
378           } else {
379             *shCtxt = *ctxt;
380             ea.rotate(*shCtxt, i); // rotate by i
381           }
382           lastRotate = i;
383 	} // if i==0 we already have *shCtxt == *ctxt
384 
385         if (dxPtr!=nullptr) shCtxt->multByConstant(*dxPtr);
386 	else                shCtxt->multByConstant(*zxPtr);
387 	*res += *shCtxt;
388       }
389       if (buildCache==cachezzX) {
390         (*zCache)[i].reset(new zzX(*zxPtr));
391       }
392       else if (buildCache==cacheDCRT) {
393         (*dCache)[i].reset(new DoubleCRT(*zxPtr, ea.getContext()));
394       }
395     }
396 
397     if (ctxt!=nullptr && res!=nullptr) // copy result back to ctxt
398       *ctxt = *res;
399 
400     // "install" the cache (if needed)
401     if (buildCache == cachezzX)
402       mat.installzzxcache(zCache);
403     else if (buildCache == cacheDCRT)
404       mat.installDCRTcache(dCache);
405   } // end of multiply(...)
406 };
407 
408 // Wrapper functions around the implemenmtation class
409 static void
mat_mul_sparse(Ctxt * ctxt,MatMulBase & mat,MatrixCacheType buildCache)410 mat_mul_sparse(Ctxt* ctxt, MatMulBase& mat, MatrixCacheType buildCache)
411 {
412   MatMulLock locking(mat, buildCache);
413 
414   // If locking.getType()!=cacheEmpty then we really do need to
415   // build the cache, and we also have the lock for it.
416 
417   if (locking.getType() == cacheEmpty && ctxt==nullptr) //  nothing to do
418     return;
419 
420   switch (mat.getEA().getTag()) {
421     case PA_GF2_tag: {
422       matmul_sparse_impl<PA_GF2> M(mat, locking.getType());
423       M.multiply(ctxt);
424       break;
425     }
426     case PA_zz_p_tag: {
427       matmul_sparse_impl<PA_zz_p> M(mat, locking.getType());
428       M.multiply(ctxt);
429       break;
430     }
431     default:
432       throw std::logic_error("mat_mul_sparse: neither PA_GF2 nor PA_zz_p");
433   }
434 }
435 // Same as matMul but optimized for matrices with few non-zero diagonals
matMul_sparse(Ctxt & ctxt,MatMulBase & mat,MatrixCacheType buildCache)436 void matMul_sparse(Ctxt& ctxt, MatMulBase& mat,
437                    MatrixCacheType buildCache)
438 { mat_mul_sparse(&ctxt, mat, buildCache); }
439 
440 // Build a cache without performing multiplication
buildCache4MatMul_sparse(MatMulBase & mat,MatrixCacheType buildCache)441 void buildCache4MatMul_sparse(MatMulBase& mat, MatrixCacheType buildCache)
442 { mat_mul_sparse(nullptr, mat, buildCache); }
443 
444 
445 /********************************************************************
446  ********************************************************************/
447 // Applying matmul to plaintext, useful for debugging
448 template<class type> class matmul_pa_impl {
449 public:
PA_INJECT(type)450   PA_INJECT(type)
451 
452   static void matmul(NewPlaintextArray& pa, MatMul<type>& mat)
453   {
454     const EncryptedArrayDerived<type>& ea = mat.getEA().getDerived(type());
455     PA_BOILER
456 
457     vector<RX> res;
458     res.resize(n);
459     for (long j = 0; j < n; j++) {
460       RX acc, val, tmp;
461       acc = 0;
462       for (long i = 0; i < n; i++) {
463         if (!mat.get(val, i, j)) {
464           NTL::mul(tmp, data[i], val);
465           NTL::add(acc, acc, tmp);
466         }
467       }
468       rem(acc, acc, G);
469       res[j] = acc;
470     }
471 
472     data = res;
473   }
474 };
475 // A wrapper around the implementation class
matMul(NewPlaintextArray & pa,MatMulBase & mat)476 void matMul(NewPlaintextArray& pa, MatMulBase& mat)
477 {
478   switch (mat.getEA().getTag()) {
479     case PA_GF2_tag: {
480       matmul_pa_impl<PA_GF2>::matmul(pa, dynamic_cast< MatMul<PA_GF2>& >(mat));
481       return;
482     }
483     case PA_zz_p_tag: {
484       matmul_pa_impl<PA_zz_p>::matmul(pa,dynamic_cast<MatMul<PA_zz_p>&>(mat));
485       return;
486     }
487     default:
488       throw std::logic_error("mat_mul: neither PA_GF2 nor PA_zz_p");
489   }
490 }
491 
492 
493 #ifdef DEBUG
printCache(const CachedzzxMatrix & cache)494 void printCache(const CachedzzxMatrix& cache)
495 {
496   std::cerr << " zzxCache=[";
497   for (long i=0; i<cache.length(); i++) {
498     if (cache[i]==nullptr)
499       std::cerr << "null ";
500     else {
501       std::cerr << (*(cache[i])) << " ";
502     }
503   }
504   std:cerr << "]\n";
505 }
506 
printCache(const CachedDCRTMatrix & cache)507 void printCache(const CachedDCRTMatrix& cache)
508 {
509   std::cerr << "dcrtCache=[";
510   for (long i=0; i<cache.length(); i++) {
511     if (cache[i]==nullptr)
512       std::cerr << "null ";
513     else {
514       ZZX poly;
515       cache[i]->toPoly(poly);
516       std::cerr << poly << " ";
517     }
518   }
519   std:cerr << "]\n";
520 }
521 #endif
522