1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 #include <NTL/ZZ.h>
13 #include <algorithm>
14 #include <complex>
15 
16 #include <helib/norms.h>
17 #include <helib/helib.h>
18 #include <helib/debugging.h>
19 #include <helib/ArgMap.h>
20 
21 NTL_CLIENT
22 using namespace helib;
23 
24 bool verbose=false;
25 
26 // Compute the L-infinity distance between two vectors
calcMaxDiff(const vector<cx_double> & v1,const vector<cx_double> & v2)27 double calcMaxDiff(const vector<cx_double>& v1,
28                    const vector<cx_double>& v2){
29 
30   if(lsize(v1)!=lsize(v2))
31     NTL::Error("Vector sizes differ.\nFAILED\n");
32 
33   double maxDiff = 0.0;
34   for (long i=0; i<lsize(v1); i++) {
35     double diffAbs = std::abs(v1[i]-v2[i]);
36     if (diffAbs > maxDiff)
37       maxDiff = diffAbs;
38   }
39 
40   return maxDiff;
41 }
42 // Compute the max relative difference between two vectors
calcMaxRelDiff(const vector<cx_double> & v1,const vector<cx_double> & v2)43 double calcMaxRelDiff(const vector<cx_double>& v1,
44                    const vector<cx_double>& v2)
45 {
46     if(lsize(v1)!=lsize(v2))
47         NTL::Error("Vector sizes differ.\nFAILED\n");
48 
49     // Compute the largest-magnitude value in the vector
50     double maxAbs = 0.0;
51     for (auto& x : v1) {
52         if (std::abs(x) > maxAbs)
53             maxAbs = std::abs(x);
54     }
55     if (maxAbs<1e-10)
56         maxAbs = 1e-10;
57 
58     double maxDiff = 0.0;
59     for (long i=0; i<lsize(v1); i++) {
60         double relDiff = std::abs(v1[i]-v2[i]) / maxAbs;
61         if (relDiff > maxDiff)
62             maxDiff = relDiff;
63     }
64 
65     return maxDiff;
66 }
67 
cx_equals(const vector<cx_double> & v1,const vector<cx_double> & v2,double epsilon)68 inline bool cx_equals(const vector<cx_double>& v1,
69                       const vector<cx_double>& v2,
70                       double epsilon)
71 {
72   return (calcMaxRelDiff(v1,v2) < epsilon);
73 }
74 
75 void testBasicArith(const PubKey& publicKey,
76                     const SecKey& secretKey,
77                     const EncryptedArrayCx& ea, double epsilon);
78 void testComplexArith(const PubKey& publicKey,
79                       const SecKey& secretKey,
80                       const EncryptedArrayCx& ea, double epsilon);
81 void testRotsNShifts(const PubKey& publicKey,
82                      const SecKey& secretKey,
83                      const EncryptedArrayCx& ea, double epsilon);
84 
debugCompare(const EncryptedArrayCx & ea,const SecKey & sk,vector<cx_double> & p,const Ctxt & c,double epsilon)85 void debugCompare(const EncryptedArrayCx& ea, const SecKey& sk,
86         vector<cx_double>& p, const Ctxt& c, double epsilon)
87 {
88   vector<cx_double> pp;
89   ea.decrypt(c, sk, pp);
90   std::cout << "    relative-error="<<calcMaxRelDiff(p,pp)
91             << ", absolute-error="<<calcMaxRelDiff(p,pp)<<endl;
92 //  if (!cx_equals(pp, p, epsilon)) {
93 //    std::cout << "oops:\n"; std::cout << p << "\n";
94 //    std::cout << pp << "\n";
95 //    exit(0);
96 //  }
97 }
98 
99 
negateVec(vector<cx_double> & p1)100 void negateVec(vector<cx_double>& p1)
101 {
102   for (auto& x: p1) x = -x;
103 }
add(vector<cx_double> & to,const vector<cx_double> & from)104 void add(vector<cx_double>& to, const vector<cx_double>& from)
105 {
106   if (to.size() < from.size())
107     to.resize(from.size(), 0);
108   for (long i=0; i<from.size(); i++) to[i] += from[i];
109 }
sub(vector<cx_double> & to,const vector<cx_double> & from)110 void sub(vector<cx_double>& to, const vector<cx_double>& from)
111 {
112   if (to.size() < from.size())
113     to.resize(from.size(), 0);
114   for (long i=0; i<from.size(); i++) to[i] -= from[i];
115 }
mul(vector<cx_double> & to,const vector<cx_double> & from)116 void mul(vector<cx_double>& to, const vector<cx_double>& from)
117 {
118   if (to.size() < from.size())
119     to.resize(from.size(), 0);
120   for (long i=0; i<from.size(); i++) to[i] *= from[i];
121 }
rotate(vector<cx_double> & p,long amt)122 void rotate(vector<cx_double>& p, long amt)
123 {
124   long sz = p.size();
125   vector<cx_double> tmp(sz);
126   for (long i=0; i<sz; i++)
127     tmp[((i+amt)%sz +sz)%sz] = p[i];
128   p = tmp;
129 }
130 
131 /************** Each round consists of the following:
132 1. c1.multiplyBy(c0)
133 2. c0 += random constant
134 3. c2 *= random constant
135 4. tmp = c1
136 5. ea.rotate(tmp, random amount in [-nSlots/2, nSlots/2])
137 6. c2 += tmp
138 7. ea.rotate(c2, random amount in [1-nSlots, nSlots-1])
139 8. c1.negate()
140 9. c3.multiplyBy(c2)
141 10. c0 -= c3
142 **************/
testGeneralOps(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon,long nRounds)143 void testGeneralOps(const PubKey& publicKey, const SecKey& secretKey,
144                     const EncryptedArrayCx& ea, double epsilon,
145                     long nRounds)
146 {
147   long nslots = ea.size();
148   char buffer[32];
149 
150   vector<cx_double> p0, p1, p2, p3;
151   ea.random(p0);
152   ea.random(p1);
153   ea.random(p2);
154   ea.random(p3);
155 
156   Ctxt c0(publicKey), c1(publicKey), c2(publicKey), c3(publicKey);
157   ea.encrypt(c0, publicKey, p0, /*size=*/1.0);
158   ea.encrypt(c1, publicKey, p1, /*size=*/1.0);
159   ea.encrypt(c2, publicKey, p2, /*size=*/1.0);
160   ea.encrypt(c3, publicKey, p3, /*size=*/1.0);
161 
162   resetAllTimers();
163   HELIB_NTIMER_START(Circuit);
164 
165   for (long i = 0; i < nRounds; i++) {
166 
167     if (verbose) std::cout << "*** round " << i << "..."<<endl;
168 
169      long shamt = RandomBnd(2*(nslots/2) + 1) - (nslots/2);
170                   // random number in [-nslots/2..nslots/2]
171      long rotamt = RandomBnd(2*nslots - 1) - (nslots - 1);
172                   // random number in [-(nslots-1)..nslots-1]
173 
174      // two random constants
175      vector<cx_double> const1, const2;
176      ea.random(const1);
177      ea.random(const2);
178 
179      ZZX const1_poly, const2_poly;
180      ea.encode(const1_poly, const1, /*size=*/1.0);
181      ea.encode(const2_poly, const2, /*size=*/1.0);
182 
183      mul(p1, p0);     // c1.multiplyBy(c0)
184      c1.multiplyBy(c0);
185      if (verbose) {
186        CheckCtxt(c1, "c1*=c0");
187        debugCompare(ea, secretKey, p1, c1, epsilon);
188      }
189 
190      add(p0, const1); // c0 += random constant
191      c0.addConstant(const1_poly);
192      if (verbose) {
193        CheckCtxt(c0, "c0+=k1");
194        debugCompare(ea, secretKey, p0, c0, epsilon);
195      }
196      mul(p2, const2); // c2 *= random constant
197      c2.multByConstant(const2_poly);
198      if (verbose) {
199        CheckCtxt(c2, "c2*=k2");
200        debugCompare(ea, secretKey, p2, c2, epsilon);
201      }
202      vector<cx_double> tmp_p(p1); // tmp = c1
203      Ctxt tmp(c1);
204      sprintf(buffer, "tmp=c1>>=%d", (int)shamt);
205      rotate(tmp_p, shamt); // ea.shift(tmp, random amount in [-nSlots/2,nSlots/2])
206      ea.rotate(tmp, shamt);
207      if (verbose) {
208        CheckCtxt(tmp, buffer);
209        debugCompare(ea, secretKey, tmp_p, tmp, epsilon);
210      }
211      add(p2, tmp_p);  // c2 += tmp
212      c2 += tmp;
213      if (verbose) {
214        CheckCtxt(c2, "c2+=tmp");
215        debugCompare(ea, secretKey, p2, c2, epsilon);
216      }
217      sprintf(buffer, "c2>>>=%d", (int)rotamt);
218      rotate(p2, rotamt); // ea.rotate(c2, random amount in [1-nSlots, nSlots-1])
219      ea.rotate(c2, rotamt);
220      if (verbose) {
221        CheckCtxt(c2, buffer);
222        debugCompare(ea, secretKey, p2, c2, epsilon);
223      }
224      negateVec(p1); // c1.negate()
225      c1.negate();
226      if (verbose) {
227        CheckCtxt(c1, "c1=-c1");
228        debugCompare(ea, secretKey, p1, c1, epsilon);
229      }
230      mul(p3, p2); // c3.multiplyBy(c2)
231      c3.multiplyBy(c2);
232      if (verbose) {
233        CheckCtxt(c3, "c3*=c2");
234        debugCompare(ea, secretKey, p3, c3, epsilon);
235      }
236      sub(p0, p3); // c0 -= c3
237      c0 -= c3;
238      if (verbose) {
239        CheckCtxt(c0, "c0=-c3");
240        debugCompare(ea, secretKey, p0, c0, epsilon);
241      }
242   }
243 
244   c0.cleanUp();
245   c1.cleanUp();
246   c2.cleanUp();
247   c3.cleanUp();
248 
249   HELIB_NTIMER_STOP(Circuit);
250 
251   vector<cx_double> pp0, pp1, pp2, pp3;
252 
253   ea.decrypt(c0, secretKey, pp0);
254   ea.decrypt(c1, secretKey, pp1);
255   ea.decrypt(c2, secretKey, pp2);
256   ea.decrypt(c3, secretKey, pp3);
257 
258   std::cout << "Test "<<nRounds<<" rounds of mixed operations, ";
259   if (cx_equals(pp0, p0,conv<double>(epsilon*c0.getPtxtMag()))
260       && cx_equals(pp1, p1,conv<double>(epsilon*c1.getPtxtMag()))
261       && cx_equals(pp2, p2,conv<double>(epsilon*c2.getPtxtMag()))
262       && cx_equals(pp3, p3,conv<double>(epsilon*c3.getPtxtMag())))
263     std::cout << "PASS\n\n";
264   else {
265     std::cout << "FAIL:\n";
266     std::cout << "  max(p0)="<<largestCoeff(p0)
267               << ", max(pp0)="<<largestCoeff(pp0)
268               << ", maxDiff="<<calcMaxDiff(p0,pp0) << endl;
269     std::cout << "  max(p1)="<<largestCoeff(p1)
270               << ", max(pp1)="<<largestCoeff(pp1)
271               << ", maxDiff="<<calcMaxDiff(p1,pp1) << endl;
272     std::cout << "  max(p2)="<<largestCoeff(p2)
273               << ", max(pp2)="<<largestCoeff(pp2)
274               << ", maxDiff="<<calcMaxDiff(p2,pp2) << endl;
275     std::cout << "  max(p3)="<<largestCoeff(p3)
276               << ", max(pp3)="<<largestCoeff(pp3)
277               << ", maxDiff="<<calcMaxDiff(p3,pp3) << endl<<endl;
278   }
279 
280   if (verbose) {
281     std::cout << endl;
282     printAllTimers();
283     std::cout << endl;
284   }
285   resetAllTimers();
286    }
287 
main(int argc,char * argv[])288 int main(int argc, char *argv[])
289 {
290 
291   // Commandline setup
292 
293   ArgMap amap;
294 
295   long m=16;
296   long r=8;
297   long L=0;
298   double epsilon=0.01; // Accepted accuracy
299   long R=1;
300   long seed=0;
301   bool debug = false;
302 
303   amap.arg("m", m, "Cyclotomic index");
304   amap.note("e.g., m=1024, m=2047");
305   amap.arg("r", r, "Bits of precision");
306   amap.arg("R", R, "number of rounds");
307   amap.arg("L", L, "Number of bits in modulus", "heuristic");
308   amap.arg("ep", epsilon, "Accepted accuracy");
309   amap.arg("seed", seed, "PRG seed");
310   amap.arg("verbose", verbose, "more printouts");
311   amap.arg("debug", debug, "for debugging");
312 
313   amap.parse(argc, argv);
314 
315   if (seed)
316     NTL::SetSeed(ZZ(seed));
317 
318   if (R<=0) R=1;
319   if (R<=2)
320     L = 100*R;
321   else
322     L = 220*(R-1);
323 
324   if (verbose) {
325     cout << "** m="<<m<<", #rounds="<<R<<", |q|="<<L
326          << ", epsilon="<<epsilon<<endl;
327   }
328   epsilon /= R;
329   try{
330 
331     // FHE setup keys, context, SKMs, etc
332 
333     Context context(m, /*p=*/-1, r);
334     context.scale=4;
335     buildModChain(context, L, /*c=*/2);
336 
337     SecKey secretKey(context);
338     secretKey.GenSecKey(); // A +-1/0 secret key
339     addSome1DMatrices(secretKey); // compute key-switching matrices
340 
341     const PubKey publicKey = secretKey;
342     const EncryptedArrayCx& ea = context.ea->getCx();
343 
344     if (verbose) {
345       ea.getPAlgebra().printout();
346       cout << "r = " << context.alMod.getR() << endl;
347       cout << "ctxtPrimes="<<context.ctxtPrimes
348            << ", specialPrimes="<<context.specialPrimes<<endl<<endl;
349     }
350     if (debug) {
351         dbgKey = & secretKey;
352         dbgEa = context.ea;
353     }
354 #ifdef HELIB_DEBUG
355           dbgKey = & secretKey;
356           dbgEa = context.ea;
357 #endif //HELIB_DEBUG
358 
359     // Run the tests.
360     testBasicArith(publicKey, secretKey, ea, epsilon);
361     testComplexArith(publicKey, secretKey, ea, epsilon);
362     testRotsNShifts(publicKey, secretKey, ea, epsilon);
363     testGeneralOps(publicKey, secretKey, ea, epsilon*R, R);
364   }
365   catch (exception& e) {
366     cerr << e.what() << endl;
367     cerr << "***Major FAIL***" << endl;
368   }
369 
370   return 0;
371 }
372 
373 
testBasicArith(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon)374 void testBasicArith(const PubKey& publicKey,
375                     const SecKey& secretKey,
376                     const EncryptedArrayCx& ea, double epsilon)
377 {
378   if (verbose)  cout << "Test Arithmetic ";
379   // Test objects
380 
381   Ctxt c1(publicKey), c2(publicKey), c3(publicKey);
382 
383   vector<cx_double> vd;
384   vector<cx_double> vd1, vd2, vd3;
385   ea.random(vd1);
386   ea.random(vd2);
387 
388   // test encoding of shorter vectors
389   vd1.resize(vd1.size()-2);
390   ea.encrypt(c1, publicKey, vd1, /*size=*/1.0);
391   vd1.resize(vd1.size()+2, 0.0);
392 
393   ea.encrypt(c2, publicKey, vd2, /*size=*/1.0);
394 
395   // Test - Multiplication
396   c1 *= c2;
397   for (long i=0; i<lsize(vd1); i++) vd1[i] *= vd2[i];
398 
399   ZZX poly;
400   ea.random(vd3);
401   ea.encode(poly, vd3, /*size=*/1.0);
402   c1.addConstant(poly); // vd1*vd2 + vd3
403   for (long i=0; i<lsize(vd1); i++) vd1[i] += vd3[i];
404 
405   // Test encoding, encryption of a single number
406   double xx = NTL::RandomLen_long(16)/double(1L<<16); // random in [0,1]
407   ea.encryptOneNum(c2, publicKey, xx);
408   c1 += c2;
409   for (auto& x : vd1) x += xx;
410 
411   // Test - Multiply by a mask
412   vector<long> mask(lsize(vd1), 1);
413   for (long i=0; i*(i+1)<lsize(mask); i++) {
414     mask[i*i] = 0;
415     mask[i*(i+1)] = -1;
416   }
417 
418   ea.encode(poly,mask, /*size=*/1.0);
419   c1.multByConstant(poly); // mask*(vd1*vd2 + vd3)
420   for (long i=0; i<lsize(vd1); i++) vd1[i] *= mask[i];
421 
422   // Test - Addition
423   ea.random(vd3);
424   ea.encrypt(c3, publicKey, vd3, /*size=*/1.0);
425   c1 += c3;
426   for (long i=0; i<lsize(vd1); i++) vd1[i] += vd3[i];
427 
428   c1.negate();
429   c1.addConstant(to_ZZ(1));
430   for (long i=0; i<lsize(vd1); i++) vd1[i] = 1.0 - vd1[i];
431 
432   // Diff between approxNums HE scheme and plaintext floating
433   ea.decrypt(c1, secretKey, vd);
434 #ifdef HELIB_DEBUG
435   printVec(cout<<"res=", vd, 10)<<endl;
436   printVec(cout<<"vec=", vd1, 10)<<endl;
437 #endif
438   if (verbose)
439     cout << "(max |res-vec|_{infty}="<< calcMaxDiff(vd, vd1) << "): ";
440 
441   if (cx_equals(vd, vd1, conv<double>(epsilon*c1.getPtxtMag())))
442     cout << "GOOD\n";
443   else {
444     cout << "BAD:\n";
445     std::cout << "  max(vd)="<<largestCoeff(vd)
446               << ", max(vd1)="<<largestCoeff(vd1)
447               << ", maxDiff="<<calcMaxDiff(vd,vd1) << endl<<endl;
448   }
449 }
450 
451 
testComplexArith(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon)452 void testComplexArith(const PubKey& publicKey,
453                       const SecKey& secretKey,
454                       const EncryptedArrayCx& ea, double epsilon)
455 {
456 
457   // Test complex conjugate
458   Ctxt c1(publicKey), c2(publicKey);
459 
460   vector<cx_double> vd;
461   vector<cx_double> vd1, vd2;
462   ea.random(vd1);
463   ea.random(vd2);
464 
465   ea.encrypt(c1, publicKey, vd1, /*size=*/1.0);
466   ea.encrypt(c2, publicKey, vd2, /*size=*/1.0);
467 
468   if (verbose)
469     cout << "Test Conjugate: ";
470   for_each(vd1.begin(), vd1.end(), [](cx_double& d){d=std::conj(d);});
471   c1.complexConj();
472   ea.decrypt(c1, secretKey, vd);
473 #ifdef HELIB_DEBUG
474   printVec(cout<<"vd1=", vd1, 10)<<endl;
475   printVec(cout<<"res=", vd, 10)<<endl;
476 #endif
477   if (cx_equals(vd, vd1, conv<double>(epsilon*c1.getPtxtMag())))
478     cout << "GOOD\n";
479   else {
480     cout << "BAD:\n";
481     std::cout << "  max(vd)="<<largestCoeff(vd)
482               << ", max(vd1)="<<largestCoeff(vd1)
483               << ", maxDiff="<<calcMaxDiff(vd,vd1) << endl<<endl;
484   }
485 
486   // Test that real and imaginary parts are actually extracted.
487   Ctxt realCtxt(c2), imCtxt(c2);
488   vector<cx_double> realParts(vd2), real_dec;
489   vector<cx_double> imParts(vd2), im_dec;
490 
491   if (verbose)
492     cout << "Test Real and Im parts: ";
493   for_each(realParts.begin(), realParts.end(), [](cx_double& d){d=std::real(d);});
494   for_each(imParts.begin(), imParts.end(), [](cx_double& d){d=std::imag(d);});
495 
496   ea.extractRealPart(realCtxt);
497   ea.decrypt(realCtxt, secretKey, real_dec);
498 
499   ea.extractImPart(imCtxt);
500   ea.decrypt(imCtxt, secretKey, im_dec);
501 
502 #ifdef HELIB_DEBUG
503   printVec(cout<<"vd2=", vd2, 10)<<endl;
504   printVec(cout<<"real=", realParts, 10)<<endl;
505   printVec(cout<<"res=", real_dec, 10)<<endl;
506   printVec(cout<<"im=", imParts, 10)<<endl;
507   printVec(cout<<"res=", im_dec, 10)<<endl;
508 #endif
509   if (cx_equals(realParts,real_dec,conv<double>(epsilon*realCtxt.getPtxtMag()))
510       && cx_equals(imParts, im_dec, conv<double>(epsilon*imCtxt.getPtxtMag())))
511     cout << "GOOD\n";
512   else {
513     cout << "BAD:\n";
514     std::cout << "  max(re)="<<largestCoeff(realParts)
515               << ", max(re1)="<<largestCoeff(real_dec)
516               << ", maxDiff="<<calcMaxDiff(realParts,real_dec) << endl;
517     std::cout << "  max(im)="<<largestCoeff(imParts)
518               << ", max(im1)="<<largestCoeff(im_dec)
519               << ", maxDiff="<<calcMaxDiff(imParts,im_dec) << endl<<endl;
520   }
521 }
522 
testRotsNShifts(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon)523 void testRotsNShifts(const PubKey& publicKey,
524                      const SecKey& secretKey,
525                      const EncryptedArrayCx& ea, double epsilon)
526 {
527 
528   std::srand(std::time(0)); // set seed, current time.
529   int nplaces = rand() % static_cast<int>(ea.size()/2.0) + 1;
530 
531   if (verbose)
532     cout << "Test Rotation of " << nplaces << ": ";
533 
534   Ctxt c1(publicKey);
535   vector<cx_double> vd1;
536   vector<cx_double> vd_dec;
537   ea.random(vd1);
538   ea.encrypt(c1, publicKey, vd1, /*size=*/1.0);
539 
540 #ifdef HELIB_DEBUG
541   printVec(cout<< "vd1=", vd1, 10)<<endl;
542 #endif
543   std::rotate(vd1.begin(), vd1.end()-nplaces, vd1.end());
544   ea.rotate(c1, nplaces);
545   c1.reLinearize();
546   ea.decrypt(c1, secretKey, vd_dec);
547 #ifdef HELIB_DEBUG
548   printVec(cout<< "vd1(rot)=", vd1, 10)<<endl;
549   printVec(cout<<"res: ", vd_dec, 10)<<endl;
550 #endif
551 
552   if (cx_equals(vd1, vd_dec, conv<double>(epsilon*c1.getPtxtMag())))
553     cout << "GOOD\n";
554   else {
555     cout << "BAD:\n";
556     std::cout << "  max(vd)="<<largestCoeff(vd_dec)
557               << ", max(vd1)="<<largestCoeff(vd1)
558               << ", maxDiff="<<calcMaxDiff(vd_dec,vd1) << endl<<endl;
559   }
560 }
561