1 /* Copyright (C) 2012-2020 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12 #include <NTL/ZZ.h>
13 #include <algorithm>
14 #include <complex>
15
16 #include <helib/norms.h>
17 #include <helib/helib.h>
18 #include <helib/debugging.h>
19 #include <helib/ArgMap.h>
20
21 NTL_CLIENT
22 using namespace helib;
23
24 bool verbose=false;
25
26 // Compute the L-infinity distance between two vectors
calcMaxDiff(const vector<cx_double> & v1,const vector<cx_double> & v2)27 double calcMaxDiff(const vector<cx_double>& v1,
28 const vector<cx_double>& v2){
29
30 if(lsize(v1)!=lsize(v2))
31 NTL::Error("Vector sizes differ.\nFAILED\n");
32
33 double maxDiff = 0.0;
34 for (long i=0; i<lsize(v1); i++) {
35 double diffAbs = std::abs(v1[i]-v2[i]);
36 if (diffAbs > maxDiff)
37 maxDiff = diffAbs;
38 }
39
40 return maxDiff;
41 }
42 // Compute the max relative difference between two vectors
calcMaxRelDiff(const vector<cx_double> & v1,const vector<cx_double> & v2)43 double calcMaxRelDiff(const vector<cx_double>& v1,
44 const vector<cx_double>& v2)
45 {
46 if(lsize(v1)!=lsize(v2))
47 NTL::Error("Vector sizes differ.\nFAILED\n");
48
49 // Compute the largest-magnitude value in the vector
50 double maxAbs = 0.0;
51 for (auto& x : v1) {
52 if (std::abs(x) > maxAbs)
53 maxAbs = std::abs(x);
54 }
55 if (maxAbs<1e-10)
56 maxAbs = 1e-10;
57
58 double maxDiff = 0.0;
59 for (long i=0; i<lsize(v1); i++) {
60 double relDiff = std::abs(v1[i]-v2[i]) / maxAbs;
61 if (relDiff > maxDiff)
62 maxDiff = relDiff;
63 }
64
65 return maxDiff;
66 }
67
cx_equals(const vector<cx_double> & v1,const vector<cx_double> & v2,double epsilon)68 inline bool cx_equals(const vector<cx_double>& v1,
69 const vector<cx_double>& v2,
70 double epsilon)
71 {
72 return (calcMaxRelDiff(v1,v2) < epsilon);
73 }
74
75 void testBasicArith(const PubKey& publicKey,
76 const SecKey& secretKey,
77 const EncryptedArrayCx& ea, double epsilon);
78 void testComplexArith(const PubKey& publicKey,
79 const SecKey& secretKey,
80 const EncryptedArrayCx& ea, double epsilon);
81 void testRotsNShifts(const PubKey& publicKey,
82 const SecKey& secretKey,
83 const EncryptedArrayCx& ea, double epsilon);
84
debugCompare(const EncryptedArrayCx & ea,const SecKey & sk,vector<cx_double> & p,const Ctxt & c,double epsilon)85 void debugCompare(const EncryptedArrayCx& ea, const SecKey& sk,
86 vector<cx_double>& p, const Ctxt& c, double epsilon)
87 {
88 vector<cx_double> pp;
89 ea.decrypt(c, sk, pp);
90 std::cout << " relative-error="<<calcMaxRelDiff(p,pp)
91 << ", absolute-error="<<calcMaxRelDiff(p,pp)<<endl;
92 // if (!cx_equals(pp, p, epsilon)) {
93 // std::cout << "oops:\n"; std::cout << p << "\n";
94 // std::cout << pp << "\n";
95 // exit(0);
96 // }
97 }
98
99
negateVec(vector<cx_double> & p1)100 void negateVec(vector<cx_double>& p1)
101 {
102 for (auto& x: p1) x = -x;
103 }
add(vector<cx_double> & to,const vector<cx_double> & from)104 void add(vector<cx_double>& to, const vector<cx_double>& from)
105 {
106 if (to.size() < from.size())
107 to.resize(from.size(), 0);
108 for (long i=0; i<from.size(); i++) to[i] += from[i];
109 }
sub(vector<cx_double> & to,const vector<cx_double> & from)110 void sub(vector<cx_double>& to, const vector<cx_double>& from)
111 {
112 if (to.size() < from.size())
113 to.resize(from.size(), 0);
114 for (long i=0; i<from.size(); i++) to[i] -= from[i];
115 }
mul(vector<cx_double> & to,const vector<cx_double> & from)116 void mul(vector<cx_double>& to, const vector<cx_double>& from)
117 {
118 if (to.size() < from.size())
119 to.resize(from.size(), 0);
120 for (long i=0; i<from.size(); i++) to[i] *= from[i];
121 }
rotate(vector<cx_double> & p,long amt)122 void rotate(vector<cx_double>& p, long amt)
123 {
124 long sz = p.size();
125 vector<cx_double> tmp(sz);
126 for (long i=0; i<sz; i++)
127 tmp[((i+amt)%sz +sz)%sz] = p[i];
128 p = tmp;
129 }
130
131 /************** Each round consists of the following:
132 1. c1.multiplyBy(c0)
133 2. c0 += random constant
134 3. c2 *= random constant
135 4. tmp = c1
136 5. ea.rotate(tmp, random amount in [-nSlots/2, nSlots/2])
137 6. c2 += tmp
138 7. ea.rotate(c2, random amount in [1-nSlots, nSlots-1])
139 8. c1.negate()
140 9. c3.multiplyBy(c2)
141 10. c0 -= c3
142 **************/
testGeneralOps(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon,long nRounds)143 void testGeneralOps(const PubKey& publicKey, const SecKey& secretKey,
144 const EncryptedArrayCx& ea, double epsilon,
145 long nRounds)
146 {
147 long nslots = ea.size();
148 char buffer[32];
149
150 vector<cx_double> p0, p1, p2, p3;
151 ea.random(p0);
152 ea.random(p1);
153 ea.random(p2);
154 ea.random(p3);
155
156 Ctxt c0(publicKey), c1(publicKey), c2(publicKey), c3(publicKey);
157 ea.encrypt(c0, publicKey, p0, /*size=*/1.0);
158 ea.encrypt(c1, publicKey, p1, /*size=*/1.0);
159 ea.encrypt(c2, publicKey, p2, /*size=*/1.0);
160 ea.encrypt(c3, publicKey, p3, /*size=*/1.0);
161
162 resetAllTimers();
163 HELIB_NTIMER_START(Circuit);
164
165 for (long i = 0; i < nRounds; i++) {
166
167 if (verbose) std::cout << "*** round " << i << "..."<<endl;
168
169 long shamt = RandomBnd(2*(nslots/2) + 1) - (nslots/2);
170 // random number in [-nslots/2..nslots/2]
171 long rotamt = RandomBnd(2*nslots - 1) - (nslots - 1);
172 // random number in [-(nslots-1)..nslots-1]
173
174 // two random constants
175 vector<cx_double> const1, const2;
176 ea.random(const1);
177 ea.random(const2);
178
179 ZZX const1_poly, const2_poly;
180 ea.encode(const1_poly, const1, /*size=*/1.0);
181 ea.encode(const2_poly, const2, /*size=*/1.0);
182
183 mul(p1, p0); // c1.multiplyBy(c0)
184 c1.multiplyBy(c0);
185 if (verbose) {
186 CheckCtxt(c1, "c1*=c0");
187 debugCompare(ea, secretKey, p1, c1, epsilon);
188 }
189
190 add(p0, const1); // c0 += random constant
191 c0.addConstant(const1_poly);
192 if (verbose) {
193 CheckCtxt(c0, "c0+=k1");
194 debugCompare(ea, secretKey, p0, c0, epsilon);
195 }
196 mul(p2, const2); // c2 *= random constant
197 c2.multByConstant(const2_poly);
198 if (verbose) {
199 CheckCtxt(c2, "c2*=k2");
200 debugCompare(ea, secretKey, p2, c2, epsilon);
201 }
202 vector<cx_double> tmp_p(p1); // tmp = c1
203 Ctxt tmp(c1);
204 sprintf(buffer, "tmp=c1>>=%d", (int)shamt);
205 rotate(tmp_p, shamt); // ea.shift(tmp, random amount in [-nSlots/2,nSlots/2])
206 ea.rotate(tmp, shamt);
207 if (verbose) {
208 CheckCtxt(tmp, buffer);
209 debugCompare(ea, secretKey, tmp_p, tmp, epsilon);
210 }
211 add(p2, tmp_p); // c2 += tmp
212 c2 += tmp;
213 if (verbose) {
214 CheckCtxt(c2, "c2+=tmp");
215 debugCompare(ea, secretKey, p2, c2, epsilon);
216 }
217 sprintf(buffer, "c2>>>=%d", (int)rotamt);
218 rotate(p2, rotamt); // ea.rotate(c2, random amount in [1-nSlots, nSlots-1])
219 ea.rotate(c2, rotamt);
220 if (verbose) {
221 CheckCtxt(c2, buffer);
222 debugCompare(ea, secretKey, p2, c2, epsilon);
223 }
224 negateVec(p1); // c1.negate()
225 c1.negate();
226 if (verbose) {
227 CheckCtxt(c1, "c1=-c1");
228 debugCompare(ea, secretKey, p1, c1, epsilon);
229 }
230 mul(p3, p2); // c3.multiplyBy(c2)
231 c3.multiplyBy(c2);
232 if (verbose) {
233 CheckCtxt(c3, "c3*=c2");
234 debugCompare(ea, secretKey, p3, c3, epsilon);
235 }
236 sub(p0, p3); // c0 -= c3
237 c0 -= c3;
238 if (verbose) {
239 CheckCtxt(c0, "c0=-c3");
240 debugCompare(ea, secretKey, p0, c0, epsilon);
241 }
242 }
243
244 c0.cleanUp();
245 c1.cleanUp();
246 c2.cleanUp();
247 c3.cleanUp();
248
249 HELIB_NTIMER_STOP(Circuit);
250
251 vector<cx_double> pp0, pp1, pp2, pp3;
252
253 ea.decrypt(c0, secretKey, pp0);
254 ea.decrypt(c1, secretKey, pp1);
255 ea.decrypt(c2, secretKey, pp2);
256 ea.decrypt(c3, secretKey, pp3);
257
258 std::cout << "Test "<<nRounds<<" rounds of mixed operations, ";
259 if (cx_equals(pp0, p0,conv<double>(epsilon*c0.getPtxtMag()))
260 && cx_equals(pp1, p1,conv<double>(epsilon*c1.getPtxtMag()))
261 && cx_equals(pp2, p2,conv<double>(epsilon*c2.getPtxtMag()))
262 && cx_equals(pp3, p3,conv<double>(epsilon*c3.getPtxtMag())))
263 std::cout << "PASS\n\n";
264 else {
265 std::cout << "FAIL:\n";
266 std::cout << " max(p0)="<<largestCoeff(p0)
267 << ", max(pp0)="<<largestCoeff(pp0)
268 << ", maxDiff="<<calcMaxDiff(p0,pp0) << endl;
269 std::cout << " max(p1)="<<largestCoeff(p1)
270 << ", max(pp1)="<<largestCoeff(pp1)
271 << ", maxDiff="<<calcMaxDiff(p1,pp1) << endl;
272 std::cout << " max(p2)="<<largestCoeff(p2)
273 << ", max(pp2)="<<largestCoeff(pp2)
274 << ", maxDiff="<<calcMaxDiff(p2,pp2) << endl;
275 std::cout << " max(p3)="<<largestCoeff(p3)
276 << ", max(pp3)="<<largestCoeff(pp3)
277 << ", maxDiff="<<calcMaxDiff(p3,pp3) << endl<<endl;
278 }
279
280 if (verbose) {
281 std::cout << endl;
282 printAllTimers();
283 std::cout << endl;
284 }
285 resetAllTimers();
286 }
287
main(int argc,char * argv[])288 int main(int argc, char *argv[])
289 {
290
291 // Commandline setup
292
293 ArgMap amap;
294
295 long m=16;
296 long r=8;
297 long L=0;
298 double epsilon=0.01; // Accepted accuracy
299 long R=1;
300 long seed=0;
301 bool debug = false;
302
303 amap.arg("m", m, "Cyclotomic index");
304 amap.note("e.g., m=1024, m=2047");
305 amap.arg("r", r, "Bits of precision");
306 amap.arg("R", R, "number of rounds");
307 amap.arg("L", L, "Number of bits in modulus", "heuristic");
308 amap.arg("ep", epsilon, "Accepted accuracy");
309 amap.arg("seed", seed, "PRG seed");
310 amap.arg("verbose", verbose, "more printouts");
311 amap.arg("debug", debug, "for debugging");
312
313 amap.parse(argc, argv);
314
315 if (seed)
316 NTL::SetSeed(ZZ(seed));
317
318 if (R<=0) R=1;
319 if (R<=2)
320 L = 100*R;
321 else
322 L = 220*(R-1);
323
324 if (verbose) {
325 cout << "** m="<<m<<", #rounds="<<R<<", |q|="<<L
326 << ", epsilon="<<epsilon<<endl;
327 }
328 epsilon /= R;
329 try{
330
331 // FHE setup keys, context, SKMs, etc
332
333 Context context(m, /*p=*/-1, r);
334 context.scale=4;
335 buildModChain(context, L, /*c=*/2);
336
337 SecKey secretKey(context);
338 secretKey.GenSecKey(); // A +-1/0 secret key
339 addSome1DMatrices(secretKey); // compute key-switching matrices
340
341 const PubKey publicKey = secretKey;
342 const EncryptedArrayCx& ea = context.ea->getCx();
343
344 if (verbose) {
345 ea.getPAlgebra().printout();
346 cout << "r = " << context.alMod.getR() << endl;
347 cout << "ctxtPrimes="<<context.ctxtPrimes
348 << ", specialPrimes="<<context.specialPrimes<<endl<<endl;
349 }
350 if (debug) {
351 dbgKey = & secretKey;
352 dbgEa = context.ea;
353 }
354 #ifdef HELIB_DEBUG
355 dbgKey = & secretKey;
356 dbgEa = context.ea;
357 #endif //HELIB_DEBUG
358
359 // Run the tests.
360 testBasicArith(publicKey, secretKey, ea, epsilon);
361 testComplexArith(publicKey, secretKey, ea, epsilon);
362 testRotsNShifts(publicKey, secretKey, ea, epsilon);
363 testGeneralOps(publicKey, secretKey, ea, epsilon*R, R);
364 }
365 catch (exception& e) {
366 cerr << e.what() << endl;
367 cerr << "***Major FAIL***" << endl;
368 }
369
370 return 0;
371 }
372
373
testBasicArith(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon)374 void testBasicArith(const PubKey& publicKey,
375 const SecKey& secretKey,
376 const EncryptedArrayCx& ea, double epsilon)
377 {
378 if (verbose) cout << "Test Arithmetic ";
379 // Test objects
380
381 Ctxt c1(publicKey), c2(publicKey), c3(publicKey);
382
383 vector<cx_double> vd;
384 vector<cx_double> vd1, vd2, vd3;
385 ea.random(vd1);
386 ea.random(vd2);
387
388 // test encoding of shorter vectors
389 vd1.resize(vd1.size()-2);
390 ea.encrypt(c1, publicKey, vd1, /*size=*/1.0);
391 vd1.resize(vd1.size()+2, 0.0);
392
393 ea.encrypt(c2, publicKey, vd2, /*size=*/1.0);
394
395 // Test - Multiplication
396 c1 *= c2;
397 for (long i=0; i<lsize(vd1); i++) vd1[i] *= vd2[i];
398
399 ZZX poly;
400 ea.random(vd3);
401 ea.encode(poly, vd3, /*size=*/1.0);
402 c1.addConstant(poly); // vd1*vd2 + vd3
403 for (long i=0; i<lsize(vd1); i++) vd1[i] += vd3[i];
404
405 // Test encoding, encryption of a single number
406 double xx = NTL::RandomLen_long(16)/double(1L<<16); // random in [0,1]
407 ea.encryptOneNum(c2, publicKey, xx);
408 c1 += c2;
409 for (auto& x : vd1) x += xx;
410
411 // Test - Multiply by a mask
412 vector<long> mask(lsize(vd1), 1);
413 for (long i=0; i*(i+1)<lsize(mask); i++) {
414 mask[i*i] = 0;
415 mask[i*(i+1)] = -1;
416 }
417
418 ea.encode(poly,mask, /*size=*/1.0);
419 c1.multByConstant(poly); // mask*(vd1*vd2 + vd3)
420 for (long i=0; i<lsize(vd1); i++) vd1[i] *= mask[i];
421
422 // Test - Addition
423 ea.random(vd3);
424 ea.encrypt(c3, publicKey, vd3, /*size=*/1.0);
425 c1 += c3;
426 for (long i=0; i<lsize(vd1); i++) vd1[i] += vd3[i];
427
428 c1.negate();
429 c1.addConstant(to_ZZ(1));
430 for (long i=0; i<lsize(vd1); i++) vd1[i] = 1.0 - vd1[i];
431
432 // Diff between approxNums HE scheme and plaintext floating
433 ea.decrypt(c1, secretKey, vd);
434 #ifdef HELIB_DEBUG
435 printVec(cout<<"res=", vd, 10)<<endl;
436 printVec(cout<<"vec=", vd1, 10)<<endl;
437 #endif
438 if (verbose)
439 cout << "(max |res-vec|_{infty}="<< calcMaxDiff(vd, vd1) << "): ";
440
441 if (cx_equals(vd, vd1, conv<double>(epsilon*c1.getPtxtMag())))
442 cout << "GOOD\n";
443 else {
444 cout << "BAD:\n";
445 std::cout << " max(vd)="<<largestCoeff(vd)
446 << ", max(vd1)="<<largestCoeff(vd1)
447 << ", maxDiff="<<calcMaxDiff(vd,vd1) << endl<<endl;
448 }
449 }
450
451
testComplexArith(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon)452 void testComplexArith(const PubKey& publicKey,
453 const SecKey& secretKey,
454 const EncryptedArrayCx& ea, double epsilon)
455 {
456
457 // Test complex conjugate
458 Ctxt c1(publicKey), c2(publicKey);
459
460 vector<cx_double> vd;
461 vector<cx_double> vd1, vd2;
462 ea.random(vd1);
463 ea.random(vd2);
464
465 ea.encrypt(c1, publicKey, vd1, /*size=*/1.0);
466 ea.encrypt(c2, publicKey, vd2, /*size=*/1.0);
467
468 if (verbose)
469 cout << "Test Conjugate: ";
470 for_each(vd1.begin(), vd1.end(), [](cx_double& d){d=std::conj(d);});
471 c1.complexConj();
472 ea.decrypt(c1, secretKey, vd);
473 #ifdef HELIB_DEBUG
474 printVec(cout<<"vd1=", vd1, 10)<<endl;
475 printVec(cout<<"res=", vd, 10)<<endl;
476 #endif
477 if (cx_equals(vd, vd1, conv<double>(epsilon*c1.getPtxtMag())))
478 cout << "GOOD\n";
479 else {
480 cout << "BAD:\n";
481 std::cout << " max(vd)="<<largestCoeff(vd)
482 << ", max(vd1)="<<largestCoeff(vd1)
483 << ", maxDiff="<<calcMaxDiff(vd,vd1) << endl<<endl;
484 }
485
486 // Test that real and imaginary parts are actually extracted.
487 Ctxt realCtxt(c2), imCtxt(c2);
488 vector<cx_double> realParts(vd2), real_dec;
489 vector<cx_double> imParts(vd2), im_dec;
490
491 if (verbose)
492 cout << "Test Real and Im parts: ";
493 for_each(realParts.begin(), realParts.end(), [](cx_double& d){d=std::real(d);});
494 for_each(imParts.begin(), imParts.end(), [](cx_double& d){d=std::imag(d);});
495
496 ea.extractRealPart(realCtxt);
497 ea.decrypt(realCtxt, secretKey, real_dec);
498
499 ea.extractImPart(imCtxt);
500 ea.decrypt(imCtxt, secretKey, im_dec);
501
502 #ifdef HELIB_DEBUG
503 printVec(cout<<"vd2=", vd2, 10)<<endl;
504 printVec(cout<<"real=", realParts, 10)<<endl;
505 printVec(cout<<"res=", real_dec, 10)<<endl;
506 printVec(cout<<"im=", imParts, 10)<<endl;
507 printVec(cout<<"res=", im_dec, 10)<<endl;
508 #endif
509 if (cx_equals(realParts,real_dec,conv<double>(epsilon*realCtxt.getPtxtMag()))
510 && cx_equals(imParts, im_dec, conv<double>(epsilon*imCtxt.getPtxtMag())))
511 cout << "GOOD\n";
512 else {
513 cout << "BAD:\n";
514 std::cout << " max(re)="<<largestCoeff(realParts)
515 << ", max(re1)="<<largestCoeff(real_dec)
516 << ", maxDiff="<<calcMaxDiff(realParts,real_dec) << endl;
517 std::cout << " max(im)="<<largestCoeff(imParts)
518 << ", max(im1)="<<largestCoeff(im_dec)
519 << ", maxDiff="<<calcMaxDiff(imParts,im_dec) << endl<<endl;
520 }
521 }
522
testRotsNShifts(const PubKey & publicKey,const SecKey & secretKey,const EncryptedArrayCx & ea,double epsilon)523 void testRotsNShifts(const PubKey& publicKey,
524 const SecKey& secretKey,
525 const EncryptedArrayCx& ea, double epsilon)
526 {
527
528 std::srand(std::time(0)); // set seed, current time.
529 int nplaces = rand() % static_cast<int>(ea.size()/2.0) + 1;
530
531 if (verbose)
532 cout << "Test Rotation of " << nplaces << ": ";
533
534 Ctxt c1(publicKey);
535 vector<cx_double> vd1;
536 vector<cx_double> vd_dec;
537 ea.random(vd1);
538 ea.encrypt(c1, publicKey, vd1, /*size=*/1.0);
539
540 #ifdef HELIB_DEBUG
541 printVec(cout<< "vd1=", vd1, 10)<<endl;
542 #endif
543 std::rotate(vd1.begin(), vd1.end()-nplaces, vd1.end());
544 ea.rotate(c1, nplaces);
545 c1.reLinearize();
546 ea.decrypt(c1, secretKey, vd_dec);
547 #ifdef HELIB_DEBUG
548 printVec(cout<< "vd1(rot)=", vd1, 10)<<endl;
549 printVec(cout<<"res: ", vd_dec, 10)<<endl;
550 #endif
551
552 if (cx_equals(vd1, vd_dec, conv<double>(epsilon*c1.getPtxtMag())))
553 cout << "GOOD\n";
554 else {
555 cout << "BAD:\n";
556 std::cout << " max(vd)="<<largestCoeff(vd_dec)
557 << ", max(vd1)="<<largestCoeff(vd1)
558 << ", maxDiff="<<calcMaxDiff(vd_dec,vd1) << endl<<endl;
559 }
560 }
561