1 /* Copyright (C) 2012-2017 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12 /* matmul.cpp - Data-movement operations on arrays of slots
13 */
14 #include <algorithm>
15 #include <NTL/BasicThreadPool.h>
16 #include "matmul.h"
17
18 #ifdef DEBUG
19 void printCache(const CachedzzxMatrix& cache);
20 void printCache(const CachedDCRTMatrix& cache);
21 #endif
22
23
24 /********************************************************************
25 ************* Helper routines to handle caches *********************/
26
lockCache(MatrixCacheType ty)27 bool MatMulBase::lockCache(MatrixCacheType ty)
28 {
29 // Check if we need to build cache (1st time)
30 if (ty==cacheEmpty || hasDCRTcache()
31 || (ty==cachezzX && haszzxcache())) // no need to build
32 return false;
33
34 cachelock.lock(); // Take the lock
35
36 // Check again if we need to build the cache
37 if (ty==cacheEmpty || hasDCRTcache()
38 || (ty==cachezzX && haszzxcache())) { // no need to build
39 cachelock.unlock();
40 return false; // no need to build
41 }
42
43 // We have the lock, can upgrade zzx to dcrt if needed
44 if (ty==cacheDCRT && haszzxcache()) {
45 upgradeCache(); // upgrade zzx to DCRT
46 cachelock.unlock(); // release the lock
47 return false; // already built
48 }
49 return true; // need to build, and we have the lock
50 }
51
52 // This function assumes that we have the lock
upgradeCache()53 void MatMulBase::upgradeCache()
54 {
55 std::unique_ptr<CachedDCRTMatrix> dCache(new CachedDCRTMatrix());
56 dCache->SetLength(zzxCache->length());
57 for (long i=0; i<zzxCache->length(); i++) {
58 if ((*zzxCache)[i] != nullptr)
59 (*dCache)[i].reset(new DoubleCRT(*(*zzxCache)[i], ea.getContext()));
60 }
61 dcrtCache.swap(dCache);
62 }
63 /********************************************************************/
64
65
66 /********************************************************************
67 * An implementation class for dense matmul.
68 *
69 * Using such a class (rather than plain functions) is convenient
70 * since you only need to PA_INJECT once (and the injected names do
71 * not pollute the external namespace).
72 * We also use it here to store some pieces of data between recursive
73 * calls, as well as pointers to temporary caches as we build them.
74 *********************************************************************/
75 template<class type> class matmul_impl {
76 PA_INJECT(type)
77 const MatrixCacheType buildCache;
78 std::unique_ptr<CachedzzxMatrix> zCache;
79 std::unique_ptr<CachedDCRTMatrix> dCache;
80
81 Ctxt* res;
82 MatMul<type>& mat;
83 std::vector<long> dims;
84 const EncryptedArrayDerived<type>& ea;
85
86 // helper class to sort dimensions, so that
87 // - bad dimensions come before good dimensions (primary sort key)
88 // - small dimensions come before large dimesnions (secondary sort key)
89 // this is a good order to process the dimensions in the recursive
90 // mat_mul_dense routine: it ensures that the work done at the leaves
91 // of the recursion is minimized, and that the work done at the non-leaves
92 // is dominated by the work done at the leaves.
93 struct MatMulDimComp {
94 const EncryptedArrayDerived<type> *ea;
MatMulDimCompmatmul_impl::MatMulDimComp95 MatMulDimComp(const EncryptedArrayDerived<type> *_ea) : ea(_ea) {}
96
operator ()matmul_impl::MatMulDimComp97 bool operator()(long i, long j) {
98 return (!ea->nativeDimension(i) && ea->nativeDimension(j)) ||
99 ( (ea->nativeDimension(i) == ea->nativeDimension(j)) &&
100 (ea->sizeOfDimension(i) < ea->sizeOfDimension(j)) );
101 }
102 };
103
104 public:
matmul_impl(Ctxt * c,MatMulBase & _mat,MatrixCacheType tag)105 matmul_impl(Ctxt* c, MatMulBase& _mat, MatrixCacheType tag)
106 : buildCache(tag), res(c),
107 mat(dynamic_cast< MatMul<type>& >(_mat)),
108 ea(_mat.getEA().getDerived(type()))
109 {
110 if (buildCache==cachezzX)
111 zCache.reset(new CachedzzxMatrix(NTL::INIT_SIZE,ea.size()));
112 else if (buildCache==cacheDCRT)
113 dCache.reset(new CachedDCRTMatrix(NTL::INIT_SIZE,ea.size()));
114
115 // dims stores the order of dimensions
116 dims.resize(ea.dimension());
117 for (long i = 0; i < ea.dimension(); i++)
118 dims[i] = i;
119 sort(dims.begin(), dims.end(), MatMulDimComp(&ea));
120 // sort the dimenesions so that bad ones come before good,
121 // and then small ones come before large
122 }
123
124 // Get a diagonal encoded as a single constant
processDiagonal(zzX & epmat,const vector<long> & idxes)125 bool processDiagonal(zzX& epmat, const vector<long>& idxes)
126 {
127 vector<RX> pmat; // the plaintext diagonal
128 pmat.resize(ea.size());
129 bool zDiag = true; // is this a zero diagonal
130 for (long j = 0; j < ea.size(); j++) {
131 long i = idxes[j];
132 RX val;
133 if (mat.get(val, i, j)) // returns true if the entry is zero
134 clear(pmat[j]);
135 else { // not a zero entry
136 pmat[j] = val;
137 zDiag = false; // not a zero diagonal
138 }
139 }
140 // Now we have the constants for all the diagonal entries, encode the
141 // diagonal as a single polynomial with these constants in the slots
142 if (!zDiag) {
143 ea.encode(epmat, pmat);
144 }
145 return zDiag;
146 }
147
148
149 // A recursive matrix-by-vector multiply, used by the dense matrix code.
150 // This routine is optimized to use only the rotate1D routine rather
151 // than the more expensive linear-array rotations.
rec_mul(const Ctxt * pdata,long dim,long idx,const vector<long> & idxes)152 long rec_mul(const Ctxt* pdata, long dim, long idx,
153 const vector<long>& idxes)
154 {
155 if (dim >= ea.dimension()) { // Last dimension (recursion edge condition)
156 zzX pt;
157 zzX* zxPtr=nullptr;
158 DoubleCRT* dxPtr=nullptr;
159
160 // Check if we have the relevant constant in cache
161 CachedzzxMatrix* zcp;
162 CachedDCRTMatrix* dcp;
163 mat.getCache(&zcp, &dcp);
164 if (dcp != nullptr) // DoubleCRT cache exists
165 dxPtr = (*dcp)[idx].get();
166 else if (zcp != nullptr) // zzx cache exists but no DoubleCRT
167 zxPtr = (*zcp)[idx].get();
168 else if (!processDiagonal(pt, idxes)) // no cache, compute const
169 zxPtr = &pt; // if it is not a zero value, point to it
170
171 // if constant is zero, return without doing anything
172 if (zxPtr==nullptr && dxPtr==nullptr)
173 return idx+1;
174
175 // Constant is non-zero, store it in cache and/or multiply/add it
176
177 if (pdata!=nullptr && res!=nullptr) {
178 Ctxt tmp = *pdata;
179 if (dxPtr!=nullptr) tmp.multByConstant(*dxPtr); // mult by DCRT
180 else tmp.multByConstant(*zxPtr); // mult by zzx
181 *res += tmp;
182 }
183
184 if (buildCache==cachezzX) {
185 (*zCache)[idx].reset(new zzX(*zxPtr));
186 }
187 else if (buildCache==cacheDCRT) {
188 (*dCache)[idx].reset(new DoubleCRT(*zxPtr, ea.getContext()));
189 }
190 return idx+1;
191 }
192
193 // not the last dimension, make a recursive call
194 long sdim = ea.sizeOfDimension(dims[dim]);
195
196 // compute "in spirit" sum_i (pdata >> i) * i'th-diagonal, but
197 // adjust the indexes so that we only need to rotate the cipehrtext
198 // along the different dimensions separately
199 for (long offset = 0; offset < sdim; offset++) {
200 vector<long> idxes1;
201 ea.EncryptedArrayBase::rotate1D(idxes1, idxes, dims[dim], offset);
202 if (pdata!=nullptr && res!=nullptr) {
203 Ctxt pdata1 = *pdata;
204 ea.rotate1D(pdata1, dims[dim], offset);
205 // indexes adjusted, make the recursive call
206 idx = rec_mul(&pdata1, dim+1, idx, idxes1);
207 }
208 else // don't bother with the ciphertext
209 idx = rec_mul(pdata, dim+1, idx, idxes1);
210 }
211
212 return idx;
213 }
214
215 // Multiply a ciphertext vector by a plaintext dense matrix
216 // and/or build a cache with the multiplication constants
multilpy(Ctxt * ctxt)217 void multilpy(Ctxt* ctxt)
218 {
219 RBak bak; bak.save(); ea.getTab().restoreContext();
220 // idxes describes a genealized diagonal, {(i,idx[i])}_i
221 // initially just the identity, idx[i]==i
222 vector<long> idxes(ea.size());
223 for (long i = 0; i < ea.size(); i++) idxes[i] = i;
224
225 // call the recursive procedure to do the actual work
226 rec_mul(ctxt, 0, 0, idxes);
227
228 if (ctxt!=nullptr && res!=nullptr)
229 *ctxt = *res; // copy the result back to ctxt
230
231 // "install" the cache (if needed)
232 if (buildCache == cachezzX)
233 mat.installzzxcache(zCache);
234 else if (buildCache == cacheDCRT)
235 mat.installDCRTcache(dCache);
236 }
237 };
238
239 // Wrapper functions around the implemenmtation class
mat_mul(Ctxt * ctxt,MatMulBase & mat,MatrixCacheType buildCache)240 static void mat_mul(Ctxt* ctxt, MatMulBase& mat, MatrixCacheType buildCache)
241 {
242 MatMulLock locking(mat, buildCache);
243
244 // If locking.getType()!=cacheEmpty then we really do need to
245 // build the cache, and we also have the lock for it.
246
247 if (locking.getType() == cacheEmpty && ctxt==nullptr) // nothing to do
248 return;
249
250 std::unique_ptr<Ctxt> res;
251 if (ctxt!=nullptr) { // we need to do an actual multiplication
252 ctxt->cleanUp(); // not sure, but this may be a good idea
253 res.reset(new Ctxt(ZeroCtxtLike, *ctxt));
254 }
255
256 switch (mat.getEA().getTag()) {
257 case PA_GF2_tag: {
258 matmul_impl<PA_GF2> M(res.get(), mat, locking.getType());
259 M.multilpy(ctxt);
260 break;
261 }
262 case PA_zz_p_tag: {
263 matmul_impl<PA_zz_p> M(res.get(), mat, locking.getType());
264 M.multilpy(ctxt);
265 break;
266 }
267 default:
268 throw std::logic_error("mat_mul: neither PA_GF2 nor PA_zz_p");
269 }
270 }
buildCache4MatMul(MatMulBase & mat,MatrixCacheType buildCache)271 void buildCache4MatMul(MatMulBase& mat, MatrixCacheType buildCache)
272 { mat_mul(nullptr, mat, buildCache); }
273
matMul(Ctxt & ctxt,MatMulBase & mat,MatrixCacheType buildCache)274 void matMul(Ctxt& ctxt, MatMulBase& mat, MatrixCacheType buildCache)
275 { mat_mul(&ctxt, mat, buildCache); }
276
277
278
279 /********************************************************************
280 * An implementation class for sparse diagonal matmul.
281 *
282 * Using such a class (rather than plain functions) is convenient
283 * since you only need to PA_INJECT once (and the injected names do
284 * not pollute the external namespace). We also use it here to store
285 * pointers to temporary caches as we build them.
286 ********************************************************************/
287 template<class type> class matmul_sparse_impl {
288 PA_INJECT(type)
289 const MatrixCacheType buildCache;
290 std::unique_ptr<CachedzzxMatrix> zCache;
291 std::unique_ptr<CachedDCRTMatrix> dCache;
292
293 MatMul<type>& mat;
294 const EncryptedArrayDerived<type>& ea;
295 public:
matmul_sparse_impl(MatMulBase & _mat,MatrixCacheType tag)296 matmul_sparse_impl(MatMulBase& _mat, MatrixCacheType tag)
297 : buildCache(tag), mat(dynamic_cast< MatMul<type>& >(_mat)),
298 ea(_mat.getEA().getDerived(type()))
299 {
300 if (buildCache==cachezzX)
301 zCache.reset(new CachedzzxMatrix(NTL::INIT_SIZE,ea.size()));
302 else if (buildCache==cacheDCRT)
303 dCache.reset(new CachedDCRTMatrix(NTL::INIT_SIZE,ea.size()));
304 }
305
306 // Get a diagonal encoded as a single constant
processDiagonal(zzX & cPoly,long diagIdx,long nslots)307 bool processDiagonal(zzX& cPoly, long diagIdx, long nslots)
308 {
309 bool zDiag = true; // is this a zero diagonal
310 vector<RX> diag(ea.size()); // the plaintext diagonal
311
312 for (long j = 0; j < nslots; j++) { // process entry j
313 long ii = mcMod(j-diagIdx, nslots); // j-diagIdx mod nslots
314 bool zEntry = mat.get(diag[j], ii, j); // callback
315 assert(zEntry || deg(diag[j]) < ea.getDegree());
316
317 if (zEntry) clear(diag[j]);
318 else if (!IsZero(diag[j]))
319 zDiag = false; // diagonal is non-zero
320 }
321 // Now we have the constants for all the diagonal entries, encode the
322 // diagonal as a single polynomial with these constants in the slots
323 if (!zDiag) {
324 ea.encode(cPoly, diag);
325 }
326 return zDiag;
327 }
328
multiply(Ctxt * ctxt)329 void multiply(Ctxt* ctxt)
330 {
331 RBak bak; bak.save(); ea.getTab().restoreContext();
332
333 long nslots = ea.size();
334 bool sequential = (ea.dimension()==1) && ea.nativeDimension(0);
335 // if just a single native dimension, then rotate adds only little noise
336
337 std::unique_ptr<Ctxt> res, shCtxt;
338 if (ctxt!=nullptr) { // we need to do an actual multiplication
339 ctxt->cleanUp(); // not sure, but this may be a good idea
340 res.reset(new Ctxt(ZeroCtxtLike, *ctxt));
341 shCtxt.reset(new Ctxt(*ctxt));
342 }
343
344 // Check if we have the relevant constant in cache
345 CachedzzxMatrix* zcp;
346 CachedDCRTMatrix* dcp;
347 mat.getCache(&zcp, &dcp);
348
349 // Process the diagonals one at a time
350 long lastRotate = 0;
351 for (long i = 0; i < nslots; i++) { // process diagonal i
352 zzX cpoly;
353 zzX* zxPtr=nullptr;
354 DoubleCRT* dxPtr=nullptr;
355
356 if (dcp != nullptr) // DoubleCRT cache exists
357 dxPtr = (*dcp)[i].get();
358 else if (zcp != nullptr) // zzx cache exists but no DoubleCRT
359 zxPtr = (*zcp)[i].get();
360 else { // no cache, compute const
361 if (!processDiagonal(cpoly,i,nslots)) { // returns true if zero
362 zxPtr = &cpoly; // if it is not a zero value, point to it
363 }
364 }
365
366 // if zero diagonal, nothing to do for this iteration
367 if (zxPtr==nullptr && dxPtr==nullptr)
368 continue;
369
370 // Non-zero diagonal, store it in cache and/or multiply/add it
371
372 if (ctxt!=nullptr && res!=nullptr) {
373 // rotate by i, multiply by the polynomial, then add to the result
374 if (i>0) {
375 if (sequential) {
376 ea.rotate(*ctxt, i-lastRotate);
377 *shCtxt = *ctxt;
378 } else {
379 *shCtxt = *ctxt;
380 ea.rotate(*shCtxt, i); // rotate by i
381 }
382 lastRotate = i;
383 } // if i==0 we already have *shCtxt == *ctxt
384
385 if (dxPtr!=nullptr) shCtxt->multByConstant(*dxPtr);
386 else shCtxt->multByConstant(*zxPtr);
387 *res += *shCtxt;
388 }
389 if (buildCache==cachezzX) {
390 (*zCache)[i].reset(new zzX(*zxPtr));
391 }
392 else if (buildCache==cacheDCRT) {
393 (*dCache)[i].reset(new DoubleCRT(*zxPtr, ea.getContext()));
394 }
395 }
396
397 if (ctxt!=nullptr && res!=nullptr) // copy result back to ctxt
398 *ctxt = *res;
399
400 // "install" the cache (if needed)
401 if (buildCache == cachezzX)
402 mat.installzzxcache(zCache);
403 else if (buildCache == cacheDCRT)
404 mat.installDCRTcache(dCache);
405 } // end of multiply(...)
406 };
407
408 // Wrapper functions around the implemenmtation class
409 static void
mat_mul_sparse(Ctxt * ctxt,MatMulBase & mat,MatrixCacheType buildCache)410 mat_mul_sparse(Ctxt* ctxt, MatMulBase& mat, MatrixCacheType buildCache)
411 {
412 MatMulLock locking(mat, buildCache);
413
414 // If locking.getType()!=cacheEmpty then we really do need to
415 // build the cache, and we also have the lock for it.
416
417 if (locking.getType() == cacheEmpty && ctxt==nullptr) // nothing to do
418 return;
419
420 switch (mat.getEA().getTag()) {
421 case PA_GF2_tag: {
422 matmul_sparse_impl<PA_GF2> M(mat, locking.getType());
423 M.multiply(ctxt);
424 break;
425 }
426 case PA_zz_p_tag: {
427 matmul_sparse_impl<PA_zz_p> M(mat, locking.getType());
428 M.multiply(ctxt);
429 break;
430 }
431 default:
432 throw std::logic_error("mat_mul_sparse: neither PA_GF2 nor PA_zz_p");
433 }
434 }
435 // Same as matMul but optimized for matrices with few non-zero diagonals
matMul_sparse(Ctxt & ctxt,MatMulBase & mat,MatrixCacheType buildCache)436 void matMul_sparse(Ctxt& ctxt, MatMulBase& mat,
437 MatrixCacheType buildCache)
438 { mat_mul_sparse(&ctxt, mat, buildCache); }
439
440 // Build a cache without performing multiplication
buildCache4MatMul_sparse(MatMulBase & mat,MatrixCacheType buildCache)441 void buildCache4MatMul_sparse(MatMulBase& mat, MatrixCacheType buildCache)
442 { mat_mul_sparse(nullptr, mat, buildCache); }
443
444
445 /********************************************************************
446 ********************************************************************/
447 // Applying matmul to plaintext, useful for debugging
448 template<class type> class matmul_pa_impl {
449 public:
PA_INJECT(type)450 PA_INJECT(type)
451
452 static void matmul(NewPlaintextArray& pa, MatMul<type>& mat)
453 {
454 const EncryptedArrayDerived<type>& ea = mat.getEA().getDerived(type());
455 PA_BOILER
456
457 vector<RX> res;
458 res.resize(n);
459 for (long j = 0; j < n; j++) {
460 RX acc, val, tmp;
461 acc = 0;
462 for (long i = 0; i < n; i++) {
463 if (!mat.get(val, i, j)) {
464 NTL::mul(tmp, data[i], val);
465 NTL::add(acc, acc, tmp);
466 }
467 }
468 rem(acc, acc, G);
469 res[j] = acc;
470 }
471
472 data = res;
473 }
474 };
475 // A wrapper around the implementation class
matMul(NewPlaintextArray & pa,MatMulBase & mat)476 void matMul(NewPlaintextArray& pa, MatMulBase& mat)
477 {
478 switch (mat.getEA().getTag()) {
479 case PA_GF2_tag: {
480 matmul_pa_impl<PA_GF2>::matmul(pa, dynamic_cast< MatMul<PA_GF2>& >(mat));
481 return;
482 }
483 case PA_zz_p_tag: {
484 matmul_pa_impl<PA_zz_p>::matmul(pa,dynamic_cast<MatMul<PA_zz_p>&>(mat));
485 return;
486 }
487 default:
488 throw std::logic_error("mat_mul: neither PA_GF2 nor PA_zz_p");
489 }
490 }
491
492
493 #ifdef DEBUG
printCache(const CachedzzxMatrix & cache)494 void printCache(const CachedzzxMatrix& cache)
495 {
496 std::cerr << " zzxCache=[";
497 for (long i=0; i<cache.length(); i++) {
498 if (cache[i]==nullptr)
499 std::cerr << "null ";
500 else {
501 std::cerr << (*(cache[i])) << " ";
502 }
503 }
504 std:cerr << "]\n";
505 }
506
printCache(const CachedDCRTMatrix & cache)507 void printCache(const CachedDCRTMatrix& cache)
508 {
509 std::cerr << "dcrtCache=[";
510 for (long i=0; i<cache.length(); i++) {
511 if (cache[i]==nullptr)
512 std::cerr << "null ";
513 else {
514 ZZX poly;
515 cache[i]->toPoly(poly);
516 std::cerr << poly << " ";
517 }
518 }
519 std:cerr << "]\n";
520 }
521 #endif
522