1 // Copyright (c) Microsoft Corporation. All rights reserved.
2 // Licensed under the MIT license.
3 
4 #include "seal/batchencoder.h"
5 #include "seal/ckks.h"
6 #include "seal/context.h"
7 #include "seal/decryptor.h"
8 #include "seal/encryptor.h"
9 #include "seal/evaluator.h"
10 #include "seal/keygenerator.h"
11 #include "seal/modulus.h"
12 #include <cstddef>
13 #include <cstdint>
14 #include <ctime>
15 #include <string>
16 #include "gtest/gtest.h"
17 
18 using namespace seal;
19 using namespace std;
20 
21 namespace sealtest
22 {
TEST(EvaluatorTest,BFVEncryptNegateDecrypt)23     TEST(EvaluatorTest, BFVEncryptNegateDecrypt)
24     {
25         EncryptionParameters parms(scheme_type::bfv);
26         Modulus plain_modulus(1 << 6);
27         parms.set_poly_modulus_degree(64);
28         parms.set_plain_modulus(plain_modulus);
29         parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 }));
30 
31         SEALContext context(parms, false, sec_level_type::none);
32         KeyGenerator keygen(context);
33         PublicKey pk;
34         keygen.create_public_key(pk);
35 
36         Encryptor encryptor(context, pk);
37         Evaluator evaluator(context);
38         Decryptor decryptor(context, keygen.secret_key());
39 
40         Ciphertext encrypted;
41         Plaintext plain;
42 
43         plain = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
44         encryptor.encrypt(plain, encrypted);
45         evaluator.negate_inplace(encrypted);
46         decryptor.decrypt(encrypted, plain);
47         ASSERT_EQ(
48             plain.to_string(), "3Fx^28 + 3Fx^25 + 3Fx^21 + 3Fx^20 + 3Fx^18 + 3Fx^14 + 3Fx^12 + 3Fx^10 + 3Fx^9 + 3Fx^6 "
49                                "+ 3Fx^5 + 3Fx^4 + 3Fx^3");
50         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
51 
52         plain = "0";
53         encryptor.encrypt(plain, encrypted);
54         evaluator.negate_inplace(encrypted);
55         decryptor.decrypt(encrypted, plain);
56         ASSERT_EQ(plain.to_string(), "0");
57         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
58 
59         plain = "1";
60         encryptor.encrypt(plain, encrypted);
61         evaluator.negate_inplace(encrypted);
62         decryptor.decrypt(encrypted, plain);
63         ASSERT_EQ(plain.to_string(), "3F");
64         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
65 
66         plain = "3F";
67         encryptor.encrypt(plain, encrypted);
68         evaluator.negate_inplace(encrypted);
69         decryptor.decrypt(encrypted, plain);
70         ASSERT_EQ(plain.to_string(), "1");
71         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
72 
73         plain = "1x^1";
74         encryptor.encrypt(plain, encrypted);
75         evaluator.negate_inplace(encrypted);
76         decryptor.decrypt(encrypted, plain);
77         ASSERT_EQ(plain.to_string(), "3Fx^1");
78         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
79 
80         plain = "3Fx^2 + 3F";
81         encryptor.encrypt(plain, encrypted);
82         evaluator.negate_inplace(encrypted);
83         decryptor.decrypt(encrypted, plain);
84         ASSERT_EQ(plain.to_string(), "1x^2 + 1");
85         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
86     }
87 
TEST(EvaluatorTest,BFVEncryptAddDecrypt)88     TEST(EvaluatorTest, BFVEncryptAddDecrypt)
89     {
90         EncryptionParameters parms(scheme_type::bfv);
91         Modulus plain_modulus(1 << 6);
92         parms.set_poly_modulus_degree(64);
93         parms.set_plain_modulus(plain_modulus);
94         parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 }));
95 
96         SEALContext context(parms, false, sec_level_type::none);
97         KeyGenerator keygen(context);
98         PublicKey pk;
99         keygen.create_public_key(pk);
100 
101         Encryptor encryptor(context, pk);
102         Evaluator evaluator(context);
103         Decryptor decryptor(context, keygen.secret_key());
104 
105         Ciphertext encrypted1;
106         Ciphertext encrypted2;
107         Plaintext plain, plain1, plain2;
108 
109         plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
110         plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
111         encryptor.encrypt(plain1, encrypted1);
112         encryptor.encrypt(plain2, encrypted2);
113         evaluator.add_inplace(encrypted1, encrypted2);
114         decryptor.decrypt(encrypted1, plain);
115         ASSERT_EQ(
116             plain.to_string(), "1x^28 + 1x^25 + 1x^21 + 1x^20 + 2x^18 + 1x^16 + 2x^14 + 1x^12 + 1x^10 + 2x^9 + 1x^8 + "
117                                "1x^6 + 2x^5 + 1x^4 + 1x^3 + 1");
118         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
119         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
120 
121         plain1 = "0";
122         plain2 = "0";
123         encryptor.encrypt(plain1, encrypted1);
124         encryptor.encrypt(plain2, encrypted2);
125         evaluator.add_inplace(encrypted1, encrypted2);
126         decryptor.decrypt(encrypted1, plain);
127         ASSERT_EQ("0", plain.to_string());
128         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
129         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
130 
131         plain1 = "0";
132         plain2 = "1x^2 + 1";
133         encryptor.encrypt(plain1, encrypted1);
134         encryptor.encrypt(plain2, encrypted2);
135         evaluator.add_inplace(encrypted1, encrypted2);
136         decryptor.decrypt(encrypted1, plain);
137         ASSERT_EQ(plain.to_string(), "1x^2 + 1");
138         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
139         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
140 
141         plain1 = "1x^2 + 1";
142         plain2 = "3Fx^1 + 3F";
143         encryptor.encrypt(plain1, encrypted1);
144         encryptor.encrypt(plain2, encrypted2);
145         evaluator.add_inplace(encrypted1, encrypted2);
146         decryptor.decrypt(encrypted1, plain);
147         ASSERT_EQ(plain.to_string(), "1x^2 + 3Fx^1");
148         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
149         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
150 
151         plain1 = "3Fx^2 + 3Fx^1 + 3F";
152         plain2 = "1x^1";
153         encryptor.encrypt(plain1, encrypted1);
154         encryptor.encrypt(plain2, encrypted2);
155         evaluator.add_inplace(encrypted1, encrypted2);
156         decryptor.decrypt(encrypted1, plain);
157         ASSERT_EQ(plain.to_string(), "3Fx^2 + 3F");
158         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
159         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
160 
161         plain1 = "2x^2 + 1x^1 + 3";
162         plain2 = "3x^3 + 4x^2 + 5x^1 + 6";
163         encryptor.encrypt(plain1, encrypted1);
164         encryptor.encrypt(plain2, encrypted2);
165         evaluator.add_inplace(encrypted1, encrypted2);
166         decryptor.decrypt(encrypted1, plain);
167         ASSERT_TRUE(plain.to_string() == "3x^3 + 6x^2 + 6x^1 + 9");
168         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
169         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
170 
171         plain1 = "3x^5 + 1x^4 + 4x^3 + 1";
172         plain2 = "5x^2 + 9x^1 + 2";
173         encryptor.encrypt(plain1, encrypted1);
174         encryptor.encrypt(plain2, encrypted2);
175         evaluator.add_inplace(encrypted1, encrypted2);
176         decryptor.decrypt(encrypted1, plain);
177         ASSERT_TRUE(plain.to_string() == "3x^5 + 1x^4 + 4x^3 + 5x^2 + 9x^1 + 3");
178         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
179         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
180     }
181 
TEST(EvaluatorTest,CKKSEncryptAddDecrypt)182     TEST(EvaluatorTest, CKKSEncryptAddDecrypt)
183     {
184         EncryptionParameters parms(scheme_type::ckks);
185         {
186             // Adding two zero vectors
187             size_t slot_size = 32;
188             parms.set_poly_modulus_degree(slot_size * 2);
189             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30 }));
190 
191             SEALContext context(parms, false, sec_level_type::none);
192             KeyGenerator keygen(context);
193             PublicKey pk;
194             keygen.create_public_key(pk);
195 
196             CKKSEncoder encoder(context);
197             Encryptor encryptor(context, pk);
198             Decryptor decryptor(context, keygen.secret_key());
199             Evaluator evaluator(context);
200 
201             Ciphertext encrypted;
202             Plaintext plain;
203             Plaintext plainRes;
204 
205             vector<complex<double>> input(slot_size, 0.0);
206             vector<complex<double>> output(slot_size);
207             const double delta = static_cast<double>(1 << 16);
208             encoder.encode(input, context.first_parms_id(), delta, plain);
209 
210             encryptor.encrypt(plain, encrypted);
211             evaluator.add_inplace(encrypted, encrypted);
212 
213             // Check correctness of encryption
214             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
215 
216             decryptor.decrypt(encrypted, plainRes);
217             encoder.decode(plainRes, output);
218             for (size_t i = 0; i < slot_size; i++)
219             {
220                 auto tmp = abs(input[i].real() - output[i].real());
221                 ASSERT_TRUE(tmp < 0.5);
222             }
223         }
224         {
225             // Adding two random vectors 100 times
226             size_t slot_size = 32;
227             parms.set_poly_modulus_degree(slot_size * 2);
228             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
229 
230             SEALContext context(parms, false, sec_level_type::none);
231             KeyGenerator keygen(context);
232             PublicKey pk;
233             keygen.create_public_key(pk);
234 
235             CKKSEncoder encoder(context);
236             Encryptor encryptor(context, pk);
237             Decryptor decryptor(context, keygen.secret_key());
238             Evaluator evaluator(context);
239 
240             Ciphertext encrypted1;
241             Ciphertext encrypted2;
242             Plaintext plain1;
243             Plaintext plain2;
244             Plaintext plainRes;
245 
246             vector<complex<double>> input1(slot_size, 0.0);
247             vector<complex<double>> input2(slot_size, 0.0);
248             vector<complex<double>> expected(slot_size, 0.0);
249             vector<complex<double>> output(slot_size);
250 
251             int data_bound = (1 << 30);
252             const double delta = static_cast<double>(1 << 16);
253 
254             srand(static_cast<unsigned>(time(NULL)));
255 
256             for (int expCount = 0; expCount < 100; expCount++)
257             {
258                 for (size_t i = 0; i < slot_size; i++)
259                 {
260                     input1[i] = static_cast<double>(rand() % data_bound);
261                     input2[i] = static_cast<double>(rand() % data_bound);
262                     expected[i] = input1[i] + input2[i];
263                 }
264 
265                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
266                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
267 
268                 encryptor.encrypt(plain1, encrypted1);
269                 encryptor.encrypt(plain2, encrypted2);
270                 evaluator.add_inplace(encrypted1, encrypted2);
271 
272                 // Check correctness of encryption
273                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
274 
275                 decryptor.decrypt(encrypted1, plainRes);
276                 encoder.decode(plainRes, output);
277                 for (size_t i = 0; i < slot_size; i++)
278                 {
279                     auto tmp = abs(expected[i].real() - output[i].real());
280                     ASSERT_TRUE(tmp < 0.5);
281                 }
282             }
283         }
284         {
285             // Adding two random vectors 100 times
286             size_t slot_size = 8;
287             parms.set_poly_modulus_degree(64);
288             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 }));
289 
290             SEALContext context(parms, false, sec_level_type::none);
291             KeyGenerator keygen(context);
292             PublicKey pk;
293             keygen.create_public_key(pk);
294 
295             CKKSEncoder encoder(context);
296             Encryptor encryptor(context, pk);
297             Decryptor decryptor(context, keygen.secret_key());
298             Evaluator evaluator(context);
299 
300             Ciphertext encrypted1;
301             Ciphertext encrypted2;
302             Plaintext plain1;
303             Plaintext plain2;
304             Plaintext plainRes;
305 
306             vector<complex<double>> input1(slot_size, 0.0);
307             vector<complex<double>> input2(slot_size, 0.0);
308             vector<complex<double>> expected(slot_size, 0.0);
309             vector<complex<double>> output(slot_size);
310 
311             int data_bound = (1 << 30);
312             const double delta = static_cast<double>(1 << 16);
313 
314             srand(static_cast<unsigned>(time(NULL)));
315 
316             for (int expCount = 0; expCount < 100; expCount++)
317             {
318                 for (size_t i = 0; i < slot_size; i++)
319                 {
320                     input1[i] = static_cast<double>(rand() % data_bound);
321                     input2[i] = static_cast<double>(rand() % data_bound);
322                     expected[i] = input1[i] + input2[i];
323                 }
324 
325                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
326                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
327 
328                 encryptor.encrypt(plain1, encrypted1);
329                 encryptor.encrypt(plain2, encrypted2);
330                 evaluator.add_inplace(encrypted1, encrypted2);
331 
332                 // Check correctness of encryption
333                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
334 
335                 decryptor.decrypt(encrypted1, plainRes);
336                 encoder.decode(plainRes, output);
337                 for (size_t i = 0; i < slot_size; i++)
338                 {
339                     auto tmp = abs(expected[i].real() - output[i].real());
340                     ASSERT_TRUE(tmp < 0.5);
341                 }
342             }
343         }
344     }
TEST(EvaluatorTest,CKKSEncryptAddPlainDecrypt)345     TEST(EvaluatorTest, CKKSEncryptAddPlainDecrypt)
346     {
347         EncryptionParameters parms(scheme_type::ckks);
348         {
349             // Adding two zero vectors
350             size_t slot_size = 32;
351             parms.set_poly_modulus_degree(slot_size * 2);
352             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30 }));
353 
354             SEALContext context(parms, false, sec_level_type::none);
355             KeyGenerator keygen(context);
356             PublicKey pk;
357             keygen.create_public_key(pk);
358 
359             CKKSEncoder encoder(context);
360             Encryptor encryptor(context, pk);
361             Decryptor decryptor(context, keygen.secret_key());
362             Evaluator evaluator(context);
363 
364             Ciphertext encrypted;
365             Plaintext plain;
366             Plaintext plainRes;
367 
368             vector<complex<double>> input(slot_size, 0.0);
369             vector<complex<double>> output(slot_size);
370             const double delta = static_cast<double>(1 << 16);
371             encoder.encode(input, context.first_parms_id(), delta, plain);
372 
373             encryptor.encrypt(plain, encrypted);
374             evaluator.add_plain_inplace(encrypted, plain);
375 
376             // Check correctness of encryption
377             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
378 
379             decryptor.decrypt(encrypted, plainRes);
380             encoder.decode(plainRes, output);
381             for (size_t i = 0; i < slot_size; i++)
382             {
383                 auto tmp = abs(input[i].real() - output[i].real());
384                 ASSERT_TRUE(tmp < 0.5);
385             }
386         }
387         {
388             // Adding two random vectors 50 times
389             size_t slot_size = 32;
390             parms.set_poly_modulus_degree(slot_size * 2);
391             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
392 
393             SEALContext context(parms, false, sec_level_type::none);
394             KeyGenerator keygen(context);
395             PublicKey pk;
396             keygen.create_public_key(pk);
397 
398             CKKSEncoder encoder(context);
399             Encryptor encryptor(context, pk);
400             Decryptor decryptor(context, keygen.secret_key());
401             Evaluator evaluator(context);
402 
403             Ciphertext encrypted1;
404             Plaintext plain1;
405             Plaintext plain2;
406             Plaintext plainRes;
407 
408             vector<complex<double>> input1(slot_size, 0.0);
409             vector<complex<double>> input2(slot_size, 0.0);
410             vector<complex<double>> expected(slot_size, 0.0);
411             vector<complex<double>> output(slot_size);
412 
413             int data_bound = (1 << 8);
414             const double delta = static_cast<double>(1ULL << 16);
415 
416             srand(static_cast<unsigned>(time(NULL)));
417 
418             for (int expCount = 0; expCount < 50; expCount++)
419             {
420                 for (size_t i = 0; i < slot_size; i++)
421                 {
422                     input1[i] = static_cast<double>(rand() % data_bound);
423                     input2[i] = static_cast<double>(rand() % data_bound);
424                     expected[i] = input1[i] + input2[i];
425                 }
426 
427                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
428                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
429 
430                 encryptor.encrypt(plain1, encrypted1);
431                 evaluator.add_plain_inplace(encrypted1, plain2);
432 
433                 // Check correctness of encryption
434                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
435 
436                 decryptor.decrypt(encrypted1, plainRes);
437                 encoder.decode(plainRes, output);
438                 for (size_t i = 0; i < slot_size; i++)
439                 {
440                     auto tmp = abs(expected[i].real() - output[i].real());
441                     ASSERT_TRUE(tmp < 0.5);
442                 }
443             }
444         }
445         {
446             // Adding two random vectors 50 times
447             size_t slot_size = 32;
448             parms.set_poly_modulus_degree(slot_size * 2);
449             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
450 
451             SEALContext context(parms, false, sec_level_type::none);
452             KeyGenerator keygen(context);
453             PublicKey pk;
454             keygen.create_public_key(pk);
455 
456             CKKSEncoder encoder(context);
457             Encryptor encryptor(context, pk);
458             Decryptor decryptor(context, keygen.secret_key());
459             Evaluator evaluator(context);
460 
461             Ciphertext encrypted1;
462             Plaintext plain1;
463             Plaintext plain2;
464             Plaintext plainRes;
465 
466             vector<complex<double>> input1(slot_size, 0.0);
467             double input2;
468             vector<complex<double>> expected(slot_size, 0.0);
469             vector<complex<double>> output(slot_size);
470 
471             int data_bound = (1 << 8);
472             const double delta = static_cast<double>(1ULL << 16);
473 
474             srand(static_cast<unsigned>(time(NULL)));
475 
476             for (int expCount = 0; expCount < 50; expCount++)
477             {
478                 input2 = static_cast<double>(rand() % (data_bound * data_bound)) / data_bound;
479                 for (size_t i = 0; i < slot_size; i++)
480                 {
481                     input1[i] = static_cast<double>(rand() % data_bound);
482                     expected[i] = input1[i] + input2;
483                 }
484 
485                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
486                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
487 
488                 encryptor.encrypt(plain1, encrypted1);
489                 evaluator.add_plain_inplace(encrypted1, plain2);
490 
491                 // Check correctness of encryption
492                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
493 
494                 decryptor.decrypt(encrypted1, plainRes);
495                 encoder.decode(plainRes, output);
496                 for (size_t i = 0; i < slot_size; i++)
497                 {
498                     auto tmp = abs(expected[i].real() - output[i].real());
499                     ASSERT_TRUE(tmp < 0.5);
500                 }
501             }
502         }
503         {
504             // Adding two random vectors 50 times
505             size_t slot_size = 8;
506             parms.set_poly_modulus_degree(64);
507             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 }));
508 
509             SEALContext context(parms, false, sec_level_type::none);
510             KeyGenerator keygen(context);
511             PublicKey pk;
512             keygen.create_public_key(pk);
513 
514             CKKSEncoder encoder(context);
515             Encryptor encryptor(context, pk);
516             Decryptor decryptor(context, keygen.secret_key());
517             Evaluator evaluator(context);
518 
519             Ciphertext encrypted1;
520             Plaintext plain1;
521             Plaintext plain2;
522             Plaintext plainRes;
523 
524             vector<complex<double>> input1(slot_size, 0.0);
525             double input2;
526             vector<complex<double>> expected(slot_size, 0.0);
527             vector<complex<double>> output(slot_size);
528 
529             int data_bound = (1 << 8);
530             const double delta = static_cast<double>(1ULL << 16);
531 
532             srand(static_cast<unsigned>(time(NULL)));
533 
534             for (int expCount = 0; expCount < 50; expCount++)
535             {
536                 input2 = static_cast<double>(rand() % (data_bound * data_bound)) / data_bound;
537                 for (size_t i = 0; i < slot_size; i++)
538                 {
539                     input1[i] = static_cast<double>(rand() % data_bound);
540                     expected[i] = input1[i] + input2;
541                 }
542 
543                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
544                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
545 
546                 encryptor.encrypt(plain1, encrypted1);
547                 evaluator.add_plain_inplace(encrypted1, plain2);
548 
549                 // Check correctness of encryption
550                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
551 
552                 decryptor.decrypt(encrypted1, plainRes);
553                 encoder.decode(plainRes, output);
554                 for (size_t i = 0; i < slot_size; i++)
555                 {
556                     auto tmp = abs(expected[i].real() - output[i].real());
557                     ASSERT_TRUE(tmp < 0.5);
558                 }
559             }
560         }
561     }
562 
TEST(EvaluatorTest,CKKSEncryptSubPlainDecrypt)563     TEST(EvaluatorTest, CKKSEncryptSubPlainDecrypt)
564     {
565         EncryptionParameters parms(scheme_type::ckks);
566         {
567             // Subtracting two zero vectors
568             size_t slot_size = 32;
569             parms.set_poly_modulus_degree(slot_size * 2);
570             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30 }));
571 
572             SEALContext context(parms, false, sec_level_type::none);
573             KeyGenerator keygen(context);
574             PublicKey pk;
575             keygen.create_public_key(pk);
576 
577             CKKSEncoder encoder(context);
578             Encryptor encryptor(context, pk);
579             Decryptor decryptor(context, keygen.secret_key());
580             Evaluator evaluator(context);
581 
582             Ciphertext encrypted;
583             Plaintext plain;
584             Plaintext plainRes;
585 
586             vector<complex<double>> input(slot_size, 0.0);
587             vector<complex<double>> output(slot_size);
588             const double delta = static_cast<double>(1 << 16);
589             encoder.encode(input, context.first_parms_id(), delta, plain);
590 
591             encryptor.encrypt(plain, encrypted);
592             evaluator.add_plain_inplace(encrypted, plain);
593 
594             // Check correctness of encryption
595             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
596 
597             decryptor.decrypt(encrypted, plainRes);
598             encoder.decode(plainRes, output);
599             for (size_t i = 0; i < slot_size; i++)
600             {
601                 auto tmp = abs(input[i].real() - output[i].real());
602                 ASSERT_TRUE(tmp < 0.5);
603             }
604         }
605         {
606             // Subtracting two random vectors 100 times
607             size_t slot_size = 32;
608             parms.set_poly_modulus_degree(slot_size * 2);
609             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
610 
611             SEALContext context(parms, false, sec_level_type::none);
612             KeyGenerator keygen(context);
613             PublicKey pk;
614             keygen.create_public_key(pk);
615 
616             CKKSEncoder encoder(context);
617             Encryptor encryptor(context, pk);
618             Decryptor decryptor(context, keygen.secret_key());
619             Evaluator evaluator(context);
620 
621             Ciphertext encrypted1;
622             Plaintext plain1;
623             Plaintext plain2;
624             Plaintext plainRes;
625 
626             vector<complex<double>> input1(slot_size, 0.0);
627             vector<complex<double>> input2(slot_size, 0.0);
628             vector<complex<double>> expected(slot_size, 0.0);
629             vector<complex<double>> output(slot_size);
630 
631             int data_bound = (1 << 8);
632             const double delta = static_cast<double>(1ULL << 16);
633 
634             srand(static_cast<unsigned>(time(NULL)));
635 
636             for (int expCount = 0; expCount < 100; expCount++)
637             {
638                 for (size_t i = 0; i < slot_size; i++)
639                 {
640                     input1[i] = static_cast<double>(rand() % data_bound);
641                     input2[i] = static_cast<double>(rand() % data_bound);
642                     expected[i] = input1[i] - input2[i];
643                 }
644 
645                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
646                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
647 
648                 encryptor.encrypt(plain1, encrypted1);
649                 evaluator.sub_plain_inplace(encrypted1, plain2);
650 
651                 // Check correctness of encryption
652                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
653 
654                 decryptor.decrypt(encrypted1, plainRes);
655                 encoder.decode(plainRes, output);
656                 for (size_t i = 0; i < slot_size; i++)
657                 {
658                     auto tmp = abs(expected[i].real() - output[i].real());
659                     ASSERT_TRUE(tmp < 0.5);
660                 }
661             }
662         }
663         {
664             // Subtracting two random vectors 100 times
665             size_t slot_size = 8;
666             parms.set_poly_modulus_degree(64);
667             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 }));
668 
669             SEALContext context(parms, false, sec_level_type::none);
670             KeyGenerator keygen(context);
671             PublicKey pk;
672             keygen.create_public_key(pk);
673 
674             CKKSEncoder encoder(context);
675             Encryptor encryptor(context, pk);
676             Decryptor decryptor(context, keygen.secret_key());
677             Evaluator evaluator(context);
678 
679             Ciphertext encrypted1;
680             Plaintext plain1;
681             Plaintext plain2;
682             Plaintext plainRes;
683 
684             vector<complex<double>> input1(slot_size, 0.0);
685             vector<complex<double>> input2(slot_size, 0.0);
686             vector<complex<double>> expected(slot_size, 0.0);
687             vector<complex<double>> output(slot_size);
688 
689             int data_bound = (1 << 8);
690             const double delta = static_cast<double>(1ULL << 16);
691 
692             srand(static_cast<unsigned>(time(NULL)));
693 
694             for (int expCount = 0; expCount < 100; expCount++)
695             {
696                 for (size_t i = 0; i < slot_size; i++)
697                 {
698                     input1[i] = static_cast<double>(rand() % data_bound);
699                     input2[i] = static_cast<double>(rand() % data_bound);
700                     expected[i] = input1[i] - input2[i];
701                 }
702 
703                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
704                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
705 
706                 encryptor.encrypt(plain1, encrypted1);
707                 evaluator.sub_plain_inplace(encrypted1, plain2);
708 
709                 // Check correctness of encryption
710                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
711 
712                 decryptor.decrypt(encrypted1, plainRes);
713                 encoder.decode(plainRes, output);
714                 for (size_t i = 0; i < slot_size; i++)
715                 {
716                     auto tmp = abs(expected[i].real() - output[i].real());
717                     ASSERT_TRUE(tmp < 0.5);
718                 }
719             }
720         }
721     }
722 
TEST(EvaluatorTest,BFVEncryptSubDecrypt)723     TEST(EvaluatorTest, BFVEncryptSubDecrypt)
724     {
725         EncryptionParameters parms(scheme_type::bfv);
726         Modulus plain_modulus(1 << 6);
727         parms.set_poly_modulus_degree(64);
728         parms.set_plain_modulus(plain_modulus);
729         parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 }));
730 
731         SEALContext context(parms, false, sec_level_type::none);
732         KeyGenerator keygen(context);
733         PublicKey pk;
734         keygen.create_public_key(pk);
735 
736         Encryptor encryptor(context, pk);
737         Evaluator evaluator(context);
738         Decryptor decryptor(context, keygen.secret_key());
739 
740         Ciphertext encrypted1;
741         Ciphertext encrypted2;
742         Plaintext plain, plain1, plain2;
743 
744         plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
745         plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
746         encryptor.encrypt(plain1, encrypted1);
747         encryptor.encrypt(plain2, encrypted2);
748         evaluator.sub_inplace(encrypted1, encrypted2);
749         decryptor.decrypt(encrypted1, plain);
750         ASSERT_EQ(
751             plain.to_string(),
752             "1x^28 + 1x^25 + 1x^21 + 1x^20 + 3Fx^16 + 1x^12 + 1x^10 + 3Fx^8 + 1x^6 + 1x^4 + 1x^3 + 3F");
753         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
754         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
755 
756         plain1 = "0";
757         plain2 = "0";
758         encryptor.encrypt(plain1, encrypted1);
759         encryptor.encrypt(plain2, encrypted2);
760         evaluator.sub_inplace(encrypted1, encrypted2);
761         decryptor.decrypt(encrypted1, plain);
762         ASSERT_EQ(plain.to_string(), "0");
763         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
764         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
765 
766         plain1 = "0";
767         plain2 = "1x^2 + 1";
768         encryptor.encrypt(plain1, encrypted1);
769         encryptor.encrypt(plain2, encrypted2);
770         evaluator.sub_inplace(encrypted1, encrypted2);
771         decryptor.decrypt(encrypted1, plain);
772         ASSERT_EQ(plain.to_string(), "3Fx^2 + 3F");
773         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
774         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
775 
776         plain1 = "1x^2 + 1";
777         plain2 = "3Fx^1 + 3F";
778         encryptor.encrypt(plain1, encrypted1);
779         encryptor.encrypt(plain2, encrypted2);
780         evaluator.sub_inplace(encrypted1, encrypted2);
781         decryptor.decrypt(encrypted1, plain);
782         ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 2");
783         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
784         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
785 
786         plain1 = "3Fx^2 + 3Fx^1 + 3F";
787         plain2 = "1x^1";
788         encryptor.encrypt(plain1, encrypted1);
789         encryptor.encrypt(plain2, encrypted2);
790         evaluator.sub_inplace(encrypted1, encrypted2);
791         decryptor.decrypt(encrypted1, plain);
792         ASSERT_EQ(plain.to_string(), "3Fx^2 + 3Ex^1 + 3F");
793         ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
794         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
795     }
796 
TEST(EvaluatorTest,BFVEncryptAddPlainDecrypt)797     TEST(EvaluatorTest, BFVEncryptAddPlainDecrypt)
798     {
799         EncryptionParameters parms(scheme_type::bfv);
800         Modulus plain_modulus(1 << 6);
801         parms.set_poly_modulus_degree(64);
802         parms.set_plain_modulus(plain_modulus);
803         parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 }));
804 
805         SEALContext context(parms, false, sec_level_type::none);
806         KeyGenerator keygen(context);
807         PublicKey pk;
808         keygen.create_public_key(pk);
809 
810         Encryptor encryptor(context, pk);
811         Evaluator evaluator(context);
812         Decryptor decryptor(context, keygen.secret_key());
813 
814         Ciphertext encrypted1;
815         Ciphertext encrypted2;
816         Plaintext plain, plain1, plain2;
817 
818         plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
819         plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
820         encryptor.encrypt(plain1, encrypted1);
821         evaluator.add_plain_inplace(encrypted1, plain2);
822         decryptor.decrypt(encrypted1, plain);
823         ASSERT_EQ(
824             plain.to_string(), "1x^28 + 1x^25 + 1x^21 + 1x^20 + 2x^18 + 1x^16 + 2x^14 + 1x^12 + 1x^10 + 2x^9 + 1x^8 + "
825                                "1x^6 + 2x^5 + 1x^4 + 1x^3 + 1");
826         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
827 
828         plain1 = "0";
829         plain2 = "0";
830         encryptor.encrypt(plain1, encrypted1);
831         evaluator.add_plain_inplace(encrypted1, plain2);
832         decryptor.decrypt(encrypted1, plain);
833         ASSERT_EQ(plain.to_string(), "0");
834         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
835 
836         plain1 = "0";
837         plain2 = "1x^2 + 1";
838         encryptor.encrypt(plain1, encrypted1);
839         evaluator.add_plain_inplace(encrypted1, plain2);
840         decryptor.decrypt(encrypted1, plain);
841         ASSERT_EQ(plain.to_string(), "1x^2 + 1");
842         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
843 
844         plain1 = "1x^2 + 1";
845         plain2 = "3Fx^1 + 3F";
846         encryptor.encrypt(plain1, encrypted1);
847         evaluator.add_plain_inplace(encrypted1, plain2);
848         decryptor.decrypt(encrypted1, plain);
849         ASSERT_EQ(plain.to_string(), "1x^2 + 3Fx^1");
850         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
851 
852         plain1 = "3Fx^2 + 3Fx^1 + 3F";
853         plain2 = "1x^2 + 1x^1 + 1";
854         encryptor.encrypt(plain1, encrypted1);
855         evaluator.add_plain_inplace(encrypted1, plain2);
856         decryptor.decrypt(encrypted1, plain);
857         ASSERT_EQ(plain.to_string(), "0");
858         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
859     }
860 
TEST(EvaluatorTest,BFVEncryptSubPlainDecrypt)861     TEST(EvaluatorTest, BFVEncryptSubPlainDecrypt)
862     {
863         EncryptionParameters parms(scheme_type::bfv);
864         Modulus plain_modulus(1 << 6);
865         parms.set_poly_modulus_degree(64);
866         parms.set_plain_modulus(plain_modulus);
867         parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 }));
868 
869         SEALContext context(parms, false, sec_level_type::none);
870         KeyGenerator keygen(context);
871         PublicKey pk;
872         keygen.create_public_key(pk);
873 
874         Encryptor encryptor(context, pk);
875         Evaluator evaluator(context);
876         Decryptor decryptor(context, keygen.secret_key());
877 
878         Ciphertext encrypted1;
879         Plaintext plain, plain1, plain2;
880 
881         plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
882         plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
883         encryptor.encrypt(plain1, encrypted1);
884         evaluator.sub_plain_inplace(encrypted1, plain2);
885         decryptor.decrypt(encrypted1, plain);
886         ASSERT_EQ(
887             plain.to_string(),
888             "1x^28 + 1x^25 + 1x^21 + 1x^20 + 3Fx^16 + 1x^12 + 1x^10 + 3Fx^8 + 1x^6 + 1x^4 + 1x^3 + 3F");
889         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
890 
891         plain1 = "0";
892         plain2 = "0";
893         encryptor.encrypt(plain1, encrypted1);
894         evaluator.sub_plain_inplace(encrypted1, plain2);
895         decryptor.decrypt(encrypted1, plain);
896         ASSERT_EQ(plain.to_string(), "0");
897         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
898 
899         plain1 = "0";
900         plain2 = "1x^2 + 1";
901         encryptor.encrypt(plain1, encrypted1);
902         evaluator.sub_plain_inplace(encrypted1, plain2);
903         decryptor.decrypt(encrypted1, plain);
904         ASSERT_EQ(plain.to_string(), "3Fx^2 + 3F");
905         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
906 
907         plain1 = "1x^2 + 1";
908         plain2 = "3Fx^1 + 3F";
909         encryptor.encrypt(plain1, encrypted1);
910         evaluator.sub_plain_inplace(encrypted1, plain2);
911         decryptor.decrypt(encrypted1, plain);
912         ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 2");
913         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
914 
915         plain1 = "3Fx^2 + 3Fx^1 + 3F";
916         plain2 = "1x^1";
917         encryptor.encrypt(plain1, encrypted1);
918         evaluator.sub_plain_inplace(encrypted1, plain2);
919         decryptor.decrypt(encrypted1, plain);
920         ASSERT_EQ(plain.to_string(), "3Fx^2 + 3Ex^1 + 3F");
921         ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
922     }
923 
TEST(EvaluatorTest,BFVEncryptMultiplyPlainDecrypt)924     TEST(EvaluatorTest, BFVEncryptMultiplyPlainDecrypt)
925     {
926         {
927             EncryptionParameters parms(scheme_type::bfv);
928             Modulus plain_modulus(1 << 6);
929             parms.set_poly_modulus_degree(64);
930             parms.set_plain_modulus(plain_modulus);
931             parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 }));
932 
933             SEALContext context(parms, false, sec_level_type::none);
934             KeyGenerator keygen(context);
935             PublicKey pk;
936             keygen.create_public_key(pk);
937 
938             Encryptor encryptor(context, pk);
939             Evaluator evaluator(context);
940             Decryptor decryptor(context, keygen.secret_key());
941 
942             Ciphertext encrypted;
943             Plaintext plain, plain1, plain2;
944 
945             plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
946             plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
947             encryptor.encrypt(plain1, encrypted);
948             evaluator.multiply_plain_inplace(encrypted, plain2);
949             decryptor.decrypt(encrypted, plain);
950             ASSERT_EQ(
951                 plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + "
952                                    "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + "
953                                    "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + "
954                                    "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3");
955             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
956 
957             plain1 = "0";
958             plain2 = "1x^2 + 1";
959             encryptor.encrypt(plain1, encrypted);
960             evaluator.multiply_plain_inplace(encrypted, plain2);
961             decryptor.decrypt(encrypted, plain);
962             ASSERT_EQ(plain.to_string(), "0");
963             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
964 
965             plain1 = "1x^2 + 1x^1 + 1";
966             plain2 = "1x^2";
967             encryptor.encrypt(plain1, encrypted);
968             evaluator.multiply_plain_inplace(encrypted, plain2);
969             decryptor.decrypt(encrypted, plain);
970             ASSERT_EQ(plain.to_string(), "1x^4 + 1x^3 + 1x^2");
971             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
972 
973             plain1 = "1x^2 + 1x^1 + 1";
974             plain2 = "1x^1";
975             encryptor.encrypt(plain1, encrypted);
976             evaluator.multiply_plain_inplace(encrypted, plain2);
977             decryptor.decrypt(encrypted, plain);
978             ASSERT_EQ(plain.to_string(), "1x^3 + 1x^2 + 1x^1");
979             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
980 
981             plain1 = "1x^2 + 1x^1 + 1";
982             plain2 = "1";
983             encryptor.encrypt(plain1, encrypted);
984             evaluator.multiply_plain_inplace(encrypted, plain2);
985             decryptor.decrypt(encrypted, plain);
986             ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1");
987             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
988 
989             plain1 = "1x^2 + 1";
990             plain2 = "3Fx^1 + 3F";
991             encryptor.encrypt(plain1, encrypted);
992             evaluator.multiply_plain_inplace(encrypted, plain2);
993             decryptor.decrypt(encrypted, plain);
994             ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1 + 3F");
995             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
996 
997             plain1 = "3Fx^2 + 3Fx^1 + 3F";
998             plain2 = "1x^1";
999             encryptor.encrypt(plain1, encrypted);
1000             evaluator.multiply_plain_inplace(encrypted, plain2);
1001             decryptor.decrypt(encrypted, plain);
1002             ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1");
1003             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1004         }
1005         {
1006             EncryptionParameters parms(scheme_type::bfv);
1007             Modulus plain_modulus((1ULL << 20) - 1);
1008             parms.set_poly_modulus_degree(64);
1009             parms.set_plain_modulus(plain_modulus);
1010             parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 }));
1011 
1012             SEALContext context(parms, false, sec_level_type::none);
1013             KeyGenerator keygen(context);
1014             PublicKey pk;
1015             keygen.create_public_key(pk);
1016 
1017             Encryptor encryptor(context, pk);
1018             Evaluator evaluator(context);
1019             Decryptor decryptor(context, keygen.secret_key());
1020 
1021             Ciphertext encrypted;
1022             Plaintext plain, plain1, plain2;
1023 
1024             plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
1025             plain2 = "1";
1026             encryptor.encrypt(plain1, encrypted);
1027             evaluator.multiply_plain_inplace(encrypted, plain2);
1028             decryptor.decrypt(encrypted, plain);
1029             ASSERT_EQ(
1030                 plain.to_string(),
1031                 "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3");
1032             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1033 
1034             plain2 = "5";
1035             evaluator.multiply_plain_inplace(encrypted, plain2);
1036             decryptor.decrypt(encrypted, plain);
1037             ASSERT_EQ(
1038                 plain.to_string(),
1039                 "5x^28 + 5x^25 + 5x^21 + 5x^20 + 5x^18 + 5x^14 + 5x^12 + 5x^10 + 5x^9 + 5x^6 + 5x^5 + 5x^4 + 5x^3");
1040             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1041         }
1042         {
1043             EncryptionParameters parms(scheme_type::bfv);
1044             Modulus plain_modulus((1ULL << 40) - 1);
1045             parms.set_poly_modulus_degree(64);
1046             parms.set_plain_modulus(plain_modulus);
1047             parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 }));
1048 
1049             SEALContext context(parms, false, sec_level_type::none);
1050             KeyGenerator keygen(context);
1051             PublicKey pk;
1052             keygen.create_public_key(pk);
1053 
1054             Encryptor encryptor(context, pk);
1055             Evaluator evaluator(context);
1056             Decryptor decryptor(context, keygen.secret_key());
1057 
1058             Ciphertext encrypted;
1059             Plaintext plain, plain1, plain2;
1060 
1061             plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
1062             plain2 = "1";
1063             encryptor.encrypt(plain1, encrypted);
1064             evaluator.multiply_plain_inplace(encrypted, plain2);
1065             decryptor.decrypt(encrypted, plain);
1066             ASSERT_EQ(
1067                 plain.to_string(),
1068                 "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3");
1069             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1070 
1071             plain2 = "5";
1072             evaluator.multiply_plain_inplace(encrypted, plain2);
1073             decryptor.decrypt(encrypted, plain);
1074             ASSERT_EQ(
1075                 plain.to_string(),
1076                 "5x^28 + 5x^25 + 5x^21 + 5x^20 + 5x^18 + 5x^14 + 5x^12 + 5x^10 + 5x^9 + 5x^6 + 5x^5 + 5x^4 + 5x^3");
1077             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1078         }
1079         {
1080             EncryptionParameters parms(scheme_type::bfv);
1081             Modulus plain_modulus(PlainModulus::Batching(64, 20));
1082             parms.set_poly_modulus_degree(64);
1083             parms.set_plain_modulus(plain_modulus);
1084             parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30, 30 }));
1085 
1086             SEALContext context(parms, false, sec_level_type::none);
1087             KeyGenerator keygen(context);
1088             PublicKey pk;
1089             keygen.create_public_key(pk);
1090 
1091             BatchEncoder batch_encoder(context);
1092             Encryptor encryptor(context, pk);
1093             Evaluator evaluator(context);
1094             Decryptor decryptor(context, keygen.secret_key());
1095 
1096             Ciphertext encrypted;
1097             Plaintext plain;
1098             vector<int64_t> result;
1099 
1100             batch_encoder.encode(vector<int64_t>(batch_encoder.slot_count(), 7), plain);
1101             encryptor.encrypt(plain, encrypted);
1102             evaluator.multiply_plain_inplace(encrypted, plain);
1103             decryptor.decrypt(encrypted, plain);
1104             batch_encoder.decode(plain, result);
1105             ASSERT_TRUE(vector<int64_t>(batch_encoder.slot_count(), 49) == result);
1106             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1107 
1108             batch_encoder.encode(vector<int64_t>(batch_encoder.slot_count(), -7), plain);
1109             encryptor.encrypt(plain, encrypted);
1110             evaluator.multiply_plain_inplace(encrypted, plain);
1111             decryptor.decrypt(encrypted, plain);
1112             batch_encoder.decode(plain, result);
1113             ASSERT_TRUE(vector<int64_t>(batch_encoder.slot_count(), 49) == result);
1114             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1115         }
1116         {
1117             EncryptionParameters parms(scheme_type::bfv);
1118             Modulus plain_modulus(PlainModulus::Batching(64, 40));
1119             parms.set_poly_modulus_degree(64);
1120             parms.set_plain_modulus(plain_modulus);
1121             parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30, 30, 30, 30 }));
1122 
1123             SEALContext context(parms, false, sec_level_type::none);
1124             KeyGenerator keygen(context);
1125             PublicKey pk;
1126             keygen.create_public_key(pk);
1127 
1128             BatchEncoder batch_encoder(context);
1129             Encryptor encryptor(context, pk);
1130             Evaluator evaluator(context);
1131             Decryptor decryptor(context, keygen.secret_key());
1132 
1133             Ciphertext encrypted;
1134             Plaintext plain;
1135             vector<int64_t> result;
1136 
1137             // First test with constant plaintext
1138             batch_encoder.encode(vector<int64_t>(batch_encoder.slot_count(), 7), plain);
1139             encryptor.encrypt(plain, encrypted);
1140             evaluator.multiply_plain_inplace(encrypted, plain);
1141             decryptor.decrypt(encrypted, plain);
1142             batch_encoder.decode(plain, result);
1143             ASSERT_TRUE(vector<int64_t>(batch_encoder.slot_count(), 49) == result);
1144             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1145 
1146             batch_encoder.encode(vector<int64_t>(batch_encoder.slot_count(), -7), plain);
1147             encryptor.encrypt(plain, encrypted);
1148             evaluator.multiply_plain_inplace(encrypted, plain);
1149             decryptor.decrypt(encrypted, plain);
1150             batch_encoder.decode(plain, result);
1151             ASSERT_TRUE(vector<int64_t>(batch_encoder.slot_count(), 49) == result);
1152             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1153 
1154             // Now test a non-constant plaintext
1155             vector<int64_t> input(batch_encoder.slot_count() - 1, 7);
1156             input.push_back(1);
1157             vector<int64_t> true_result(batch_encoder.slot_count() - 1, 49);
1158             true_result.push_back(1);
1159             batch_encoder.encode(input, plain);
1160             encryptor.encrypt(plain, encrypted);
1161             evaluator.multiply_plain_inplace(encrypted, plain);
1162             decryptor.decrypt(encrypted, plain);
1163             batch_encoder.decode(plain, result);
1164             ASSERT_TRUE(true_result == result);
1165             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1166 
1167             input = vector<int64_t>(batch_encoder.slot_count() - 1, -7);
1168             input.push_back(1);
1169             batch_encoder.encode(input, plain);
1170             encryptor.encrypt(plain, encrypted);
1171             evaluator.multiply_plain_inplace(encrypted, plain);
1172             decryptor.decrypt(encrypted, plain);
1173             batch_encoder.decode(plain, result);
1174             ASSERT_TRUE(true_result == result);
1175             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1176         }
1177     }
1178 
TEST(EvaluatorTest,BFVEncryptMultiplyDecrypt)1179     TEST(EvaluatorTest, BFVEncryptMultiplyDecrypt)
1180     {
1181         {
1182             EncryptionParameters parms(scheme_type::bfv);
1183             Modulus plain_modulus(1 << 6);
1184             parms.set_poly_modulus_degree(64);
1185             parms.set_plain_modulus(plain_modulus);
1186             parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 }));
1187 
1188             SEALContext context(parms, false, sec_level_type::none);
1189             KeyGenerator keygen(context);
1190             PublicKey pk;
1191             keygen.create_public_key(pk);
1192 
1193             Encryptor encryptor(context, pk);
1194             Evaluator evaluator(context);
1195             Decryptor decryptor(context, keygen.secret_key());
1196 
1197             Ciphertext encrypted1;
1198             Ciphertext encrypted2;
1199             Plaintext plain, plain1, plain2;
1200 
1201             plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
1202             plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
1203             encryptor.encrypt(plain1, encrypted1);
1204             encryptor.encrypt(plain2, encrypted2);
1205             evaluator.multiply_inplace(encrypted1, encrypted2);
1206             decryptor.decrypt(encrypted1, plain);
1207             ASSERT_EQ(
1208                 plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + "
1209                                    "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + "
1210                                    "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + "
1211                                    "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3");
1212             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1213             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1214 
1215             plain1 = "0";
1216             plain2 = "0";
1217             encryptor.encrypt(plain1, encrypted1);
1218             encryptor.encrypt(plain2, encrypted2);
1219             evaluator.multiply_inplace(encrypted1, encrypted2);
1220             decryptor.decrypt(encrypted1, plain);
1221             ASSERT_EQ(plain.to_string(), "0");
1222             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1223             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1224 
1225             plain1 = "0";
1226             plain2 = "1x^2 + 1";
1227             encryptor.encrypt(plain1, encrypted1);
1228             encryptor.encrypt(plain2, encrypted2);
1229             evaluator.multiply_inplace(encrypted1, encrypted2);
1230             decryptor.decrypt(encrypted1, plain);
1231             ASSERT_EQ(plain.to_string(), "0");
1232             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1233             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1234 
1235             plain1 = "1x^2 + 1x^1 + 1";
1236             plain2 = "1";
1237             encryptor.encrypt(plain1, encrypted1);
1238             encryptor.encrypt(plain2, encrypted2);
1239             evaluator.multiply_inplace(encrypted1, encrypted2);
1240             decryptor.decrypt(encrypted1, plain);
1241             ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1");
1242             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1243             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1244 
1245             plain1 = "1x^2 + 1";
1246             plain2 = "3Fx^1 + 3F";
1247             encryptor.encrypt(plain1, encrypted1);
1248             encryptor.encrypt(plain2, encrypted2);
1249             evaluator.multiply_inplace(encrypted1, encrypted2);
1250             decryptor.decrypt(encrypted1, plain);
1251             ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1 + 3F");
1252             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1253             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1254 
1255             plain1 = "1x^16";
1256             plain2 = "1x^8";
1257             encryptor.encrypt(plain1, encrypted1);
1258             encryptor.encrypt(plain2, encrypted2);
1259             evaluator.multiply_inplace(encrypted1, encrypted2);
1260             decryptor.decrypt(encrypted1, plain);
1261             ASSERT_EQ(plain.to_string(), "1x^24");
1262             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1263             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1264         }
1265         {
1266             EncryptionParameters parms(scheme_type::bfv);
1267             Modulus plain_modulus((1ULL << 60) - 1);
1268             parms.set_poly_modulus_degree(64);
1269             parms.set_plain_modulus(plain_modulus);
1270             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60, 60 }));
1271 
1272             SEALContext context(parms, false, sec_level_type::none);
1273             KeyGenerator keygen(context);
1274             PublicKey pk;
1275             keygen.create_public_key(pk);
1276 
1277             Encryptor encryptor(context, pk);
1278             Evaluator evaluator(context);
1279             Decryptor decryptor(context, keygen.secret_key());
1280 
1281             Ciphertext encrypted1;
1282             Ciphertext encrypted2;
1283             Plaintext plain, plain1, plain2;
1284 
1285             plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
1286             plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
1287             encryptor.encrypt(plain1, encrypted1);
1288             encryptor.encrypt(plain2, encrypted2);
1289             evaluator.multiply_inplace(encrypted1, encrypted2);
1290             decryptor.decrypt(encrypted1, plain);
1291             ASSERT_EQ(
1292                 plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + "
1293                                    "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + "
1294                                    "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + "
1295                                    "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3");
1296             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1297             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1298 
1299             plain1 = "0";
1300             plain2 = "0";
1301             encryptor.encrypt(plain1, encrypted1);
1302             encryptor.encrypt(plain2, encrypted2);
1303             evaluator.multiply_inplace(encrypted1, encrypted2);
1304             decryptor.decrypt(encrypted1, plain);
1305             ASSERT_EQ(plain.to_string(), "0");
1306             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1307             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1308 
1309             plain1 = "0";
1310             plain2 = "1x^2 + 1";
1311             encryptor.encrypt(plain1, encrypted1);
1312             encryptor.encrypt(plain2, encrypted2);
1313             evaluator.multiply_inplace(encrypted1, encrypted2);
1314             decryptor.decrypt(encrypted1, plain);
1315             ASSERT_EQ(plain.to_string(), "0");
1316             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1317             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1318 
1319             plain1 = "1x^2 + 1x^1 + 1";
1320             plain2 = "1";
1321             encryptor.encrypt(plain1, encrypted1);
1322             encryptor.encrypt(plain2, encrypted2);
1323             evaluator.multiply_inplace(encrypted1, encrypted2);
1324             decryptor.decrypt(encrypted1, plain);
1325             ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1");
1326             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1327             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1328 
1329             plain1 = "1x^2 + 1";
1330             plain2 = "FFFFFFFFFFFFFFEx^1 + FFFFFFFFFFFFFFE";
1331             encryptor.encrypt(plain1, encrypted1);
1332             encryptor.encrypt(plain2, encrypted2);
1333             evaluator.multiply_inplace(encrypted1, encrypted2);
1334             decryptor.decrypt(encrypted1, plain);
1335             ASSERT_EQ(
1336                 plain.to_string(), "FFFFFFFFFFFFFFEx^3 + FFFFFFFFFFFFFFEx^2 + FFFFFFFFFFFFFFEx^1 + FFFFFFFFFFFFFFE");
1337             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1338             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1339 
1340             plain1 = "1x^16";
1341             plain2 = "1x^8";
1342             encryptor.encrypt(plain1, encrypted1);
1343             encryptor.encrypt(plain2, encrypted2);
1344             evaluator.multiply_inplace(encrypted1, encrypted2);
1345             decryptor.decrypt(encrypted1, plain);
1346             ASSERT_EQ(plain.to_string(), "1x^24");
1347             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1348             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1349         }
1350         {
1351             EncryptionParameters parms(scheme_type::bfv);
1352             Modulus plain_modulus(1 << 6);
1353             parms.set_poly_modulus_degree(128);
1354             parms.set_plain_modulus(plain_modulus);
1355             parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 }));
1356 
1357             SEALContext context(parms, false, sec_level_type::none);
1358             KeyGenerator keygen(context);
1359             PublicKey pk;
1360             keygen.create_public_key(pk);
1361 
1362             Encryptor encryptor(context, pk);
1363             Evaluator evaluator(context);
1364             Decryptor decryptor(context, keygen.secret_key());
1365 
1366             Ciphertext encrypted1;
1367             Ciphertext encrypted2;
1368             Plaintext plain, plain1, plain2;
1369 
1370             plain1 = "1x^28 + 1x^25 + 1x^21 + 1x^20 + 1x^18 + 1x^14 + 1x^12 + 1x^10 + 1x^9 + 1x^6 + 1x^5 + 1x^4 + 1x^3";
1371             plain2 = "1x^18 + 1x^16 + 1x^14 + 1x^9 + 1x^8 + 1x^5 + 1";
1372             encryptor.encrypt(plain1, encrypted1);
1373             encryptor.encrypt(plain2, encrypted2);
1374             evaluator.multiply_inplace(encrypted1, encrypted2);
1375             decryptor.decrypt(encrypted1, plain);
1376             ASSERT_EQ(
1377                 plain.to_string(), "1x^46 + 1x^44 + 1x^43 + 1x^42 + 1x^41 + 2x^39 + 1x^38 + 2x^37 + 3x^36 + 1x^35 + "
1378                                    "3x^34 + 2x^33 + 2x^32 + 4x^30 + 2x^29 + 5x^28 + 2x^27 + 4x^26 + 3x^25 + 2x^24 + "
1379                                    "4x^23 + 3x^22 + 4x^21 + 4x^20 + 4x^19 + 4x^18 + 3x^17 + 2x^15 + 4x^14 + 2x^13 + "
1380                                    "3x^12 + 2x^11 + 2x^10 + 2x^9 + 1x^8 + 1x^6 + 1x^5 + 1x^4 + 1x^3");
1381             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1382             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1383 
1384             plain1 = "0";
1385             plain2 = "0";
1386             encryptor.encrypt(plain1, encrypted1);
1387             encryptor.encrypt(plain2, encrypted2);
1388             evaluator.multiply_inplace(encrypted1, encrypted2);
1389             decryptor.decrypt(encrypted1, plain);
1390             ASSERT_EQ(plain.to_string(), "0");
1391             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1392             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1393 
1394             plain1 = "0";
1395             plain2 = "1x^2 + 1";
1396             encryptor.encrypt(plain1, encrypted1);
1397             encryptor.encrypt(plain2, encrypted2);
1398             evaluator.multiply_inplace(encrypted1, encrypted2);
1399             decryptor.decrypt(encrypted1, plain);
1400             ASSERT_EQ(plain.to_string(), "0");
1401             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1402             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1403 
1404             plain1 = "1x^2 + 1x^1 + 1";
1405             plain2 = "1";
1406             encryptor.encrypt(plain1, encrypted1);
1407             encryptor.encrypt(plain2, encrypted2);
1408             evaluator.multiply_inplace(encrypted1, encrypted2);
1409             decryptor.decrypt(encrypted1, plain);
1410             ASSERT_EQ(plain.to_string(), "1x^2 + 1x^1 + 1");
1411             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1412             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1413 
1414             plain1 = "1x^2 + 1";
1415             plain2 = "3Fx^1 + 3F";
1416             encryptor.encrypt(plain1, encrypted1);
1417             encryptor.encrypt(plain2, encrypted2);
1418             evaluator.multiply_inplace(encrypted1, encrypted2);
1419             decryptor.decrypt(encrypted1, plain);
1420             ASSERT_EQ(plain.to_string(), "3Fx^3 + 3Fx^2 + 3Fx^1 + 3F");
1421             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1422             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1423 
1424             plain1 = "1x^16";
1425             plain2 = "1x^8";
1426             encryptor.encrypt(plain1, encrypted1);
1427             encryptor.encrypt(plain2, encrypted2);
1428             evaluator.multiply_inplace(encrypted1, encrypted2);
1429             decryptor.decrypt(encrypted1, plain);
1430             ASSERT_EQ(plain.to_string(), "1x^24");
1431             ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id());
1432             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1433         }
1434         {
1435             EncryptionParameters parms(scheme_type::bfv);
1436             Modulus plain_modulus(1 << 8);
1437             parms.set_poly_modulus_degree(128);
1438             parms.set_plain_modulus(plain_modulus);
1439             parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 }));
1440 
1441             SEALContext context(parms, false, sec_level_type::none);
1442             KeyGenerator keygen(context);
1443             PublicKey pk;
1444             keygen.create_public_key(pk);
1445 
1446             Encryptor encryptor(context, pk);
1447             Evaluator evaluator(context);
1448             Decryptor decryptor(context, keygen.secret_key());
1449 
1450             Ciphertext encrypted1;
1451             Plaintext plain, plain1;
1452 
1453             plain1 = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1";
1454             encryptor.encrypt(plain1, encrypted1);
1455             evaluator.multiply(encrypted1, encrypted1, encrypted1);
1456             evaluator.multiply(encrypted1, encrypted1, encrypted1);
1457             decryptor.decrypt(encrypted1, plain);
1458             ASSERT_EQ(
1459                 plain.to_string(), "1x^24 + 4x^23 + Ax^22 + 14x^21 + 1Fx^20 + 2Cx^19 + 3Cx^18 + 4Cx^17 + 5Fx^16 + "
1460                                    "6Cx^15 + 70x^14 + 74x^13 + 71x^12 + 6Cx^11 + 64x^10 + 50x^9 + 40x^8 + 34x^7 + "
1461                                    "26x^6 + 1Cx^5 + 11x^4 + 8x^3 + 6x^2 + 4x^1 + 1");
1462             ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1463         }
1464     }
1465 
1466 #include "seal/randomgen.h"
TEST(EvaluatorTest,BFVRelinearize)1467     TEST(EvaluatorTest, BFVRelinearize)
1468     {
1469         EncryptionParameters parms(scheme_type::bfv);
1470         Modulus plain_modulus(1 << 6);
1471         parms.set_poly_modulus_degree(128);
1472         parms.set_plain_modulus(plain_modulus);
1473         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40 }));
1474 
1475         SEALContext context(parms, true, sec_level_type::none);
1476         KeyGenerator keygen(context);
1477         PublicKey pk;
1478         keygen.create_public_key(pk);
1479         RelinKeys rlk;
1480         keygen.create_relin_keys(rlk);
1481 
1482         Encryptor encryptor(context, pk);
1483         Evaluator evaluator(context);
1484         Decryptor decryptor(context, keygen.secret_key());
1485 
1486         Ciphertext encrypted(context);
1487         Ciphertext encrypted2(context);
1488 
1489         Plaintext plain;
1490         Plaintext plain2;
1491 
1492         plain = 0;
1493         encryptor.encrypt(plain, encrypted);
1494         evaluator.square_inplace(encrypted);
1495         evaluator.relinearize_inplace(encrypted, rlk);
1496         decryptor.decrypt(encrypted, plain2);
1497         ASSERT_TRUE(plain == plain2);
1498 
1499         encryptor.encrypt(plain, encrypted);
1500         evaluator.square_inplace(encrypted);
1501         evaluator.relinearize_inplace(encrypted, rlk);
1502         evaluator.square_inplace(encrypted);
1503         evaluator.relinearize_inplace(encrypted, rlk);
1504         decryptor.decrypt(encrypted, plain2);
1505         ASSERT_TRUE(plain == plain2);
1506 
1507         plain = "1x^10 + 2";
1508         encryptor.encrypt(plain, encrypted);
1509         evaluator.square_inplace(encrypted);
1510         evaluator.relinearize_inplace(encrypted, rlk);
1511         decryptor.decrypt(encrypted, plain2);
1512         ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4");
1513 
1514         encryptor.encrypt(plain, encrypted);
1515         evaluator.square_inplace(encrypted);
1516         evaluator.relinearize_inplace(encrypted, rlk);
1517         evaluator.square_inplace(encrypted);
1518         evaluator.relinearize_inplace(encrypted, rlk);
1519         decryptor.decrypt(encrypted, plain2);
1520         ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10");
1521 
1522         // Relinearization with modulus switching
1523         plain = "1x^10 + 2";
1524         encryptor.encrypt(plain, encrypted);
1525         evaluator.square_inplace(encrypted);
1526         evaluator.relinearize_inplace(encrypted, rlk);
1527         evaluator.mod_switch_to_next_inplace(encrypted);
1528         decryptor.decrypt(encrypted, plain2);
1529         ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4");
1530 
1531         encryptor.encrypt(plain, encrypted);
1532         evaluator.square_inplace(encrypted);
1533         evaluator.relinearize_inplace(encrypted, rlk);
1534         evaluator.mod_switch_to_next_inplace(encrypted);
1535         evaluator.square_inplace(encrypted);
1536         evaluator.relinearize_inplace(encrypted, rlk);
1537         evaluator.mod_switch_to_next_inplace(encrypted);
1538         decryptor.decrypt(encrypted, plain2);
1539         ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10");
1540     }
1541 
TEST(EvaluatorTest,CKKSEncryptNaiveMultiplyDecrypt)1542     TEST(EvaluatorTest, CKKSEncryptNaiveMultiplyDecrypt)
1543     {
1544         EncryptionParameters parms(scheme_type::ckks);
1545         {
1546             // Multiplying two zero vectors
1547             size_t slot_size = 32;
1548             parms.set_poly_modulus_degree(slot_size * 2);
1549             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30 }));
1550 
1551             SEALContext context(parms, false, sec_level_type::none);
1552             KeyGenerator keygen(context);
1553             PublicKey pk;
1554             keygen.create_public_key(pk);
1555 
1556             CKKSEncoder encoder(context);
1557             Encryptor encryptor(context, pk);
1558             Decryptor decryptor(context, keygen.secret_key());
1559             Evaluator evaluator(context);
1560 
1561             Ciphertext encrypted;
1562             Plaintext plain;
1563             Plaintext plainRes;
1564 
1565             vector<complex<double>> input(slot_size, 0.0);
1566             vector<complex<double>> output(slot_size);
1567             const double delta = static_cast<double>(1 << 30);
1568             encoder.encode(input, context.first_parms_id(), delta, plain);
1569 
1570             encryptor.encrypt(plain, encrypted);
1571             evaluator.multiply_inplace(encrypted, encrypted);
1572 
1573             // Check correctness of encryption
1574             ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
1575 
1576             decryptor.decrypt(encrypted, plainRes);
1577             encoder.decode(plainRes, output);
1578             for (size_t i = 0; i < slot_size; i++)
1579             {
1580                 auto tmp = abs(input[i].real() - output[i].real());
1581                 ASSERT_TRUE(tmp < 0.5);
1582             }
1583         }
1584         {
1585             // Multiplying two random vectors
1586             size_t slot_size = 32;
1587             parms.set_poly_modulus_degree(slot_size * 2);
1588             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
1589 
1590             SEALContext context(parms, false, sec_level_type::none);
1591             KeyGenerator keygen(context);
1592             PublicKey pk;
1593             keygen.create_public_key(pk);
1594 
1595             CKKSEncoder encoder(context);
1596             Encryptor encryptor(context, pk);
1597             Decryptor decryptor(context, keygen.secret_key());
1598             Evaluator evaluator(context);
1599 
1600             Ciphertext encrypted1;
1601             Ciphertext encrypted2;
1602             Plaintext plain1;
1603             Plaintext plain2;
1604             Plaintext plainRes;
1605 
1606             vector<complex<double>> input1(slot_size, 0.0);
1607             vector<complex<double>> input2(slot_size, 0.0);
1608             vector<complex<double>> expected(slot_size, 0.0);
1609             vector<complex<double>> output(slot_size);
1610             const double delta = static_cast<double>(1ULL << 40);
1611 
1612             int data_bound = (1 << 10);
1613             srand(static_cast<unsigned>(time(NULL)));
1614 
1615             for (int round = 0; round < 100; round++)
1616             {
1617                 for (size_t i = 0; i < slot_size; i++)
1618                 {
1619                     input1[i] = static_cast<double>(rand() % data_bound);
1620                     input2[i] = static_cast<double>(rand() % data_bound);
1621                     expected[i] = input1[i] * input2[i];
1622                 }
1623                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
1624                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
1625 
1626                 encryptor.encrypt(plain1, encrypted1);
1627                 encryptor.encrypt(plain2, encrypted2);
1628                 evaluator.multiply_inplace(encrypted1, encrypted2);
1629 
1630                 // Check correctness of encryption
1631                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1632 
1633                 decryptor.decrypt(encrypted1, plainRes);
1634                 encoder.decode(plainRes, output);
1635                 for (size_t i = 0; i < slot_size; i++)
1636                 {
1637                     auto tmp = abs(expected[i].real() - output[i].real());
1638                     ASSERT_TRUE(tmp < 0.5);
1639                 }
1640             }
1641         }
1642         {
1643             // Multiplying two random vectors
1644             size_t slot_size = 16;
1645             parms.set_poly_modulus_degree(64);
1646             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 }));
1647 
1648             SEALContext context(parms, false, sec_level_type::none);
1649             KeyGenerator keygen(context);
1650             PublicKey pk;
1651             keygen.create_public_key(pk);
1652 
1653             CKKSEncoder encoder(context);
1654             Encryptor encryptor(context, pk);
1655             Decryptor decryptor(context, keygen.secret_key());
1656             Evaluator evaluator(context);
1657 
1658             Ciphertext encrypted1;
1659             Ciphertext encrypted2;
1660             Plaintext plain1;
1661             Plaintext plain2;
1662             Plaintext plainRes;
1663 
1664             vector<complex<double>> input1(slot_size, 0.0);
1665             vector<complex<double>> input2(slot_size, 0.0);
1666             vector<complex<double>> expected(slot_size, 0.0);
1667             vector<complex<double>> output(slot_size);
1668             const double delta = static_cast<double>(1ULL << 40);
1669 
1670             int data_bound = (1 << 10);
1671             srand(static_cast<unsigned>(time(NULL)));
1672 
1673             for (int round = 0; round < 100; round++)
1674             {
1675                 for (size_t i = 0; i < slot_size; i++)
1676                 {
1677                     input1[i] = static_cast<double>(rand() % data_bound);
1678                     input2[i] = static_cast<double>(rand() % data_bound);
1679                     expected[i] = input1[i] * input2[i];
1680                 }
1681                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
1682                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
1683 
1684                 encryptor.encrypt(plain1, encrypted1);
1685                 encryptor.encrypt(plain2, encrypted2);
1686                 evaluator.multiply_inplace(encrypted1, encrypted2);
1687 
1688                 // Check correctness of encryption
1689                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1690 
1691                 decryptor.decrypt(encrypted1, plainRes);
1692                 encoder.decode(plainRes, output);
1693                 for (size_t i = 0; i < slot_size; i++)
1694                 {
1695                     auto tmp = abs(expected[i].real() - output[i].real());
1696                     ASSERT_TRUE(tmp < 0.5);
1697                 }
1698             }
1699         }
1700     }
1701 
TEST(EvaluatorTest,CKKSEncryptMultiplyByNumberDecrypt)1702     TEST(EvaluatorTest, CKKSEncryptMultiplyByNumberDecrypt)
1703     {
1704         EncryptionParameters parms(scheme_type::ckks);
1705         {
1706             // Multiplying two random vectors by an integer
1707             size_t slot_size = 32;
1708             parms.set_poly_modulus_degree(slot_size * 2);
1709             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 40 }));
1710 
1711             SEALContext context(parms, false, sec_level_type::none);
1712             KeyGenerator keygen(context);
1713             PublicKey pk;
1714             keygen.create_public_key(pk);
1715 
1716             CKKSEncoder encoder(context);
1717             Encryptor encryptor(context, pk);
1718             Decryptor decryptor(context, keygen.secret_key());
1719             Evaluator evaluator(context);
1720 
1721             Ciphertext encrypted1;
1722             Plaintext plain1;
1723             Plaintext plain2;
1724             Plaintext plainRes;
1725 
1726             vector<complex<double>> input1(slot_size, 0.0);
1727             int64_t input2;
1728             vector<complex<double>> expected(slot_size, 0.0);
1729 
1730             int data_bound = (1 << 10);
1731             srand(static_cast<unsigned>(time(NULL)));
1732 
1733             for (int iExp = 0; iExp < 50; iExp++)
1734             {
1735                 input2 = max(rand() % data_bound, 1);
1736                 for (size_t i = 0; i < slot_size; i++)
1737                 {
1738                     input1[i] = static_cast<double>(rand() % data_bound);
1739                     expected[i] = input1[i] * static_cast<double>(input2);
1740                 }
1741 
1742                 vector<complex<double>> output(slot_size);
1743                 const double delta = static_cast<double>(1ULL << 40);
1744                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
1745                 encoder.encode(input2, context.first_parms_id(), plain2);
1746 
1747                 encryptor.encrypt(plain1, encrypted1);
1748                 evaluator.multiply_plain_inplace(encrypted1, plain2);
1749 
1750                 // Check correctness of encryption
1751                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1752 
1753                 decryptor.decrypt(encrypted1, plainRes);
1754                 encoder.decode(plainRes, output);
1755                 for (size_t i = 0; i < slot_size; i++)
1756                 {
1757                     auto tmp = abs(expected[i].real() - output[i].real());
1758                     ASSERT_TRUE(tmp < 0.5);
1759                 }
1760             }
1761         }
1762         {
1763             // Multiplying two random vectors by an integer
1764             size_t slot_size = 8;
1765             parms.set_poly_modulus_degree(64);
1766             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 }));
1767 
1768             SEALContext context(parms, false, sec_level_type::none);
1769             KeyGenerator keygen(context);
1770             PublicKey pk;
1771             keygen.create_public_key(pk);
1772 
1773             CKKSEncoder encoder(context);
1774             Encryptor encryptor(context, pk);
1775             Decryptor decryptor(context, keygen.secret_key());
1776             Evaluator evaluator(context);
1777 
1778             Ciphertext encrypted1;
1779             Plaintext plain1;
1780             Plaintext plain2;
1781             Plaintext plainRes;
1782 
1783             vector<complex<double>> input1(slot_size, 0.0);
1784             int64_t input2;
1785             vector<complex<double>> expected(slot_size, 0.0);
1786 
1787             int data_bound = (1 << 10);
1788             srand(static_cast<unsigned>(time(NULL)));
1789 
1790             for (int iExp = 0; iExp < 50; iExp++)
1791             {
1792                 input2 = max(rand() % data_bound, 1);
1793                 for (size_t i = 0; i < slot_size; i++)
1794                 {
1795                     input1[i] = static_cast<double>(rand() % data_bound);
1796                     expected[i] = input1[i] * static_cast<double>(input2);
1797                 }
1798 
1799                 vector<complex<double>> output(slot_size);
1800                 const double delta = static_cast<double>(1ULL << 40);
1801                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
1802                 encoder.encode(input2, context.first_parms_id(), plain2);
1803 
1804                 encryptor.encrypt(plain1, encrypted1);
1805                 evaluator.multiply_plain_inplace(encrypted1, plain2);
1806 
1807                 // Check correctness of encryption
1808                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1809 
1810                 decryptor.decrypt(encrypted1, plainRes);
1811                 encoder.decode(plainRes, output);
1812                 for (size_t i = 0; i < slot_size; i++)
1813                 {
1814                     auto tmp = abs(expected[i].real() - output[i].real());
1815                     ASSERT_TRUE(tmp < 0.5);
1816                 }
1817             }
1818         }
1819         {
1820             // Multiplying two random vectors by a double
1821             size_t slot_size = 32;
1822             parms.set_poly_modulus_degree(slot_size * 2);
1823             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
1824 
1825             SEALContext context(parms, false, sec_level_type::none);
1826             KeyGenerator keygen(context);
1827             PublicKey pk;
1828             keygen.create_public_key(pk);
1829 
1830             CKKSEncoder encoder(context);
1831             Encryptor encryptor(context, pk);
1832             Decryptor decryptor(context, keygen.secret_key());
1833             Evaluator evaluator(context);
1834 
1835             Ciphertext encrypted1;
1836             Plaintext plain1;
1837             Plaintext plain2;
1838             Plaintext plainRes;
1839 
1840             vector<complex<double>> input1(slot_size, 0.0);
1841             double input2;
1842             vector<complex<double>> expected(slot_size, 0.0);
1843             vector<complex<double>> output(slot_size);
1844 
1845             int data_bound = (1 << 10);
1846             srand(static_cast<unsigned>(time(NULL)));
1847 
1848             for (int iExp = 0; iExp < 50; iExp++)
1849             {
1850                 input2 = static_cast<double>(rand() % (data_bound * data_bound)) / static_cast<double>(data_bound);
1851                 for (size_t i = 0; i < slot_size; i++)
1852                 {
1853                     input1[i] = static_cast<double>(rand() % data_bound);
1854                     expected[i] = input1[i] * input2;
1855                 }
1856 
1857                 const double delta = static_cast<double>(1ULL << 40);
1858                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
1859                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
1860 
1861                 encryptor.encrypt(plain1, encrypted1);
1862                 evaluator.multiply_plain_inplace(encrypted1, plain2);
1863 
1864                 // Check correctness of encryption
1865                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1866 
1867                 decryptor.decrypt(encrypted1, plainRes);
1868                 encoder.decode(plainRes, output);
1869                 for (size_t i = 0; i < slot_size; i++)
1870                 {
1871                     auto tmp = abs(expected[i].real() - output[i].real());
1872                     ASSERT_TRUE(tmp < 0.5);
1873                 }
1874             }
1875         }
1876         {
1877             // Multiplying two random vectors by a double
1878             size_t slot_size = 16;
1879             parms.set_poly_modulus_degree(64);
1880             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 }));
1881 
1882             SEALContext context(parms, false, sec_level_type::none);
1883             KeyGenerator keygen(context);
1884             PublicKey pk;
1885             keygen.create_public_key(pk);
1886 
1887             CKKSEncoder encoder(context);
1888             Encryptor encryptor(context, pk);
1889             Decryptor decryptor(context, keygen.secret_key());
1890             Evaluator evaluator(context);
1891 
1892             Ciphertext encrypted1;
1893             Plaintext plain1;
1894             Plaintext plain2;
1895             Plaintext plainRes;
1896 
1897             vector<complex<double>> input1(slot_size, 2.1);
1898             double input2;
1899             vector<complex<double>> expected(slot_size, 2.1);
1900             vector<complex<double>> output(slot_size);
1901 
1902             int data_bound = (1 << 10);
1903             srand(static_cast<unsigned>(time(NULL)));
1904 
1905             for (int iExp = 0; iExp < 50; iExp++)
1906             {
1907                 input2 = static_cast<double>(rand() % (data_bound * data_bound)) / static_cast<double>(data_bound);
1908                 for (size_t i = 0; i < slot_size; i++)
1909                 {
1910                     input1[i] = static_cast<double>(rand() % data_bound);
1911                     expected[i] = input1[i] * input2;
1912                 }
1913 
1914                 const double delta = static_cast<double>(1ULL << 40);
1915                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
1916                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
1917 
1918                 encryptor.encrypt(plain1, encrypted1);
1919                 evaluator.multiply_plain_inplace(encrypted1, plain2);
1920 
1921                 // Check correctness of encryption
1922                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1923 
1924                 decryptor.decrypt(encrypted1, plainRes);
1925                 encoder.decode(plainRes, output);
1926                 for (size_t i = 0; i < slot_size; i++)
1927                 {
1928                     auto tmp = abs(expected[i].real() - output[i].real());
1929                     ASSERT_TRUE(tmp < 0.5);
1930                 }
1931             }
1932         }
1933     }
1934 
TEST(EvaluatorTest,CKKSEncryptMultiplyRelinDecrypt)1935     TEST(EvaluatorTest, CKKSEncryptMultiplyRelinDecrypt)
1936     {
1937         EncryptionParameters parms(scheme_type::ckks);
1938         {
1939             // Multiplying two random vectors 50 times
1940             size_t slot_size = 32;
1941             parms.set_poly_modulus_degree(slot_size * 2);
1942             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
1943 
1944             SEALContext context(parms, false, sec_level_type::none);
1945             KeyGenerator keygen(context);
1946             PublicKey pk;
1947             keygen.create_public_key(pk);
1948             RelinKeys rlk;
1949             keygen.create_relin_keys(rlk);
1950 
1951             CKKSEncoder encoder(context);
1952             Encryptor encryptor(context, pk);
1953             Decryptor decryptor(context, keygen.secret_key());
1954             Evaluator evaluator(context);
1955 
1956             Ciphertext encrypted1;
1957             Ciphertext encrypted2;
1958             Ciphertext encryptedRes;
1959             Plaintext plain1;
1960             Plaintext plain2;
1961             Plaintext plainRes;
1962 
1963             vector<complex<double>> input1(slot_size, 0.0);
1964             vector<complex<double>> input2(slot_size, 0.0);
1965             vector<complex<double>> expected(slot_size, 0.0);
1966             int data_bound = 1 << 10;
1967 
1968             for (int round = 0; round < 50; round++)
1969             {
1970                 srand(static_cast<unsigned>(time(NULL)));
1971                 for (size_t i = 0; i < slot_size; i++)
1972                 {
1973                     input1[i] = static_cast<double>(rand() % data_bound);
1974                     input2[i] = static_cast<double>(rand() % data_bound);
1975                     expected[i] = input1[i] * input2[i];
1976                 }
1977 
1978                 vector<complex<double>> output(slot_size);
1979                 const double delta = static_cast<double>(1ULL << 40);
1980                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
1981                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
1982 
1983                 encryptor.encrypt(plain1, encrypted1);
1984                 encryptor.encrypt(plain2, encrypted2);
1985 
1986                 // Check correctness of encryption
1987                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
1988                 // Check correctness of encryption
1989                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
1990 
1991                 evaluator.multiply_inplace(encrypted1, encrypted2);
1992                 evaluator.relinearize_inplace(encrypted1, rlk);
1993 
1994                 decryptor.decrypt(encrypted1, plainRes);
1995                 encoder.decode(plainRes, output);
1996                 for (size_t i = 0; i < slot_size; i++)
1997                 {
1998                     auto tmp = abs(expected[i].real() - output[i].real());
1999                     ASSERT_TRUE(tmp < 0.5);
2000                 }
2001             }
2002         }
2003         {
2004             // Multiplying two random vectors 50 times
2005             size_t slot_size = 32;
2006             parms.set_poly_modulus_degree(slot_size * 2);
2007             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 }));
2008 
2009             SEALContext context(parms, false, sec_level_type::none);
2010             KeyGenerator keygen(context);
2011             PublicKey pk;
2012             keygen.create_public_key(pk);
2013             RelinKeys rlk;
2014             keygen.create_relin_keys(rlk);
2015 
2016             CKKSEncoder encoder(context);
2017             Encryptor encryptor(context, pk);
2018             Decryptor decryptor(context, keygen.secret_key());
2019             Evaluator evaluator(context);
2020 
2021             Ciphertext encrypted1;
2022             Ciphertext encrypted2;
2023             Ciphertext encryptedRes;
2024             Plaintext plain1;
2025             Plaintext plain2;
2026             Plaintext plainRes;
2027 
2028             vector<complex<double>> input1(slot_size, 0.0);
2029             vector<complex<double>> input2(slot_size, 0.0);
2030             vector<complex<double>> expected(slot_size, 0.0);
2031             int data_bound = 1 << 10;
2032 
2033             for (int round = 0; round < 50; round++)
2034             {
2035                 srand(static_cast<unsigned>(time(NULL)));
2036                 for (size_t i = 0; i < slot_size; i++)
2037                 {
2038                     input1[i] = static_cast<double>(rand() % data_bound);
2039                     input2[i] = static_cast<double>(rand() % data_bound);
2040                     expected[i] = input1[i] * input2[i];
2041                 }
2042 
2043                 vector<complex<double>> output(slot_size);
2044                 const double delta = static_cast<double>(1ULL << 40);
2045                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
2046                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
2047 
2048                 encryptor.encrypt(plain1, encrypted1);
2049                 encryptor.encrypt(plain2, encrypted2);
2050 
2051                 // Check correctness of encryption
2052                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
2053                 // Check correctness of encryption
2054                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
2055 
2056                 evaluator.multiply_inplace(encrypted1, encrypted2);
2057                 evaluator.relinearize_inplace(encrypted1, rlk);
2058 
2059                 decryptor.decrypt(encrypted1, plainRes);
2060                 encoder.decode(plainRes, output);
2061                 for (size_t i = 0; i < slot_size; i++)
2062                 {
2063                     auto tmp = abs(expected[i].real() - output[i].real());
2064                     ASSERT_TRUE(tmp < 0.5);
2065                 }
2066             }
2067         }
2068         {
2069             // Multiplying two random vectors 50 times
2070             size_t slot_size = 2;
2071             parms.set_poly_modulus_degree(8);
2072             parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 30, 30, 30 }));
2073 
2074             SEALContext context(parms, false, sec_level_type::none);
2075             KeyGenerator keygen(context);
2076             PublicKey pk;
2077             keygen.create_public_key(pk);
2078             RelinKeys rlk;
2079             keygen.create_relin_keys(rlk);
2080 
2081             CKKSEncoder encoder(context);
2082             Encryptor encryptor(context, pk);
2083             Decryptor decryptor(context, keygen.secret_key());
2084             Evaluator evaluator(context);
2085 
2086             Ciphertext encrypted1;
2087             Ciphertext encrypted2;
2088             Ciphertext encryptedRes;
2089             Plaintext plain1;
2090             Plaintext plain2;
2091             Plaintext plainRes;
2092 
2093             vector<complex<double>> input1(slot_size, 0.0);
2094             vector<complex<double>> input2(slot_size, 0.0);
2095             vector<complex<double>> expected(slot_size, 0.0);
2096             vector<complex<double>> output(slot_size);
2097             int data_bound = 1 << 10;
2098             const double delta = static_cast<double>(1ULL << 40);
2099 
2100             for (int round = 0; round < 50; round++)
2101             {
2102                 srand(static_cast<unsigned>(time(NULL)));
2103                 for (size_t i = 0; i < slot_size; i++)
2104                 {
2105                     input1[i] = static_cast<double>(rand() % data_bound);
2106                     input2[i] = static_cast<double>(rand() % data_bound);
2107                     expected[i] = input1[i] * input2[i];
2108                 }
2109 
2110                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
2111                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
2112 
2113                 encryptor.encrypt(plain1, encrypted1);
2114                 encryptor.encrypt(plain2, encrypted2);
2115 
2116                 // Check correctness of encryption
2117                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
2118                 // Check correctness of encryption
2119                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
2120 
2121                 evaluator.multiply_inplace(encrypted1, encrypted2);
2122                 // Evaluator.relinearize_inplace(encrypted1, rlk);
2123 
2124                 decryptor.decrypt(encrypted1, plainRes);
2125                 encoder.decode(plainRes, output);
2126                 for (size_t i = 0; i < slot_size; i++)
2127                 {
2128                     auto tmp = abs(expected[i].real() - output[i].real());
2129                     ASSERT_TRUE(tmp < 0.5);
2130                 }
2131             }
2132         }
2133     }
2134 
TEST(EvaluatorTest,CKKSEncryptSquareRelinDecrypt)2135     TEST(EvaluatorTest, CKKSEncryptSquareRelinDecrypt)
2136     {
2137         EncryptionParameters parms(scheme_type::ckks);
2138         {
2139             // Squaring two random vectors 100 times
2140             size_t slot_size = 32;
2141             parms.set_poly_modulus_degree(slot_size * 2);
2142             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 }));
2143 
2144             SEALContext context(parms, false, sec_level_type::none);
2145             KeyGenerator keygen(context);
2146             PublicKey pk;
2147             keygen.create_public_key(pk);
2148             RelinKeys rlk;
2149             keygen.create_relin_keys(rlk);
2150 
2151             CKKSEncoder encoder(context);
2152             Encryptor encryptor(context, pk);
2153             Decryptor decryptor(context, keygen.secret_key());
2154             Evaluator evaluator(context);
2155 
2156             Ciphertext encrypted;
2157             Plaintext plain;
2158             Plaintext plainRes;
2159 
2160             vector<complex<double>> input(slot_size, 0.0);
2161             vector<complex<double>> expected(slot_size, 0.0);
2162 
2163             int data_bound = 1 << 7;
2164             srand(static_cast<unsigned>(time(NULL)));
2165 
2166             for (int round = 0; round < 100; round++)
2167             {
2168                 for (size_t i = 0; i < slot_size; i++)
2169                 {
2170                     input[i] = static_cast<double>(rand() % data_bound);
2171                     expected[i] = input[i] * input[i];
2172                 }
2173 
2174                 vector<complex<double>> output(slot_size);
2175                 const double delta = static_cast<double>(1ULL << 40);
2176                 encoder.encode(input, context.first_parms_id(), delta, plain);
2177 
2178                 encryptor.encrypt(plain, encrypted);
2179 
2180                 // Check correctness of encryption
2181                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2182 
2183                 // Evaluator.square_inplace(encrypted);
2184                 evaluator.multiply_inplace(encrypted, encrypted);
2185                 evaluator.relinearize_inplace(encrypted, rlk);
2186 
2187                 decryptor.decrypt(encrypted, plainRes);
2188                 encoder.decode(plainRes, output);
2189                 for (size_t i = 0; i < slot_size; i++)
2190                 {
2191                     auto tmp = abs(expected[i].real() - output[i].real());
2192                     ASSERT_TRUE(tmp < 0.5);
2193                 }
2194             }
2195         }
2196         {
2197             // Squaring two random vectors 100 times
2198             size_t slot_size = 32;
2199             parms.set_poly_modulus_degree(slot_size * 2);
2200             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 }));
2201 
2202             SEALContext context(parms, false, sec_level_type::none);
2203             KeyGenerator keygen(context);
2204             PublicKey pk;
2205             keygen.create_public_key(pk);
2206             RelinKeys rlk;
2207             keygen.create_relin_keys(rlk);
2208 
2209             CKKSEncoder encoder(context);
2210             Encryptor encryptor(context, pk);
2211             Decryptor decryptor(context, keygen.secret_key());
2212             Evaluator evaluator(context);
2213 
2214             Ciphertext encrypted;
2215             Plaintext plain;
2216             Plaintext plainRes;
2217 
2218             vector<complex<double>> input(slot_size, 0.0);
2219             vector<complex<double>> expected(slot_size, 0.0);
2220 
2221             int data_bound = 1 << 7;
2222             srand(static_cast<unsigned>(time(NULL)));
2223 
2224             for (int round = 0; round < 100; round++)
2225             {
2226                 for (size_t i = 0; i < slot_size; i++)
2227                 {
2228                     input[i] = static_cast<double>(rand() % data_bound);
2229                     expected[i] = input[i] * input[i];
2230                 }
2231 
2232                 vector<complex<double>> output(slot_size);
2233                 const double delta = static_cast<double>(1ULL << 40);
2234                 encoder.encode(input, context.first_parms_id(), delta, plain);
2235 
2236                 encryptor.encrypt(plain, encrypted);
2237 
2238                 // Check correctness of encryption
2239                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2240 
2241                 // Evaluator.square_inplace(encrypted);
2242                 evaluator.multiply_inplace(encrypted, encrypted);
2243                 evaluator.relinearize_inplace(encrypted, rlk);
2244 
2245                 decryptor.decrypt(encrypted, plainRes);
2246                 encoder.decode(plainRes, output);
2247                 for (size_t i = 0; i < slot_size; i++)
2248                 {
2249                     auto tmp = abs(expected[i].real() - output[i].real());
2250                     ASSERT_TRUE(tmp < 0.5);
2251                 }
2252             }
2253         }
2254         {
2255             // Squaring two random vectors 100 times
2256             size_t slot_size = 16;
2257             parms.set_poly_modulus_degree(64);
2258             parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 30, 30, 30 }));
2259 
2260             SEALContext context(parms, false, sec_level_type::none);
2261             KeyGenerator keygen(context);
2262             PublicKey pk;
2263             keygen.create_public_key(pk);
2264             RelinKeys rlk;
2265             keygen.create_relin_keys(rlk);
2266 
2267             CKKSEncoder encoder(context);
2268             Encryptor encryptor(context, pk);
2269             Decryptor decryptor(context, keygen.secret_key());
2270             Evaluator evaluator(context);
2271 
2272             Ciphertext encrypted;
2273             Plaintext plain;
2274             Plaintext plainRes;
2275 
2276             vector<complex<double>> input(slot_size, 0.0);
2277             vector<complex<double>> expected(slot_size, 0.0);
2278 
2279             int data_bound = 1 << 7;
2280             srand(static_cast<unsigned>(time(NULL)));
2281 
2282             for (int round = 0; round < 100; round++)
2283             {
2284                 for (size_t i = 0; i < slot_size; i++)
2285                 {
2286                     input[i] = static_cast<double>(rand() % data_bound);
2287                     expected[i] = input[i] * input[i];
2288                 }
2289 
2290                 vector<complex<double>> output(slot_size);
2291                 const double delta = static_cast<double>(1ULL << 40);
2292                 encoder.encode(input, context.first_parms_id(), delta, plain);
2293 
2294                 encryptor.encrypt(plain, encrypted);
2295 
2296                 // Check correctness of encryption
2297                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2298 
2299                 // Evaluator.square_inplace(encrypted);
2300                 evaluator.multiply_inplace(encrypted, encrypted);
2301                 evaluator.relinearize_inplace(encrypted, rlk);
2302 
2303                 decryptor.decrypt(encrypted, plainRes);
2304                 encoder.decode(plainRes, output);
2305                 for (size_t i = 0; i < slot_size; i++)
2306                 {
2307                     auto tmp = abs(expected[i].real() - output[i].real());
2308                     ASSERT_TRUE(tmp < 0.5);
2309                 }
2310             }
2311         }
2312     }
2313 
TEST(EvaluatorTest,CKKSEncryptMultiplyRelinRescaleDecrypt)2314     TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleDecrypt)
2315     {
2316         EncryptionParameters parms(scheme_type::ckks);
2317         {
2318             // Multiplying two random vectors 100 times
2319             size_t slot_size = 64;
2320             parms.set_poly_modulus_degree(slot_size * 2);
2321             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30, 30 }));
2322 
2323             SEALContext context(parms, true, sec_level_type::none);
2324             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2325             KeyGenerator keygen(context);
2326             PublicKey pk;
2327             keygen.create_public_key(pk);
2328             RelinKeys rlk;
2329             keygen.create_relin_keys(rlk);
2330 
2331             CKKSEncoder encoder(context);
2332             Encryptor encryptor(context, pk);
2333             Decryptor decryptor(context, keygen.secret_key());
2334             Evaluator evaluator(context);
2335 
2336             Ciphertext encrypted1;
2337             Ciphertext encrypted2;
2338             Ciphertext encryptedRes;
2339             Plaintext plain1;
2340             Plaintext plain2;
2341             Plaintext plainRes;
2342 
2343             vector<complex<double>> input1(slot_size, 0.0);
2344             vector<complex<double>> input2(slot_size, 0.0);
2345             vector<complex<double>> expected(slot_size, 0.0);
2346 
2347             for (int round = 0; round < 100; round++)
2348             {
2349                 int data_bound = 1 << 7;
2350                 srand(static_cast<unsigned>(time(NULL)));
2351                 for (size_t i = 0; i < slot_size; i++)
2352                 {
2353                     input1[i] = static_cast<double>(rand() % data_bound);
2354                     input2[i] = static_cast<double>(rand() % data_bound);
2355                     expected[i] = input1[i] * input2[i];
2356                 }
2357 
2358                 vector<complex<double>> output(slot_size);
2359                 double delta = static_cast<double>(1ULL << 40);
2360                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
2361                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
2362 
2363                 encryptor.encrypt(plain1, encrypted1);
2364                 encryptor.encrypt(plain2, encrypted2);
2365 
2366                 // Check correctness of encryption
2367                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
2368                 // Check correctness of encryption
2369                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
2370 
2371                 evaluator.multiply_inplace(encrypted1, encrypted2);
2372                 evaluator.relinearize_inplace(encrypted1, rlk);
2373                 evaluator.rescale_to_next_inplace(encrypted1);
2374 
2375                 // Check correctness of modulus switching
2376                 ASSERT_TRUE(encrypted1.parms_id() == next_parms_id);
2377 
2378                 decryptor.decrypt(encrypted1, plainRes);
2379                 encoder.decode(plainRes, output);
2380                 for (size_t i = 0; i < slot_size; i++)
2381                 {
2382                     auto tmp = abs(expected[i].real() - output[i].real());
2383                     ASSERT_TRUE(tmp < 0.5);
2384                 }
2385             }
2386         }
2387         {
2388             // Multiplying two random vectors 100 times
2389             size_t slot_size = 16;
2390             parms.set_poly_modulus_degree(128);
2391             parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30, 30 }));
2392 
2393             SEALContext context(parms, true, sec_level_type::none);
2394             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2395             KeyGenerator keygen(context);
2396             PublicKey pk;
2397             keygen.create_public_key(pk);
2398             RelinKeys rlk;
2399             keygen.create_relin_keys(rlk);
2400 
2401             CKKSEncoder encoder(context);
2402             Encryptor encryptor(context, pk);
2403             Decryptor decryptor(context, keygen.secret_key());
2404             Evaluator evaluator(context);
2405 
2406             Ciphertext encrypted1;
2407             Ciphertext encrypted2;
2408             Ciphertext encryptedRes;
2409             Plaintext plain1;
2410             Plaintext plain2;
2411             Plaintext plainRes;
2412 
2413             vector<complex<double>> input1(slot_size, 0.0);
2414             vector<complex<double>> input2(slot_size, 0.0);
2415             vector<complex<double>> expected(slot_size, 0.0);
2416 
2417             for (int round = 0; round < 100; round++)
2418             {
2419                 int data_bound = 1 << 7;
2420                 srand(static_cast<unsigned>(time(NULL)));
2421                 for (size_t i = 0; i < slot_size; i++)
2422                 {
2423                     input1[i] = static_cast<double>(rand() % data_bound);
2424                     input2[i] = static_cast<double>(rand() % data_bound);
2425                     expected[i] = input1[i] * input2[i];
2426                 }
2427 
2428                 vector<complex<double>> output(slot_size);
2429                 double delta = static_cast<double>(1ULL << 40);
2430                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
2431                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
2432 
2433                 encryptor.encrypt(plain1, encrypted1);
2434                 encryptor.encrypt(plain2, encrypted2);
2435 
2436                 // Check correctness of encryption
2437                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
2438                 // Check correctness of encryption
2439                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
2440 
2441                 evaluator.multiply_inplace(encrypted1, encrypted2);
2442                 evaluator.relinearize_inplace(encrypted1, rlk);
2443                 evaluator.rescale_to_next_inplace(encrypted1);
2444 
2445                 // Check correctness of modulus switching
2446                 ASSERT_TRUE(encrypted1.parms_id() == next_parms_id);
2447 
2448                 decryptor.decrypt(encrypted1, plainRes);
2449                 encoder.decode(plainRes, output);
2450                 for (size_t i = 0; i < slot_size; i++)
2451                 {
2452                     auto tmp = abs(expected[i].real() - output[i].real());
2453                     ASSERT_TRUE(tmp < 0.5);
2454                 }
2455             }
2456         }
2457         {
2458             // Multiplying two random vectors 100 times
2459             size_t slot_size = 16;
2460             parms.set_poly_modulus_degree(128);
2461             parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 60, 60, 60, 60 }));
2462 
2463             SEALContext context(parms, true, sec_level_type::none);
2464             KeyGenerator keygen(context);
2465             PublicKey pk;
2466             keygen.create_public_key(pk);
2467             RelinKeys rlk;
2468             keygen.create_relin_keys(rlk);
2469 
2470             CKKSEncoder encoder(context);
2471             Encryptor encryptor(context, pk);
2472             Decryptor decryptor(context, keygen.secret_key());
2473             Evaluator evaluator(context);
2474 
2475             Ciphertext encrypted1;
2476             Ciphertext encrypted2;
2477             Ciphertext encryptedRes;
2478             Plaintext plain1;
2479             Plaintext plain2;
2480             Plaintext plainRes;
2481 
2482             vector<complex<double>> input1(slot_size, 0.0);
2483             vector<complex<double>> input2(slot_size, 0.0);
2484             vector<complex<double>> expected(slot_size, 0.0);
2485 
2486             for (int round = 0; round < 100; round++)
2487             {
2488                 int data_bound = 1 << 7;
2489                 srand(static_cast<unsigned>(time(NULL)));
2490                 for (size_t i = 0; i < slot_size; i++)
2491                 {
2492                     input1[i] = static_cast<double>(rand() % data_bound);
2493                     input2[i] = static_cast<double>(rand() % data_bound);
2494                     expected[i] = input1[i] * input2[i] * input2[i];
2495                 }
2496 
2497                 vector<complex<double>> output(slot_size);
2498                 double delta = static_cast<double>(1ULL << 60);
2499                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
2500                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
2501 
2502                 encryptor.encrypt(plain1, encrypted1);
2503                 encryptor.encrypt(plain2, encrypted2);
2504 
2505                 // Check correctness of encryption
2506                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
2507                 // Check correctness of encryption
2508                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
2509 
2510                 evaluator.multiply_inplace(encrypted1, encrypted2);
2511                 evaluator.relinearize_inplace(encrypted1, rlk);
2512                 evaluator.multiply_inplace(encrypted1, encrypted2);
2513                 evaluator.relinearize_inplace(encrypted1, rlk);
2514 
2515                 // Scale down by two levels
2516                 auto target_parms = context.first_context_data()->next_context_data()->next_context_data()->parms_id();
2517                 evaluator.rescale_to_inplace(encrypted1, target_parms);
2518 
2519                 // Check correctness of modulus switching
2520                 ASSERT_TRUE(encrypted1.parms_id() == target_parms);
2521 
2522                 decryptor.decrypt(encrypted1, plainRes);
2523                 encoder.decode(plainRes, output);
2524                 for (size_t i = 0; i < slot_size; i++)
2525                 {
2526                     auto tmp = abs(expected[i].real() - output[i].real());
2527                     ASSERT_TRUE(tmp < 0.5);
2528                 }
2529             }
2530 
2531             // Test with inverted order: rescale then relin
2532             for (int round = 0; round < 100; round++)
2533             {
2534                 int data_bound = 1 << 7;
2535                 srand(static_cast<unsigned>(time(NULL)));
2536                 for (size_t i = 0; i < slot_size; i++)
2537                 {
2538                     input1[i] = static_cast<double>(rand() % data_bound);
2539                     input2[i] = static_cast<double>(rand() % data_bound);
2540                     expected[i] = input1[i] * input2[i] * input2[i];
2541                 }
2542 
2543                 vector<complex<double>> output(slot_size);
2544                 double delta = static_cast<double>(1ULL << 50);
2545                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
2546                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
2547 
2548                 encryptor.encrypt(plain1, encrypted1);
2549                 encryptor.encrypt(plain2, encrypted2);
2550 
2551                 // Check correctness of encryption
2552                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
2553                 // Check correctness of encryption
2554                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
2555 
2556                 evaluator.multiply_inplace(encrypted1, encrypted2);
2557                 evaluator.relinearize_inplace(encrypted1, rlk);
2558                 evaluator.multiply_inplace(encrypted1, encrypted2);
2559 
2560                 // Scale down by two levels
2561                 auto target_parms = context.first_context_data()->next_context_data()->next_context_data()->parms_id();
2562                 evaluator.rescale_to_inplace(encrypted1, target_parms);
2563 
2564                 // Relinearize now
2565                 evaluator.relinearize_inplace(encrypted1, rlk);
2566 
2567                 // Check correctness of modulus switching
2568                 ASSERT_TRUE(encrypted1.parms_id() == target_parms);
2569 
2570                 decryptor.decrypt(encrypted1, plainRes);
2571                 encoder.decode(plainRes, output);
2572                 for (size_t i = 0; i < slot_size; i++)
2573                 {
2574                     auto tmp = abs(expected[i].real() - output[i].real());
2575                     ASSERT_TRUE(tmp < 0.5);
2576                 }
2577             }
2578         }
2579     }
2580 
TEST(EvaluatorTest,CKKSEncryptSquareRelinRescaleDecrypt)2581     TEST(EvaluatorTest, CKKSEncryptSquareRelinRescaleDecrypt)
2582     {
2583         EncryptionParameters parms(scheme_type::ckks);
2584         {
2585             // Squaring two random vectors 100 times
2586             size_t slot_size = 64;
2587             parms.set_poly_modulus_degree(slot_size * 2);
2588             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 }));
2589 
2590             SEALContext context(parms, true, sec_level_type::none);
2591             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2592             KeyGenerator keygen(context);
2593             PublicKey pk;
2594             keygen.create_public_key(pk);
2595             RelinKeys rlk;
2596             keygen.create_relin_keys(rlk);
2597 
2598             CKKSEncoder encoder(context);
2599             Encryptor encryptor(context, pk);
2600             Decryptor decryptor(context, keygen.secret_key());
2601             Evaluator evaluator(context);
2602 
2603             Ciphertext encrypted;
2604             Plaintext plain;
2605             Plaintext plainRes;
2606 
2607             vector<complex<double>> input(slot_size, 0.0);
2608             vector<complex<double>> output(slot_size);
2609             vector<complex<double>> expected(slot_size, 0.0);
2610             int data_bound = 1 << 8;
2611 
2612             for (int round = 0; round < 100; round++)
2613             {
2614                 srand(static_cast<unsigned>(time(NULL)));
2615                 for (size_t i = 0; i < slot_size; i++)
2616                 {
2617                     input[i] = static_cast<double>(rand() % data_bound);
2618                     expected[i] = input[i] * input[i];
2619                 }
2620 
2621                 double delta = static_cast<double>(1ULL << 40);
2622                 encoder.encode(input, context.first_parms_id(), delta, plain);
2623 
2624                 encryptor.encrypt(plain, encrypted);
2625 
2626                 // Check correctness of encryption
2627                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2628 
2629                 evaluator.square_inplace(encrypted);
2630                 evaluator.relinearize_inplace(encrypted, rlk);
2631                 evaluator.rescale_to_next_inplace(encrypted);
2632 
2633                 // Check correctness of modulus switching
2634                 ASSERT_TRUE(encrypted.parms_id() == next_parms_id);
2635 
2636                 decryptor.decrypt(encrypted, plainRes);
2637                 encoder.decode(plainRes, output);
2638                 for (size_t i = 0; i < slot_size; i++)
2639                 {
2640                     auto tmp = abs(expected[i].real() - output[i].real());
2641                     ASSERT_TRUE(tmp < 0.5);
2642                 }
2643             }
2644         }
2645         {
2646             // Squaring two random vectors 100 times
2647             size_t slot_size = 16;
2648             parms.set_poly_modulus_degree(128);
2649             parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 }));
2650 
2651             SEALContext context(parms, true, sec_level_type::none);
2652             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2653             KeyGenerator keygen(context);
2654             PublicKey pk;
2655             keygen.create_public_key(pk);
2656             RelinKeys rlk;
2657             keygen.create_relin_keys(rlk);
2658 
2659             CKKSEncoder encoder(context);
2660             Encryptor encryptor(context, pk);
2661             Decryptor decryptor(context, keygen.secret_key());
2662             Evaluator evaluator(context);
2663 
2664             Ciphertext encrypted;
2665             Plaintext plain;
2666             Plaintext plainRes;
2667 
2668             vector<complex<double>> input(slot_size, 0.0);
2669             vector<complex<double>> output(slot_size);
2670             vector<complex<double>> expected(slot_size, 0.0);
2671             int data_bound = 1 << 8;
2672 
2673             for (int round = 0; round < 100; round++)
2674             {
2675                 srand(static_cast<unsigned>(time(NULL)));
2676                 for (size_t i = 0; i < slot_size; i++)
2677                 {
2678                     input[i] = static_cast<double>(rand() % data_bound);
2679                     expected[i] = input[i] * input[i];
2680                 }
2681 
2682                 double delta = static_cast<double>(1ULL << 40);
2683                 encoder.encode(input, context.first_parms_id(), delta, plain);
2684 
2685                 encryptor.encrypt(plain, encrypted);
2686 
2687                 // Check correctness of encryption
2688                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2689 
2690                 evaluator.square_inplace(encrypted);
2691                 evaluator.relinearize_inplace(encrypted, rlk);
2692                 evaluator.rescale_to_next_inplace(encrypted);
2693 
2694                 // Check correctness of modulus switching
2695                 ASSERT_TRUE(encrypted.parms_id() == next_parms_id);
2696 
2697                 decryptor.decrypt(encrypted, plainRes);
2698                 encoder.decode(plainRes, output);
2699                 for (size_t i = 0; i < slot_size; i++)
2700                 {
2701                     auto tmp = abs(expected[i].real() - output[i].real());
2702                     ASSERT_TRUE(tmp < 0.5);
2703                 }
2704             }
2705         }
2706     }
TEST(EvaluatorTest,CKKSEncryptModSwitchDecrypt)2707     TEST(EvaluatorTest, CKKSEncryptModSwitchDecrypt)
2708     {
2709         EncryptionParameters parms(scheme_type::ckks);
2710         {
2711             // Modulus switching without rescaling for random vectors
2712             size_t slot_size = 64;
2713             parms.set_poly_modulus_degree(slot_size * 2);
2714             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60, 60, 60 }));
2715 
2716             SEALContext context(parms, true, sec_level_type::none);
2717             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2718             KeyGenerator keygen(context);
2719             PublicKey pk;
2720             keygen.create_public_key(pk);
2721 
2722             CKKSEncoder encoder(context);
2723             Encryptor encryptor(context, pk);
2724             Decryptor decryptor(context, keygen.secret_key());
2725             Evaluator evaluator(context);
2726 
2727             int data_bound = 1 << 30;
2728             srand(static_cast<unsigned>(time(NULL)));
2729 
2730             vector<complex<double>> input(slot_size, 0.0);
2731             vector<complex<double>> output(slot_size);
2732 
2733             Ciphertext encrypted;
2734             Plaintext plain;
2735             Plaintext plainRes;
2736 
2737             for (int round = 0; round < 100; round++)
2738             {
2739                 for (size_t i = 0; i < slot_size; i++)
2740                 {
2741                     input[i] = static_cast<double>(rand() % data_bound);
2742                 }
2743 
2744                 double delta = static_cast<double>(1ULL << 40);
2745                 encoder.encode(input, context.first_parms_id(), delta, plain);
2746 
2747                 encryptor.encrypt(plain, encrypted);
2748 
2749                 // Check correctness of encryption
2750                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2751 
2752                 // Not inplace
2753                 Ciphertext destination;
2754                 evaluator.mod_switch_to_next(encrypted, destination);
2755 
2756                 // Check correctness of modulus switching
2757                 ASSERT_TRUE(destination.parms_id() == next_parms_id);
2758 
2759                 decryptor.decrypt(destination, plainRes);
2760                 encoder.decode(plainRes, output);
2761 
2762                 for (size_t i = 0; i < slot_size; i++)
2763                 {
2764                     auto tmp = abs(input[i].real() - output[i].real());
2765                     ASSERT_TRUE(tmp < 0.5);
2766                 }
2767 
2768                 // Inplace
2769                 evaluator.mod_switch_to_next_inplace(encrypted);
2770 
2771                 // Check correctness of modulus switching
2772                 ASSERT_TRUE(encrypted.parms_id() == next_parms_id);
2773 
2774                 decryptor.decrypt(encrypted, plainRes);
2775                 encoder.decode(plainRes, output);
2776                 for (size_t i = 0; i < slot_size; i++)
2777                 {
2778                     auto tmp = abs(input[i].real() - output[i].real());
2779                     ASSERT_TRUE(tmp < 0.5);
2780                 }
2781             }
2782         }
2783         {
2784             // Modulus switching without rescaling for random vectors
2785             size_t slot_size = 32;
2786             parms.set_poly_modulus_degree(slot_size * 2);
2787             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40, 40 }));
2788 
2789             SEALContext context(parms, true, sec_level_type::none);
2790             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2791             KeyGenerator keygen(context);
2792             PublicKey pk;
2793             keygen.create_public_key(pk);
2794 
2795             CKKSEncoder encoder(context);
2796             Encryptor encryptor(context, pk);
2797             Decryptor decryptor(context, keygen.secret_key());
2798             Evaluator evaluator(context);
2799 
2800             int data_bound = 1 << 30;
2801             srand(static_cast<unsigned>(time(NULL)));
2802 
2803             vector<complex<double>> input(slot_size, 0.0);
2804             vector<complex<double>> output(slot_size);
2805 
2806             Ciphertext encrypted;
2807             Plaintext plain;
2808             Plaintext plainRes;
2809 
2810             for (int round = 0; round < 100; round++)
2811             {
2812                 for (size_t i = 0; i < slot_size; i++)
2813                 {
2814                     input[i] = static_cast<double>(rand() % data_bound);
2815                 }
2816 
2817                 double delta = static_cast<double>(1ULL << 40);
2818                 encoder.encode(input, context.first_parms_id(), delta, plain);
2819 
2820                 encryptor.encrypt(plain, encrypted);
2821 
2822                 // Check correctness of encryption
2823                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2824 
2825                 // Not inplace
2826                 Ciphertext destination;
2827                 evaluator.mod_switch_to_next(encrypted, destination);
2828 
2829                 // Check correctness of modulus switching
2830                 ASSERT_TRUE(destination.parms_id() == next_parms_id);
2831 
2832                 decryptor.decrypt(destination, plainRes);
2833                 encoder.decode(plainRes, output);
2834 
2835                 for (size_t i = 0; i < slot_size; i++)
2836                 {
2837                     auto tmp = abs(input[i].real() - output[i].real());
2838                     ASSERT_TRUE(tmp < 0.5);
2839                 }
2840 
2841                 // Inplace
2842                 evaluator.mod_switch_to_next_inplace(encrypted);
2843 
2844                 // Check correctness of modulus switching
2845                 ASSERT_TRUE(encrypted.parms_id() == next_parms_id);
2846 
2847                 decryptor.decrypt(encrypted, plainRes);
2848                 encoder.decode(plainRes, output);
2849                 for (size_t i = 0; i < slot_size; i++)
2850                 {
2851                     auto tmp = abs(input[i].real() - output[i].real());
2852                     ASSERT_TRUE(tmp < 0.5);
2853                 }
2854             }
2855         }
2856         {
2857             // Modulus switching without rescaling for random vectors
2858             size_t slot_size = 32;
2859             parms.set_poly_modulus_degree(128);
2860             parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40, 40 }));
2861 
2862             SEALContext context(parms, true, sec_level_type::none);
2863             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2864             KeyGenerator keygen(context);
2865             PublicKey pk;
2866             keygen.create_public_key(pk);
2867 
2868             CKKSEncoder encoder(context);
2869             Encryptor encryptor(context, pk);
2870             Decryptor decryptor(context, keygen.secret_key());
2871             Evaluator evaluator(context);
2872 
2873             int data_bound = 1 << 30;
2874             srand(static_cast<unsigned>(time(NULL)));
2875 
2876             vector<complex<double>> input(slot_size, 0.0);
2877             vector<complex<double>> output(slot_size);
2878 
2879             Ciphertext encrypted;
2880             Plaintext plain;
2881             Plaintext plainRes;
2882 
2883             for (int round = 0; round < 100; round++)
2884             {
2885                 for (size_t i = 0; i < slot_size; i++)
2886                 {
2887                     input[i] = static_cast<double>(rand() % data_bound);
2888                 }
2889 
2890                 double delta = static_cast<double>(1ULL << 40);
2891                 encoder.encode(input, context.first_parms_id(), delta, plain);
2892 
2893                 encryptor.encrypt(plain, encrypted);
2894 
2895                 // Check correctness of encryption
2896                 ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
2897 
2898                 // Not inplace
2899                 Ciphertext destination;
2900                 evaluator.mod_switch_to_next(encrypted, destination);
2901 
2902                 // Check correctness of modulus switching
2903                 ASSERT_TRUE(destination.parms_id() == next_parms_id);
2904 
2905                 decryptor.decrypt(destination, plainRes);
2906                 encoder.decode(plainRes, output);
2907 
2908                 for (size_t i = 0; i < slot_size; i++)
2909                 {
2910                     auto tmp = abs(input[i].real() - output[i].real());
2911                     ASSERT_TRUE(tmp < 0.5);
2912                 }
2913 
2914                 // Inplace
2915                 evaluator.mod_switch_to_next_inplace(encrypted);
2916 
2917                 // Check correctness of modulus switching
2918                 ASSERT_TRUE(encrypted.parms_id() == next_parms_id);
2919 
2920                 decryptor.decrypt(encrypted, plainRes);
2921                 encoder.decode(plainRes, output);
2922                 for (size_t i = 0; i < slot_size; i++)
2923                 {
2924                     auto tmp = abs(input[i].real() - output[i].real());
2925                     ASSERT_TRUE(tmp < 0.5);
2926                 }
2927             }
2928         }
2929     }
TEST(EvaluatorTest,CKKSEncryptMultiplyRelinRescaleModSwitchAddDecrypt)2930     TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleModSwitchAddDecrypt)
2931     {
2932         EncryptionParameters parms(scheme_type::ckks);
2933         {
2934             // Multiplication and addition without rescaling for random vectors
2935             size_t slot_size = 64;
2936             parms.set_poly_modulus_degree(slot_size * 2);
2937             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 }));
2938 
2939             SEALContext context(parms, true, sec_level_type::none);
2940             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
2941             KeyGenerator keygen(context);
2942             PublicKey pk;
2943             keygen.create_public_key(pk);
2944             RelinKeys rlk;
2945             keygen.create_relin_keys(rlk);
2946 
2947             CKKSEncoder encoder(context);
2948             Encryptor encryptor(context, pk);
2949             Decryptor decryptor(context, keygen.secret_key());
2950             Evaluator evaluator(context);
2951 
2952             Ciphertext encrypted1;
2953             Ciphertext encrypted2;
2954             Ciphertext encrypted3;
2955             Plaintext plain1;
2956             Plaintext plain2;
2957             Plaintext plain3;
2958             Plaintext plainRes;
2959 
2960             vector<complex<double>> input1(slot_size, 0.0);
2961             vector<complex<double>> input2(slot_size, 0.0);
2962             vector<complex<double>> input3(slot_size, 0.0);
2963             vector<complex<double>> expected(slot_size, 0.0);
2964 
2965             for (int round = 0; round < 100; round++)
2966             {
2967                 int data_bound = 1 << 8;
2968                 srand(static_cast<unsigned>(time(NULL)));
2969                 for (size_t i = 0; i < slot_size; i++)
2970                 {
2971                     input1[i] = static_cast<double>(rand() % data_bound);
2972                     input2[i] = static_cast<double>(rand() % data_bound);
2973                     expected[i] = input1[i] * input2[i] + input3[i];
2974                 }
2975 
2976                 vector<complex<double>> output(slot_size);
2977                 double delta = static_cast<double>(1ULL << 40);
2978                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
2979                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
2980                 encoder.encode(input3, context.first_parms_id(), delta * delta, plain3);
2981 
2982                 encryptor.encrypt(plain1, encrypted1);
2983                 encryptor.encrypt(plain2, encrypted2);
2984                 encryptor.encrypt(plain3, encrypted3);
2985 
2986                 // Check correctness of encryption
2987                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
2988                 // Check correctness of encryption
2989                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
2990                 // Check correctness of encryption
2991                 ASSERT_TRUE(encrypted3.parms_id() == context.first_parms_id());
2992 
2993                 // Enc1*enc2
2994                 evaluator.multiply_inplace(encrypted1, encrypted2);
2995                 evaluator.relinearize_inplace(encrypted1, rlk);
2996                 evaluator.rescale_to_next_inplace(encrypted1);
2997 
2998                 // Check correctness of modulus switching with rescaling
2999                 ASSERT_TRUE(encrypted1.parms_id() == next_parms_id);
3000 
3001                 // Move enc3 to the level of enc1 * enc2
3002                 evaluator.rescale_to_inplace(encrypted3, next_parms_id);
3003 
3004                 // Enc1*enc2 + enc3
3005                 evaluator.add_inplace(encrypted1, encrypted3);
3006 
3007                 decryptor.decrypt(encrypted1, plainRes);
3008                 encoder.decode(plainRes, output);
3009                 for (size_t i = 0; i < slot_size; i++)
3010                 {
3011                     auto tmp = abs(expected[i].real() - output[i].real());
3012                     ASSERT_TRUE(tmp < 0.5);
3013                 }
3014             }
3015         }
3016         {
3017             // Multiplication and addition without rescaling for random vectors
3018             size_t slot_size = 16;
3019             parms.set_poly_modulus_degree(128);
3020             parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 }));
3021 
3022             SEALContext context(parms, true, sec_level_type::none);
3023             auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
3024             KeyGenerator keygen(context);
3025             PublicKey pk;
3026             keygen.create_public_key(pk);
3027             RelinKeys rlk;
3028             keygen.create_relin_keys(rlk);
3029 
3030             CKKSEncoder encoder(context);
3031             Encryptor encryptor(context, pk);
3032             Decryptor decryptor(context, keygen.secret_key());
3033             Evaluator evaluator(context);
3034 
3035             Ciphertext encrypted1;
3036             Ciphertext encrypted2;
3037             Ciphertext encrypted3;
3038             Plaintext plain1;
3039             Plaintext plain2;
3040             Plaintext plain3;
3041             Plaintext plainRes;
3042 
3043             vector<complex<double>> input1(slot_size, 0.0);
3044             vector<complex<double>> input2(slot_size, 0.0);
3045             vector<complex<double>> input3(slot_size, 0.0);
3046             vector<complex<double>> expected(slot_size, 0.0);
3047             vector<complex<double>> output(slot_size);
3048 
3049             for (int round = 0; round < 100; round++)
3050             {
3051                 int data_bound = 1 << 8;
3052                 srand(static_cast<unsigned>(time(NULL)));
3053                 for (size_t i = 0; i < slot_size; i++)
3054                 {
3055                     input1[i] = static_cast<double>(rand() % data_bound);
3056                     input2[i] = static_cast<double>(rand() % data_bound);
3057                     expected[i] = input1[i] * input2[i] + input3[i];
3058                 }
3059 
3060                 double delta = static_cast<double>(1ULL << 40);
3061                 encoder.encode(input1, context.first_parms_id(), delta, plain1);
3062                 encoder.encode(input2, context.first_parms_id(), delta, plain2);
3063                 encoder.encode(input3, context.first_parms_id(), delta * delta, plain3);
3064 
3065                 encryptor.encrypt(plain1, encrypted1);
3066                 encryptor.encrypt(plain2, encrypted2);
3067                 encryptor.encrypt(plain3, encrypted3);
3068 
3069                 // Check correctness of encryption
3070                 ASSERT_TRUE(encrypted1.parms_id() == context.first_parms_id());
3071                 // Check correctness of encryption
3072                 ASSERT_TRUE(encrypted2.parms_id() == context.first_parms_id());
3073                 // Check correctness of encryption
3074                 ASSERT_TRUE(encrypted3.parms_id() == context.first_parms_id());
3075 
3076                 // Enc1*enc2
3077                 evaluator.multiply_inplace(encrypted1, encrypted2);
3078                 evaluator.relinearize_inplace(encrypted1, rlk);
3079                 evaluator.rescale_to_next_inplace(encrypted1);
3080 
3081                 // Check correctness of modulus switching with rescaling
3082                 ASSERT_TRUE(encrypted1.parms_id() == next_parms_id);
3083 
3084                 // Move enc3 to the level of enc1 * enc2
3085                 evaluator.rescale_to_inplace(encrypted3, next_parms_id);
3086 
3087                 // Enc1*enc2 + enc3
3088                 evaluator.add_inplace(encrypted1, encrypted3);
3089 
3090                 decryptor.decrypt(encrypted1, plainRes);
3091                 encoder.decode(plainRes, output);
3092                 for (size_t i = 0; i < slot_size; i++)
3093                 {
3094                     auto tmp = abs(expected[i].real() - output[i].real());
3095                     ASSERT_TRUE(tmp < 0.5);
3096                 }
3097             }
3098         }
3099     }
TEST(EvaluatorTest,CKKSEncryptRotateDecrypt)3100     TEST(EvaluatorTest, CKKSEncryptRotateDecrypt)
3101     {
3102         EncryptionParameters parms(scheme_type::ckks);
3103         {
3104             // Maximal number of slots
3105             size_t slot_size = 4;
3106             parms.set_poly_modulus_degree(slot_size * 2);
3107             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 }));
3108 
3109             SEALContext context(parms, false, sec_level_type::none);
3110             KeyGenerator keygen(context);
3111             PublicKey pk;
3112             keygen.create_public_key(pk);
3113             GaloisKeys glk;
3114             keygen.create_galois_keys(glk);
3115 
3116             Encryptor encryptor(context, pk);
3117             Evaluator evaluator(context);
3118             Decryptor decryptor(context, keygen.secret_key());
3119             CKKSEncoder encoder(context);
3120             const double delta = static_cast<double>(1ULL << 30);
3121 
3122             Ciphertext encrypted;
3123             Plaintext plain;
3124 
3125             vector<complex<double>> input{ complex<double>(1, 1), complex<double>(2, 2), complex<double>(3, 3),
3126                                            complex<double>(4, 4) };
3127             input.resize(slot_size);
3128 
3129             vector<complex<double>> output(slot_size, 0);
3130 
3131             encoder.encode(input, context.first_parms_id(), delta, plain);
3132             int shift = 1;
3133             encryptor.encrypt(plain, encrypted);
3134             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3135             decryptor.decrypt(encrypted, plain);
3136             encoder.decode(plain, output);
3137             for (size_t i = 0; i < slot_size; i++)
3138             {
3139                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].real(), round(output[i].real()));
3140                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].imag(), round(output[i].imag()));
3141             }
3142 
3143             encoder.encode(input, context.first_parms_id(), delta, plain);
3144             shift = 2;
3145             encryptor.encrypt(plain, encrypted);
3146             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3147             decryptor.decrypt(encrypted, plain);
3148             encoder.decode(plain, output);
3149             for (size_t i = 0; i < slot_size; i++)
3150             {
3151                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].real(), round(output[i].real()));
3152                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].imag(), round(output[i].imag()));
3153             }
3154 
3155             encoder.encode(input, context.first_parms_id(), delta, plain);
3156             shift = 3;
3157             encryptor.encrypt(plain, encrypted);
3158             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3159             decryptor.decrypt(encrypted, plain);
3160             encoder.decode(plain, output);
3161             for (size_t i = 0; i < slot_size; i++)
3162             {
3163                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].real(), round(output[i].real()));
3164                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].imag(), round(output[i].imag()));
3165             }
3166 
3167             encoder.encode(input, context.first_parms_id(), delta, plain);
3168             encryptor.encrypt(plain, encrypted);
3169             evaluator.complex_conjugate_inplace(encrypted, glk);
3170             decryptor.decrypt(encrypted, plain);
3171             encoder.decode(plain, output);
3172             for (size_t i = 0; i < slot_size; i++)
3173             {
3174                 ASSERT_EQ(input[i].real(), round(output[i].real()));
3175                 ASSERT_EQ(-input[i].imag(), round(output[i].imag()));
3176             }
3177         }
3178         {
3179             size_t slot_size = 32;
3180             parms.set_poly_modulus_degree(64);
3181             parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 }));
3182 
3183             SEALContext context(parms, false, sec_level_type::none);
3184             KeyGenerator keygen(context);
3185             PublicKey pk;
3186             keygen.create_public_key(pk);
3187             GaloisKeys glk;
3188             keygen.create_galois_keys(glk);
3189 
3190             Encryptor encryptor(context, pk);
3191             Evaluator evaluator(context);
3192             Decryptor decryptor(context, keygen.secret_key());
3193             CKKSEncoder encoder(context);
3194             const double delta = static_cast<double>(1ULL << 30);
3195 
3196             Ciphertext encrypted;
3197             Plaintext plain;
3198 
3199             vector<complex<double>> input{ complex<double>(1, 1), complex<double>(2, 2), complex<double>(3, 3),
3200                                            complex<double>(4, 4) };
3201             input.resize(slot_size);
3202 
3203             vector<complex<double>> output(slot_size, 0);
3204 
3205             encoder.encode(input, context.first_parms_id(), delta, plain);
3206             int shift = 1;
3207             encryptor.encrypt(plain, encrypted);
3208             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3209             decryptor.decrypt(encrypted, plain);
3210             encoder.decode(plain, output);
3211             for (size_t i = 0; i < input.size(); i++)
3212             {
3213                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].real()), round(output[i].real()));
3214                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].imag()), round(output[i].imag()));
3215             }
3216 
3217             encoder.encode(input, context.first_parms_id(), delta, plain);
3218             shift = 2;
3219             encryptor.encrypt(plain, encrypted);
3220             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3221             decryptor.decrypt(encrypted, plain);
3222             encoder.decode(plain, output);
3223             for (size_t i = 0; i < slot_size; i++)
3224             {
3225                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].real()), round(output[i].real()));
3226                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].imag()), round(output[i].imag()));
3227             }
3228 
3229             encoder.encode(input, context.first_parms_id(), delta, plain);
3230             shift = 3;
3231             encryptor.encrypt(plain, encrypted);
3232             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3233             decryptor.decrypt(encrypted, plain);
3234             encoder.decode(plain, output);
3235             for (size_t i = 0; i < slot_size; i++)
3236             {
3237                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].real()), round(output[i].real()));
3238                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].imag()), round(output[i].imag()));
3239             }
3240 
3241             encoder.encode(input, context.first_parms_id(), delta, plain);
3242             encryptor.encrypt(plain, encrypted);
3243             evaluator.complex_conjugate_inplace(encrypted, glk);
3244             decryptor.decrypt(encrypted, plain);
3245             encoder.decode(plain, output);
3246             for (size_t i = 0; i < slot_size; i++)
3247             {
3248                 ASSERT_EQ(round(input[i].real()), round(output[i].real()));
3249                 ASSERT_EQ(round(-input[i].imag()), round(output[i].imag()));
3250             }
3251         }
3252     }
3253 
TEST(EvaluatorTest,CKKSEncryptRescaleRotateDecrypt)3254     TEST(EvaluatorTest, CKKSEncryptRescaleRotateDecrypt)
3255     {
3256         EncryptionParameters parms(scheme_type::ckks);
3257         {
3258             // Maximal number of slots
3259             size_t slot_size = 4;
3260             parms.set_poly_modulus_degree(slot_size * 2);
3261             parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 }));
3262 
3263             SEALContext context(parms, true, sec_level_type::none);
3264             KeyGenerator keygen(context);
3265             PublicKey pk;
3266             keygen.create_public_key(pk);
3267             GaloisKeys glk;
3268             keygen.create_galois_keys(glk);
3269 
3270             Encryptor encryptor(context, pk);
3271             Evaluator evaluator(context);
3272             Decryptor decryptor(context, keygen.secret_key());
3273             CKKSEncoder encoder(context);
3274             const double delta = pow(2.0, 70);
3275 
3276             Ciphertext encrypted;
3277             Plaintext plain;
3278 
3279             vector<complex<double>> input{ complex<double>(1, 1), complex<double>(2, 2), complex<double>(3, 3),
3280                                            complex<double>(4, 4) };
3281             input.resize(slot_size);
3282 
3283             vector<complex<double>> output(slot_size, 0);
3284 
3285             encoder.encode(input, context.first_parms_id(), delta, plain);
3286             int shift = 1;
3287             encryptor.encrypt(plain, encrypted);
3288             evaluator.rescale_to_next_inplace(encrypted);
3289             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3290             decryptor.decrypt(encrypted, plain);
3291             encoder.decode(plain, output);
3292             for (size_t i = 0; i < slot_size; i++)
3293             {
3294                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].real(), round(output[i].real()));
3295                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].imag(), round(output[i].imag()));
3296             }
3297 
3298             encoder.encode(input, context.first_parms_id(), delta, plain);
3299             shift = 2;
3300             encryptor.encrypt(plain, encrypted);
3301             evaluator.rescale_to_next_inplace(encrypted);
3302             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3303             decryptor.decrypt(encrypted, plain);
3304             encoder.decode(plain, output);
3305             for (size_t i = 0; i < slot_size; i++)
3306             {
3307                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].real(), round(output[i].real()));
3308                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].imag(), round(output[i].imag()));
3309             }
3310 
3311             encoder.encode(input, context.first_parms_id(), delta, plain);
3312             shift = 3;
3313             encryptor.encrypt(plain, encrypted);
3314             evaluator.rescale_to_next_inplace(encrypted);
3315             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3316             decryptor.decrypt(encrypted, plain);
3317             encoder.decode(plain, output);
3318             for (size_t i = 0; i < slot_size; i++)
3319             {
3320                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].real(), round(output[i].real()));
3321                 ASSERT_EQ(input[(i + static_cast<size_t>(shift)) % slot_size].imag(), round(output[i].imag()));
3322             }
3323 
3324             encoder.encode(input, context.first_parms_id(), delta, plain);
3325             encryptor.encrypt(plain, encrypted);
3326             evaluator.rescale_to_next_inplace(encrypted);
3327             evaluator.complex_conjugate_inplace(encrypted, glk);
3328             decryptor.decrypt(encrypted, plain);
3329             encoder.decode(plain, output);
3330             for (size_t i = 0; i < slot_size; i++)
3331             {
3332                 ASSERT_EQ(input[i].real(), round(output[i].real()));
3333                 ASSERT_EQ(-input[i].imag(), round(output[i].imag()));
3334             }
3335         }
3336         {
3337             size_t slot_size = 32;
3338             parms.set_poly_modulus_degree(64);
3339             parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 }));
3340 
3341             SEALContext context(parms, true, sec_level_type::none);
3342             KeyGenerator keygen(context);
3343             PublicKey pk;
3344             keygen.create_public_key(pk);
3345             GaloisKeys glk;
3346             keygen.create_galois_keys(glk);
3347 
3348             Encryptor encryptor(context, pk);
3349             Evaluator evaluator(context);
3350             Decryptor decryptor(context, keygen.secret_key());
3351             CKKSEncoder encoder(context);
3352             const double delta = pow(2, 70);
3353 
3354             Ciphertext encrypted;
3355             Plaintext plain;
3356 
3357             vector<complex<double>> input{ complex<double>(1, 1), complex<double>(2, 2), complex<double>(3, 3),
3358                                            complex<double>(4, 4) };
3359             input.resize(slot_size);
3360 
3361             vector<complex<double>> output(slot_size, 0);
3362 
3363             encoder.encode(input, context.first_parms_id(), delta, plain);
3364             int shift = 1;
3365             encryptor.encrypt(plain, encrypted);
3366             evaluator.rescale_to_next_inplace(encrypted);
3367             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3368             decryptor.decrypt(encrypted, plain);
3369             encoder.decode(plain, output);
3370             for (size_t i = 0; i < slot_size; i++)
3371             {
3372                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].real()), round(output[i].real()));
3373                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].imag()), round(output[i].imag()));
3374             }
3375 
3376             encoder.encode(input, context.first_parms_id(), delta, plain);
3377             shift = 2;
3378             encryptor.encrypt(plain, encrypted);
3379             evaluator.rescale_to_next_inplace(encrypted);
3380             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3381             decryptor.decrypt(encrypted, plain);
3382             encoder.decode(plain, output);
3383             for (size_t i = 0; i < slot_size; i++)
3384             {
3385                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].real()), round(output[i].real()));
3386                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].imag()), round(output[i].imag()));
3387             }
3388 
3389             encoder.encode(input, context.first_parms_id(), delta, plain);
3390             shift = 3;
3391             encryptor.encrypt(plain, encrypted);
3392             evaluator.rescale_to_next_inplace(encrypted);
3393             evaluator.rotate_vector_inplace(encrypted, shift, glk);
3394             decryptor.decrypt(encrypted, plain);
3395             encoder.decode(plain, output);
3396             for (size_t i = 0; i < slot_size; i++)
3397             {
3398                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].real()), round(output[i].real()));
3399                 ASSERT_EQ(round(input[(i + static_cast<size_t>(shift)) % slot_size].imag()), round(output[i].imag()));
3400             }
3401 
3402             encoder.encode(input, context.first_parms_id(), delta, plain);
3403             encryptor.encrypt(plain, encrypted);
3404             evaluator.rescale_to_next_inplace(encrypted);
3405             evaluator.complex_conjugate_inplace(encrypted, glk);
3406             decryptor.decrypt(encrypted, plain);
3407             encoder.decode(plain, output);
3408             for (size_t i = 0; i < slot_size; i++)
3409             {
3410                 ASSERT_EQ(round(input[i].real()), round(output[i].real()));
3411                 ASSERT_EQ(round(-input[i].imag()), round(output[i].imag()));
3412             }
3413         }
3414     }
3415 
TEST(EvaluatorTest,BFVEncryptSquareDecrypt)3416     TEST(EvaluatorTest, BFVEncryptSquareDecrypt)
3417     {
3418         EncryptionParameters parms(scheme_type::bfv);
3419         Modulus plain_modulus(1 << 8);
3420         parms.set_poly_modulus_degree(128);
3421         parms.set_plain_modulus(plain_modulus);
3422         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 }));
3423 
3424         SEALContext context(parms, false, sec_level_type::none);
3425         KeyGenerator keygen(context);
3426         PublicKey pk;
3427         keygen.create_public_key(pk);
3428 
3429         Encryptor encryptor(context, pk);
3430         Evaluator evaluator(context);
3431         Decryptor decryptor(context, keygen.secret_key());
3432 
3433         Ciphertext encrypted;
3434         Plaintext plain;
3435 
3436         plain = "1";
3437         encryptor.encrypt(plain, encrypted);
3438         evaluator.square_inplace(encrypted);
3439         decryptor.decrypt(encrypted, plain);
3440         ASSERT_EQ(plain.to_string(), "1");
3441         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3442 
3443         plain = "0";
3444         encryptor.encrypt(plain, encrypted);
3445         evaluator.square_inplace(encrypted);
3446         decryptor.decrypt(encrypted, plain);
3447         ASSERT_EQ(plain.to_string(), "0");
3448         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3449 
3450         plain = "FFx^2 + FF";
3451         encryptor.encrypt(plain, encrypted);
3452         evaluator.square_inplace(encrypted);
3453         decryptor.decrypt(encrypted, plain);
3454         ASSERT_EQ(plain.to_string(), "1x^4 + 2x^2 + 1");
3455         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3456 
3457         plain = "FF";
3458         encryptor.encrypt(plain, encrypted);
3459         evaluator.square_inplace(encrypted);
3460         decryptor.decrypt(encrypted, plain);
3461         ASSERT_EQ(plain.to_string(), "1");
3462         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3463 
3464         plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1";
3465         encryptor.encrypt(plain, encrypted);
3466         evaluator.square_inplace(encrypted);
3467         decryptor.decrypt(encrypted, plain);
3468         ASSERT_EQ(
3469             plain.to_string(),
3470             "1x^12 + 2x^11 + 3x^10 + 4x^9 + 3x^8 + 4x^7 + 5x^6 + 4x^5 + 4x^4 + 2x^3 + 1x^2 + 2x^1 + 1");
3471         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3472 
3473         plain = "1x^16";
3474         encryptor.encrypt(plain, encrypted);
3475         evaluator.square_inplace(encrypted);
3476         decryptor.decrypt(encrypted, plain);
3477         ASSERT_EQ(plain.to_string(), "1x^32");
3478         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3479 
3480         plain = "1x^6 + 1x^5 + 1x^4 + 1x^3 + 1x^1 + 1";
3481         encryptor.encrypt(plain, encrypted);
3482         evaluator.square_inplace(encrypted);
3483         evaluator.square_inplace(encrypted);
3484         decryptor.decrypt(encrypted, plain);
3485         ASSERT_EQ(
3486             plain.to_string(),
3487             "1x^24 + 4x^23 + Ax^22 + 14x^21 + 1Fx^20 + 2Cx^19 + 3Cx^18 + 4Cx^17 + 5Fx^16 + 6Cx^15 + 70x^14 + 74x^13 + "
3488             "71x^12 + 6Cx^11 + 64x^10 + 50x^9 + 40x^8 + 34x^7 + 26x^6 + 1Cx^5 + 11x^4 + 8x^3 + 6x^2 + 4x^1 + 1");
3489         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3490     }
3491 
TEST(EvaluatorTest,BFVEncryptMultiplyManyDecrypt)3492     TEST(EvaluatorTest, BFVEncryptMultiplyManyDecrypt)
3493     {
3494         EncryptionParameters parms(scheme_type::bfv);
3495         Modulus plain_modulus(1 << 6);
3496         parms.set_poly_modulus_degree(128);
3497         parms.set_plain_modulus(plain_modulus);
3498         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 }));
3499 
3500         SEALContext context(parms, false, sec_level_type::none);
3501         KeyGenerator keygen(context);
3502         PublicKey pk;
3503         keygen.create_public_key(pk);
3504         RelinKeys rlk;
3505         keygen.create_relin_keys(rlk);
3506 
3507         Encryptor encryptor(context, pk);
3508         Evaluator evaluator(context);
3509         Decryptor decryptor(context, keygen.secret_key());
3510 
3511         Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, product;
3512         Plaintext plain, plain1, plain2, plain3, plain4;
3513 
3514         plain1 = "1x^2 + 1";
3515         plain2 = "1x^2 + 1x^1";
3516         plain3 = "1x^2 + 1x^1 + 1";
3517         encryptor.encrypt(plain1, encrypted1);
3518         encryptor.encrypt(plain2, encrypted2);
3519         encryptor.encrypt(plain3, encrypted3);
3520         vector<Ciphertext> encrypteds{ encrypted1, encrypted2, encrypted3 };
3521         evaluator.multiply_many(encrypteds, rlk, product);
3522         ASSERT_EQ(3, encrypteds.size());
3523         decryptor.decrypt(product, plain);
3524         ASSERT_EQ(plain.to_string(), "1x^6 + 2x^5 + 3x^4 + 3x^3 + 2x^2 + 1x^1");
3525         ASSERT_TRUE(encrypted1.parms_id() == product.parms_id());
3526         ASSERT_TRUE(encrypted2.parms_id() == product.parms_id());
3527         ASSERT_TRUE(encrypted3.parms_id() == product.parms_id());
3528         ASSERT_TRUE(product.parms_id() == context.first_parms_id());
3529 
3530         plain1 = "3Fx^3 + 3F";
3531         plain2 = "3Fx^4 + 3F";
3532         encryptor.encrypt(plain1, encrypted1);
3533         encryptor.encrypt(plain2, encrypted2);
3534         encrypteds = { encrypted1, encrypted2 };
3535         evaluator.multiply_many(encrypteds, rlk, product);
3536         ASSERT_EQ(2, encrypteds.size());
3537         decryptor.decrypt(product, plain);
3538         ASSERT_EQ(plain.to_string(), "1x^7 + 1x^4 + 1x^3 + 1");
3539         ASSERT_TRUE(encrypted1.parms_id() == product.parms_id());
3540         ASSERT_TRUE(encrypted2.parms_id() == product.parms_id());
3541         ASSERT_TRUE(product.parms_id() == context.first_parms_id());
3542 
3543         plain1 = "1x^1";
3544         plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F";
3545         plain3 = "1x^2 + 1x^1 + 1";
3546         encryptor.encrypt(plain1, encrypted1);
3547         encryptor.encrypt(plain2, encrypted2);
3548         encryptor.encrypt(plain3, encrypted3);
3549         encrypteds = { encrypted1, encrypted2, encrypted3 };
3550         evaluator.multiply_many(encrypteds, rlk, product);
3551         ASSERT_EQ(3, encrypteds.size());
3552         decryptor.decrypt(product, plain);
3553         ASSERT_EQ(plain.to_string(), "3Fx^7 + 3Ex^6 + 3Dx^5 + 3Dx^4 + 3Dx^3 + 3Ex^2 + 3Fx^1");
3554         ASSERT_TRUE(encrypted1.parms_id() == product.parms_id());
3555         ASSERT_TRUE(encrypted2.parms_id() == product.parms_id());
3556         ASSERT_TRUE(encrypted3.parms_id() == product.parms_id());
3557         ASSERT_TRUE(product.parms_id() == context.first_parms_id());
3558 
3559         plain1 = "1";
3560         plain2 = "3F";
3561         plain3 = "1";
3562         plain4 = "3F";
3563         encryptor.encrypt(plain1, encrypted1);
3564         encryptor.encrypt(plain2, encrypted2);
3565         encryptor.encrypt(plain3, encrypted3);
3566         encryptor.encrypt(plain4, encrypted4);
3567         encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 };
3568         evaluator.multiply_many(encrypteds, rlk, product);
3569         ASSERT_EQ(4, encrypteds.size());
3570         decryptor.decrypt(product, plain);
3571         ASSERT_EQ(plain.to_string(), "1");
3572         ASSERT_TRUE(encrypted1.parms_id() == product.parms_id());
3573         ASSERT_TRUE(encrypted2.parms_id() == product.parms_id());
3574         ASSERT_TRUE(encrypted3.parms_id() == product.parms_id());
3575         ASSERT_TRUE(encrypted4.parms_id() == product.parms_id());
3576         ASSERT_TRUE(product.parms_id() == context.first_parms_id());
3577 
3578         plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1";
3579         plain2 = "0";
3580         plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1";
3581         plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1";
3582         encryptor.encrypt(plain1, encrypted1);
3583         encryptor.encrypt(plain2, encrypted2);
3584         encryptor.encrypt(plain3, encrypted3);
3585         encryptor.encrypt(plain4, encrypted4);
3586         encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 };
3587         evaluator.multiply_many(encrypteds, rlk, product);
3588         ASSERT_EQ(4, encrypteds.size());
3589         decryptor.decrypt(product, plain);
3590         ASSERT_EQ(plain.to_string(), "0");
3591         ASSERT_TRUE(encrypted1.parms_id() == product.parms_id());
3592         ASSERT_TRUE(encrypted2.parms_id() == product.parms_id());
3593         ASSERT_TRUE(encrypted3.parms_id() == product.parms_id());
3594         ASSERT_TRUE(encrypted4.parms_id() == product.parms_id());
3595         ASSERT_TRUE(product.parms_id() == context.first_parms_id());
3596     }
3597 
TEST(EvaluatorTest,BFVEncryptExponentiateDecrypt)3598     TEST(EvaluatorTest, BFVEncryptExponentiateDecrypt)
3599     {
3600         EncryptionParameters parms(scheme_type::bfv);
3601         Modulus plain_modulus(1 << 6);
3602         parms.set_poly_modulus_degree(128);
3603         parms.set_plain_modulus(plain_modulus);
3604         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 }));
3605 
3606         SEALContext context(parms, false, sec_level_type::none);
3607         KeyGenerator keygen(context);
3608         PublicKey pk;
3609         keygen.create_public_key(pk);
3610         RelinKeys rlk;
3611         keygen.create_relin_keys(rlk);
3612 
3613         Encryptor encryptor(context, pk);
3614         Evaluator evaluator(context);
3615         Decryptor decryptor(context, keygen.secret_key());
3616 
3617         Ciphertext encrypted;
3618         Plaintext plain;
3619 
3620         plain = "1x^2 + 1";
3621         encryptor.encrypt(plain, encrypted);
3622         evaluator.exponentiate_inplace(encrypted, 1, rlk);
3623         decryptor.decrypt(encrypted, plain);
3624         ASSERT_EQ(plain.to_string(), "1x^2 + 1");
3625         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3626 
3627         plain = "1x^2 + 1x^1 + 1";
3628         encryptor.encrypt(plain, encrypted);
3629         evaluator.exponentiate_inplace(encrypted, 2, rlk);
3630         decryptor.decrypt(encrypted, plain);
3631         ASSERT_EQ(plain.to_string(), "1x^4 + 2x^3 + 3x^2 + 2x^1 + 1");
3632         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3633 
3634         plain = "3Fx^2 + 3Fx^1 + 3F";
3635         encryptor.encrypt(plain, encrypted);
3636         evaluator.exponentiate_inplace(encrypted, 3, rlk);
3637         decryptor.decrypt(encrypted, plain);
3638         ASSERT_EQ(plain.to_string(), "3Fx^6 + 3Dx^5 + 3Ax^4 + 39x^3 + 3Ax^2 + 3Dx^1 + 3F");
3639         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3640 
3641         plain = "1x^8";
3642         encryptor.encrypt(plain, encrypted);
3643         evaluator.exponentiate_inplace(encrypted, 4, rlk);
3644         decryptor.decrypt(encrypted, plain);
3645         ASSERT_EQ(plain.to_string(), "1x^32");
3646         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3647     }
3648 
TEST(EvaluatorTest,BFVEncryptAddManyDecrypt)3649     TEST(EvaluatorTest, BFVEncryptAddManyDecrypt)
3650     {
3651         EncryptionParameters parms(scheme_type::bfv);
3652         Modulus plain_modulus(1 << 6);
3653         parms.set_poly_modulus_degree(128);
3654         parms.set_plain_modulus(plain_modulus);
3655         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 }));
3656 
3657         SEALContext context(parms, false, sec_level_type::none);
3658         KeyGenerator keygen(context);
3659         PublicKey pk;
3660         keygen.create_public_key(pk);
3661 
3662         Encryptor encryptor(context, pk);
3663         Evaluator evaluator(context);
3664         Decryptor decryptor(context, keygen.secret_key());
3665 
3666         Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, sum;
3667         Plaintext plain, plain1, plain2, plain3, plain4;
3668 
3669         plain1 = "1x^2 + 1";
3670         plain2 = "1x^2 + 1x^1";
3671         plain3 = "1x^2 + 1x^1 + 1";
3672         encryptor.encrypt(plain1, encrypted1);
3673         encryptor.encrypt(plain2, encrypted2);
3674         encryptor.encrypt(plain3, encrypted3);
3675         vector<Ciphertext> encrypteds = { encrypted1, encrypted2, encrypted3 };
3676         evaluator.add_many(encrypteds, sum);
3677         decryptor.decrypt(sum, plain);
3678         ASSERT_EQ(plain.to_string(), "3x^2 + 2x^1 + 2");
3679         ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id());
3680         ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id());
3681         ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id());
3682         ASSERT_TRUE(sum.parms_id() == context.first_parms_id());
3683 
3684         plain1 = "3Fx^3 + 3F";
3685         plain2 = "3Fx^4 + 3F";
3686         encryptor.encrypt(plain1, encrypted1);
3687         encryptor.encrypt(plain2, encrypted2);
3688         encrypteds = {
3689             encrypted1,
3690             encrypted2,
3691         };
3692         evaluator.add_many(encrypteds, sum);
3693         decryptor.decrypt(sum, plain);
3694         ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 3E");
3695         ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id());
3696         ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id());
3697         ASSERT_TRUE(sum.parms_id() == context.first_parms_id());
3698 
3699         plain1 = "1x^1";
3700         plain2 = "3Fx^4 + 3Fx^3 + 3Fx^2 + 3Fx^1 + 3F";
3701         plain3 = "1x^2 + 1x^1 + 1";
3702         encryptor.encrypt(plain1, encrypted1);
3703         encryptor.encrypt(plain2, encrypted2);
3704         encryptor.encrypt(plain3, encrypted3);
3705         encrypteds = { encrypted1, encrypted2, encrypted3 };
3706         evaluator.add_many(encrypteds, sum);
3707         decryptor.decrypt(sum, plain);
3708         ASSERT_EQ(plain.to_string(), "3Fx^4 + 3Fx^3 + 1x^1");
3709         ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id());
3710         ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id());
3711         ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id());
3712         ASSERT_TRUE(sum.parms_id() == context.first_parms_id());
3713 
3714         plain1 = "1";
3715         plain2 = "3F";
3716         plain3 = "1";
3717         plain4 = "3F";
3718         encryptor.encrypt(plain1, encrypted1);
3719         encryptor.encrypt(plain2, encrypted2);
3720         encryptor.encrypt(plain3, encrypted3);
3721         encryptor.encrypt(plain4, encrypted4);
3722         encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 };
3723         evaluator.add_many(encrypteds, sum);
3724         decryptor.decrypt(sum, plain);
3725         ASSERT_EQ(plain.to_string(), "0");
3726         ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id());
3727         ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id());
3728         ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id());
3729         ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id());
3730         ASSERT_TRUE(sum.parms_id() == context.first_parms_id());
3731 
3732         plain1 = "1x^16 + 1x^15 + 1x^8 + 1x^7 + 1x^6 + 1x^3 + 1x^2 + 1";
3733         plain2 = "0";
3734         plain3 = "1x^13 + 1x^12 + 1x^5 + 1x^4 + 1x^3 + 1";
3735         plain4 = "1x^15 + 1x^10 + 1x^9 + 1x^8 + 1x^2 + 1x^1 + 1";
3736         encryptor.encrypt(plain1, encrypted1);
3737         encryptor.encrypt(plain2, encrypted2);
3738         encryptor.encrypt(plain3, encrypted3);
3739         encryptor.encrypt(plain4, encrypted4);
3740         encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 };
3741         evaluator.add_many(encrypteds, sum);
3742         decryptor.decrypt(sum, plain);
3743         ASSERT_EQ(
3744             plain.to_string(),
3745             "1x^16 + 2x^15 + 1x^13 + 1x^12 + 1x^10 + 1x^9 + 2x^8 + 1x^7 + 1x^6 + 1x^5 + 1x^4 + 2x^3 + 2x^2 + 1x^1 + 3");
3746         ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id());
3747         ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id());
3748         ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id());
3749         ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id());
3750         ASSERT_TRUE(sum.parms_id() == context.first_parms_id());
3751     }
3752 
TEST(EvaluatorTest,TransformPlainToNTT)3753     TEST(EvaluatorTest, TransformPlainToNTT)
3754     {
3755         EncryptionParameters parms(scheme_type::bfv);
3756         Modulus plain_modulus(1 << 6);
3757         parms.set_poly_modulus_degree(128);
3758         parms.set_plain_modulus(plain_modulus);
3759         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 }));
3760         SEALContext context(parms, true, sec_level_type::none);
3761 
3762         Evaluator evaluator(context);
3763         Plaintext plain("0");
3764         ASSERT_FALSE(plain.is_ntt_form());
3765         evaluator.transform_to_ntt_inplace(plain, context.first_parms_id());
3766         ASSERT_TRUE(plain.is_zero());
3767         ASSERT_TRUE(plain.is_ntt_form());
3768         ASSERT_TRUE(plain.parms_id() == context.first_parms_id());
3769 
3770         plain.release();
3771         plain = "0";
3772         ASSERT_FALSE(plain.is_ntt_form());
3773         auto next_parms_id = context.first_context_data()->next_context_data()->parms_id();
3774         evaluator.transform_to_ntt_inplace(plain, next_parms_id);
3775         ASSERT_TRUE(plain.is_zero());
3776         ASSERT_TRUE(plain.is_ntt_form());
3777         ASSERT_TRUE(plain.parms_id() == next_parms_id);
3778 
3779         plain.release();
3780         plain = "1";
3781         ASSERT_FALSE(plain.is_ntt_form());
3782         evaluator.transform_to_ntt_inplace(plain, context.first_parms_id());
3783         for (size_t i = 0; i < 256; i++)
3784         {
3785             ASSERT_TRUE(plain[i] == uint64_t(1));
3786         }
3787         ASSERT_TRUE(plain.is_ntt_form());
3788         ASSERT_TRUE(plain.parms_id() == context.first_parms_id());
3789 
3790         plain.release();
3791         plain = "1";
3792         ASSERT_FALSE(plain.is_ntt_form());
3793         evaluator.transform_to_ntt_inplace(plain, next_parms_id);
3794         for (size_t i = 0; i < 128; i++)
3795         {
3796             ASSERT_TRUE(plain[i] == uint64_t(1));
3797         }
3798         ASSERT_TRUE(plain.is_ntt_form());
3799         ASSERT_TRUE(plain.parms_id() == next_parms_id);
3800 
3801         plain.release();
3802         plain = "2";
3803         ASSERT_FALSE(plain.is_ntt_form());
3804         evaluator.transform_to_ntt_inplace(plain, context.first_parms_id());
3805         for (size_t i = 0; i < 256; i++)
3806         {
3807             ASSERT_TRUE(plain[i] == uint64_t(2));
3808         }
3809         ASSERT_TRUE(plain.is_ntt_form());
3810         ASSERT_TRUE(plain.parms_id() == context.first_parms_id());
3811 
3812         plain.release();
3813         plain = "2";
3814         evaluator.transform_to_ntt_inplace(plain, next_parms_id);
3815         for (size_t i = 0; i < 128; i++)
3816         {
3817             ASSERT_TRUE(plain[i] == uint64_t(2));
3818         }
3819         ASSERT_TRUE(plain.is_ntt_form());
3820         ASSERT_TRUE(plain.parms_id() == next_parms_id);
3821     }
3822 
TEST(EvaluatorTest,TransformEncryptedToFromNTT)3823     TEST(EvaluatorTest, TransformEncryptedToFromNTT)
3824     {
3825         EncryptionParameters parms(scheme_type::bfv);
3826         Modulus plain_modulus(1 << 6);
3827         parms.set_poly_modulus_degree(128);
3828         parms.set_plain_modulus(plain_modulus);
3829         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 }));
3830 
3831         SEALContext context(parms, false, sec_level_type::none);
3832         KeyGenerator keygen(context);
3833         PublicKey pk;
3834         keygen.create_public_key(pk);
3835 
3836         Encryptor encryptor(context, pk);
3837         Evaluator evaluator(context);
3838         Decryptor decryptor(context, keygen.secret_key());
3839 
3840         Plaintext plain;
3841         Ciphertext encrypted;
3842         plain = "0";
3843         encryptor.encrypt(plain, encrypted);
3844         evaluator.transform_to_ntt_inplace(encrypted);
3845         evaluator.transform_from_ntt_inplace(encrypted);
3846         decryptor.decrypt(encrypted, plain);
3847         ASSERT_TRUE(plain.to_string() == "0");
3848         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3849 
3850         plain = "1";
3851         encryptor.encrypt(plain, encrypted);
3852         evaluator.transform_to_ntt_inplace(encrypted);
3853         evaluator.transform_from_ntt_inplace(encrypted);
3854         decryptor.decrypt(encrypted, plain);
3855         ASSERT_TRUE(plain.to_string() == "1");
3856         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3857 
3858         plain = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5";
3859         encryptor.encrypt(plain, encrypted);
3860         evaluator.transform_to_ntt_inplace(encrypted);
3861         evaluator.transform_from_ntt_inplace(encrypted);
3862         decryptor.decrypt(encrypted, plain);
3863         ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5");
3864         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3865     }
3866 
TEST(EvaluatorTest,BFVEncryptMultiplyPlainNTTDecrypt)3867     TEST(EvaluatorTest, BFVEncryptMultiplyPlainNTTDecrypt)
3868     {
3869         EncryptionParameters parms(scheme_type::bfv);
3870         Modulus plain_modulus(1 << 6);
3871         parms.set_poly_modulus_degree(128);
3872         parms.set_plain_modulus(plain_modulus);
3873         parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 }));
3874 
3875         SEALContext context(parms, false, sec_level_type::none);
3876         KeyGenerator keygen(context);
3877         PublicKey pk;
3878         keygen.create_public_key(pk);
3879 
3880         Encryptor encryptor(context, pk);
3881         Evaluator evaluator(context);
3882         Decryptor decryptor(context, keygen.secret_key());
3883 
3884         Plaintext plain;
3885         Plaintext plain_multiplier;
3886         Ciphertext encrypted;
3887 
3888         plain = 0;
3889         encryptor.encrypt(plain, encrypted);
3890         evaluator.transform_to_ntt_inplace(encrypted);
3891         plain_multiplier = 1;
3892         evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id());
3893         evaluator.multiply_plain_inplace(encrypted, plain_multiplier);
3894         evaluator.transform_from_ntt_inplace(encrypted);
3895         decryptor.decrypt(encrypted, plain);
3896         ASSERT_TRUE(plain.to_string() == "0");
3897         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3898 
3899         plain = 2;
3900         encryptor.encrypt(plain, encrypted);
3901         evaluator.transform_to_ntt_inplace(encrypted);
3902         plain_multiplier.release();
3903         plain_multiplier = 3;
3904         evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id());
3905         evaluator.multiply_plain_inplace(encrypted, plain_multiplier);
3906         evaluator.transform_from_ntt_inplace(encrypted);
3907         decryptor.decrypt(encrypted, plain);
3908         ASSERT_TRUE(plain.to_string() == "6");
3909         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3910 
3911         plain = 1;
3912         encryptor.encrypt(plain, encrypted);
3913         evaluator.transform_to_ntt_inplace(encrypted);
3914         plain_multiplier.release();
3915         plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5";
3916         evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id());
3917         evaluator.multiply_plain_inplace(encrypted, plain_multiplier);
3918         evaluator.transform_from_ntt_inplace(encrypted);
3919         decryptor.decrypt(encrypted, plain);
3920         ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5");
3921         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3922 
3923         plain = "1x^20";
3924         encryptor.encrypt(plain, encrypted);
3925         evaluator.transform_to_ntt_inplace(encrypted);
3926         plain_multiplier.release();
3927         plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5";
3928         evaluator.transform_to_ntt_inplace(plain_multiplier, context.first_parms_id());
3929         evaluator.multiply_plain_inplace(encrypted, plain_multiplier);
3930         evaluator.transform_from_ntt_inplace(encrypted);
3931         decryptor.decrypt(encrypted, plain);
3932         ASSERT_TRUE(
3933             plain.to_string() ==
3934             "Fx^30 + Ex^29 + Dx^28 + Cx^27 + Bx^26 + Ax^25 + 1x^24 + 2x^23 + 3x^22 + 4x^21 + 5x^20");
3935         ASSERT_TRUE(encrypted.parms_id() == context.first_parms_id());
3936     }
3937 
TEST(EvaluatorTest,BFVEncryptApplyGaloisDecrypt)3938     TEST(EvaluatorTest, BFVEncryptApplyGaloisDecrypt)
3939     {
3940         EncryptionParameters parms(scheme_type::bfv);
3941         Modulus plain_modulus(257);
3942         parms.set_poly_modulus_degree(8);
3943         parms.set_plain_modulus(plain_modulus);
3944         parms.set_coeff_modulus(CoeffModulus::Create(8, { 40, 40 }));
3945 
3946         SEALContext context(parms, false, sec_level_type::none);
3947         KeyGenerator keygen(context);
3948         PublicKey pk;
3949         keygen.create_public_key(pk);
3950         GaloisKeys glk;
3951         keygen.create_galois_keys(vector<uint32_t>{ 1, 3, 5, 15 }, glk);
3952 
3953         Encryptor encryptor(context, pk);
3954         Evaluator evaluator(context);
3955         Decryptor decryptor(context, keygen.secret_key());
3956 
3957         Plaintext plain("1");
3958         Ciphertext encrypted;
3959         encryptor.encrypt(plain, encrypted);
3960         evaluator.apply_galois_inplace(encrypted, 1, glk);
3961         decryptor.decrypt(encrypted, plain);
3962         ASSERT_TRUE("1" == plain.to_string());
3963         evaluator.apply_galois_inplace(encrypted, 3, glk);
3964         decryptor.decrypt(encrypted, plain);
3965         ASSERT_TRUE("1" == plain.to_string());
3966         evaluator.apply_galois_inplace(encrypted, 5, glk);
3967         decryptor.decrypt(encrypted, plain);
3968         ASSERT_TRUE("1" == plain.to_string());
3969         evaluator.apply_galois_inplace(encrypted, 15, glk);
3970         decryptor.decrypt(encrypted, plain);
3971         ASSERT_TRUE("1" == plain.to_string());
3972 
3973         plain = "1x^1";
3974         encryptor.encrypt(plain, encrypted);
3975         evaluator.apply_galois_inplace(encrypted, 1, glk);
3976         decryptor.decrypt(encrypted, plain);
3977         ASSERT_TRUE("1x^1" == plain.to_string());
3978         evaluator.apply_galois_inplace(encrypted, 3, glk);
3979         decryptor.decrypt(encrypted, plain);
3980         ASSERT_TRUE("1x^3" == plain.to_string());
3981         evaluator.apply_galois_inplace(encrypted, 5, glk);
3982         decryptor.decrypt(encrypted, plain);
3983         ASSERT_TRUE("100x^7" == plain.to_string());
3984         evaluator.apply_galois_inplace(encrypted, 15, glk);
3985         decryptor.decrypt(encrypted, plain);
3986         ASSERT_TRUE("1x^1" == plain.to_string());
3987 
3988         plain = "1x^2";
3989         encryptor.encrypt(plain, encrypted);
3990         evaluator.apply_galois_inplace(encrypted, 1, glk);
3991         decryptor.decrypt(encrypted, plain);
3992         ASSERT_TRUE("1x^2" == plain.to_string());
3993         evaluator.apply_galois_inplace(encrypted, 3, glk);
3994         decryptor.decrypt(encrypted, plain);
3995         ASSERT_TRUE("1x^6" == plain.to_string());
3996         evaluator.apply_galois_inplace(encrypted, 5, glk);
3997         decryptor.decrypt(encrypted, plain);
3998         ASSERT_TRUE("100x^6" == plain.to_string());
3999         evaluator.apply_galois_inplace(encrypted, 15, glk);
4000         decryptor.decrypt(encrypted, plain);
4001         ASSERT_TRUE("1x^2" == plain.to_string());
4002 
4003         plain = "1x^3 + 2x^2 + 1x^1 + 1";
4004         encryptor.encrypt(plain, encrypted);
4005         evaluator.apply_galois_inplace(encrypted, 1, glk);
4006         decryptor.decrypt(encrypted, plain);
4007         ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string());
4008         evaluator.apply_galois_inplace(encrypted, 3, glk);
4009         decryptor.decrypt(encrypted, plain);
4010         ASSERT_TRUE("2x^6 + 1x^3 + 100x^1 + 1" == plain.to_string());
4011         evaluator.apply_galois_inplace(encrypted, 5, glk);
4012         decryptor.decrypt(encrypted, plain);
4013         ASSERT_TRUE("100x^7 + FFx^6 + 100x^5 + 1" == plain.to_string());
4014         evaluator.apply_galois_inplace(encrypted, 15, glk);
4015         decryptor.decrypt(encrypted, plain);
4016         ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string());
4017     }
4018 
TEST(EvaluatorTest,BFVEncryptRotateMatrixDecrypt)4019     TEST(EvaluatorTest, BFVEncryptRotateMatrixDecrypt)
4020     {
4021         EncryptionParameters parms(scheme_type::bfv);
4022         Modulus plain_modulus(257);
4023         parms.set_poly_modulus_degree(8);
4024         parms.set_plain_modulus(plain_modulus);
4025         parms.set_coeff_modulus(CoeffModulus::Create(8, { 40, 40 }));
4026 
4027         SEALContext context(parms, false, sec_level_type::none);
4028         KeyGenerator keygen(context);
4029         PublicKey pk;
4030         keygen.create_public_key(pk);
4031         GaloisKeys glk;
4032         keygen.create_galois_keys(glk);
4033 
4034         Encryptor encryptor(context, pk);
4035         Evaluator evaluator(context);
4036         Decryptor decryptor(context, keygen.secret_key());
4037         BatchEncoder batch_encoder(context);
4038 
4039         Plaintext plain;
4040         vector<uint64_t> plain_vec{ 1, 2, 3, 4, 5, 6, 7, 8 };
4041         batch_encoder.encode(plain_vec, plain);
4042         Ciphertext encrypted;
4043         encryptor.encrypt(plain, encrypted);
4044 
4045         evaluator.rotate_columns_inplace(encrypted, glk);
4046         decryptor.decrypt(encrypted, plain);
4047         batch_encoder.decode(plain, plain_vec);
4048         ASSERT_TRUE((plain_vec == vector<uint64_t>{ 5, 6, 7, 8, 1, 2, 3, 4 }));
4049 
4050         evaluator.rotate_rows_inplace(encrypted, -1, glk);
4051         decryptor.decrypt(encrypted, plain);
4052         batch_encoder.decode(plain, plain_vec);
4053         ASSERT_TRUE((plain_vec == vector<uint64_t>{ 8, 5, 6, 7, 4, 1, 2, 3 }));
4054 
4055         evaluator.rotate_rows_inplace(encrypted, 2, glk);
4056         decryptor.decrypt(encrypted, plain);
4057         batch_encoder.decode(plain, plain_vec);
4058         ASSERT_TRUE((plain_vec == vector<uint64_t>{ 6, 7, 8, 5, 2, 3, 4, 1 }));
4059 
4060         evaluator.rotate_columns_inplace(encrypted, glk);
4061         decryptor.decrypt(encrypted, plain);
4062         batch_encoder.decode(plain, plain_vec);
4063         ASSERT_TRUE((plain_vec == vector<uint64_t>{ 2, 3, 4, 1, 6, 7, 8, 5 }));
4064 
4065         evaluator.rotate_rows_inplace(encrypted, 0, glk);
4066         decryptor.decrypt(encrypted, plain);
4067         batch_encoder.decode(plain, plain_vec);
4068         ASSERT_TRUE((plain_vec == vector<uint64_t>{ 2, 3, 4, 1, 6, 7, 8, 5 }));
4069     }
TEST(EvaluatorTest,BFVEncryptModSwitchToNextDecrypt)4070     TEST(EvaluatorTest, BFVEncryptModSwitchToNextDecrypt)
4071     {
4072         // The common parameters: the plaintext and the polynomial moduli
4073         Modulus plain_modulus(1 << 6);
4074 
4075         // The parameters and the context of the higher level
4076         EncryptionParameters parms(scheme_type::bfv);
4077         parms.set_poly_modulus_degree(128);
4078         parms.set_plain_modulus(plain_modulus);
4079         parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 }));
4080 
4081         SEALContext context(parms, true, sec_level_type::none);
4082         KeyGenerator keygen(context);
4083         SecretKey secret_key = keygen.secret_key();
4084         PublicKey pk;
4085         keygen.create_public_key(pk);
4086 
4087         Encryptor encryptor(context, pk);
4088         Evaluator evaluator(context);
4089         Decryptor decryptor(context, keygen.secret_key());
4090         auto parms_id = context.first_parms_id();
4091 
4092         Ciphertext encrypted(context);
4093         Ciphertext encryptedRes;
4094         Plaintext plain;
4095 
4096         plain = 0;
4097         encryptor.encrypt(plain, encrypted);
4098         evaluator.mod_switch_to_next(encrypted, encryptedRes);
4099         decryptor.decrypt(encryptedRes, plain);
4100         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4101         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4102         ASSERT_TRUE(plain.to_string() == "0");
4103 
4104         evaluator.mod_switch_to_next_inplace(encryptedRes);
4105         decryptor.decrypt(encryptedRes, plain);
4106         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4107         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4108         ASSERT_TRUE(plain.to_string() == "0");
4109 
4110         parms_id = context.first_parms_id();
4111         plain = 1;
4112         encryptor.encrypt(plain, encrypted);
4113         evaluator.mod_switch_to_next(encrypted, encryptedRes);
4114         decryptor.decrypt(encryptedRes, plain);
4115         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4116         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4117         ASSERT_TRUE(plain.to_string() == "1");
4118 
4119         evaluator.mod_switch_to_next_inplace(encryptedRes);
4120         decryptor.decrypt(encryptedRes, plain);
4121         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4122         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4123         ASSERT_TRUE(plain.to_string() == "1");
4124 
4125         parms_id = context.first_parms_id();
4126         plain = "1x^127";
4127         encryptor.encrypt(plain, encrypted);
4128         evaluator.mod_switch_to_next(encrypted, encryptedRes);
4129         decryptor.decrypt(encryptedRes, plain);
4130         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4131         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4132         ASSERT_TRUE(plain.to_string() == "1x^127");
4133 
4134         evaluator.mod_switch_to_next_inplace(encryptedRes);
4135         decryptor.decrypt(encryptedRes, plain);
4136         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4137         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4138         ASSERT_TRUE(plain.to_string() == "1x^127");
4139 
4140         parms_id = context.first_parms_id();
4141         plain = "5x^64 + Ax^5";
4142         encryptor.encrypt(plain, encrypted);
4143         evaluator.mod_switch_to_next(encrypted, encryptedRes);
4144         decryptor.decrypt(encryptedRes, plain);
4145         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4146         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4147         ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5");
4148 
4149         evaluator.mod_switch_to_next_inplace(encryptedRes);
4150         decryptor.decrypt(encryptedRes, plain);
4151         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4152         ASSERT_TRUE(encryptedRes.parms_id() == parms_id);
4153         ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5");
4154     }
4155 
TEST(EvaluatorTest,BFVEncryptModSwitchToDecrypt)4156     TEST(EvaluatorTest, BFVEncryptModSwitchToDecrypt)
4157     {
4158         // The common parameters: the plaintext and the polynomial moduli
4159         Modulus plain_modulus(1 << 6);
4160 
4161         // The parameters and the context of the higher level
4162         EncryptionParameters parms(scheme_type::bfv);
4163         parms.set_poly_modulus_degree(128);
4164         parms.set_plain_modulus(plain_modulus);
4165         parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 }));
4166 
4167         SEALContext context(parms, true, sec_level_type::none);
4168         KeyGenerator keygen(context);
4169         SecretKey secret_key = keygen.secret_key();
4170         PublicKey pk;
4171         keygen.create_public_key(pk);
4172 
4173         Encryptor encryptor(context, pk);
4174         Evaluator evaluator(context);
4175         Decryptor decryptor(context, keygen.secret_key());
4176         auto parms_id = context.first_parms_id();
4177 
4178         Ciphertext encrypted(context);
4179         Plaintext plain;
4180 
4181         plain = 0;
4182         encryptor.encrypt(plain, encrypted);
4183         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4184         decryptor.decrypt(encrypted, plain);
4185         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4186         ASSERT_TRUE(plain.to_string() == "0");
4187 
4188         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4189         encryptor.encrypt(plain, encrypted);
4190         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4191         decryptor.decrypt(encrypted, plain);
4192         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4193         ASSERT_TRUE(plain.to_string() == "0");
4194 
4195         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4196         encryptor.encrypt(plain, encrypted);
4197         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4198         decryptor.decrypt(encrypted, plain);
4199         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4200         ASSERT_TRUE(plain.to_string() == "0");
4201 
4202         parms_id = context.first_parms_id();
4203         encryptor.encrypt(plain, encrypted);
4204         parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id();
4205         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4206         decryptor.decrypt(encrypted, plain);
4207         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4208         ASSERT_TRUE(plain.to_string() == "0");
4209 
4210         parms_id = context.first_parms_id();
4211         plain = 1;
4212         encryptor.encrypt(plain, encrypted);
4213         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4214         decryptor.decrypt(encrypted, plain);
4215         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4216         ASSERT_TRUE(plain.to_string() == "1");
4217 
4218         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4219         encryptor.encrypt(plain, encrypted);
4220         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4221         decryptor.decrypt(encrypted, plain);
4222         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4223         ASSERT_TRUE(plain.to_string() == "1");
4224 
4225         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4226         encryptor.encrypt(plain, encrypted);
4227         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4228         decryptor.decrypt(encrypted, plain);
4229         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4230         ASSERT_TRUE(plain.to_string() == "1");
4231 
4232         parms_id = context.first_parms_id();
4233         encryptor.encrypt(plain, encrypted);
4234         parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id();
4235         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4236         decryptor.decrypt(encrypted, plain);
4237         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4238         ASSERT_TRUE(plain.to_string() == "1");
4239 
4240         parms_id = context.first_parms_id();
4241         plain = "1x^127";
4242         encryptor.encrypt(plain, encrypted);
4243         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4244         decryptor.decrypt(encrypted, plain);
4245         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4246         ASSERT_TRUE(plain.to_string() == "1x^127");
4247 
4248         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4249         encryptor.encrypt(plain, encrypted);
4250         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4251         decryptor.decrypt(encrypted, plain);
4252         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4253         ASSERT_TRUE(plain.to_string() == "1x^127");
4254 
4255         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4256         encryptor.encrypt(plain, encrypted);
4257         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4258         decryptor.decrypt(encrypted, plain);
4259         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4260         ASSERT_TRUE(plain.to_string() == "1x^127");
4261 
4262         parms_id = context.first_parms_id();
4263         encryptor.encrypt(plain, encrypted);
4264         parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id();
4265         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4266         decryptor.decrypt(encrypted, plain);
4267         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4268         ASSERT_TRUE(plain.to_string() == "1x^127");
4269 
4270         parms_id = context.first_parms_id();
4271         plain = "5x^64 + Ax^5";
4272         encryptor.encrypt(plain, encrypted);
4273         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4274         decryptor.decrypt(encrypted, plain);
4275         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4276         ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5");
4277 
4278         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4279         encryptor.encrypt(plain, encrypted);
4280         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4281         decryptor.decrypt(encrypted, plain);
4282         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4283         ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5");
4284 
4285         parms_id = context.get_context_data(parms_id)->next_context_data()->parms_id();
4286         encryptor.encrypt(plain, encrypted);
4287         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4288         decryptor.decrypt(encrypted, plain);
4289         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4290         ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5");
4291 
4292         parms_id = context.first_parms_id();
4293         encryptor.encrypt(plain, encrypted);
4294         parms_id = context.get_context_data(parms_id)->next_context_data()->next_context_data()->parms_id();
4295         evaluator.mod_switch_to_inplace(encrypted, parms_id);
4296         decryptor.decrypt(encrypted, plain);
4297         ASSERT_TRUE(encrypted.parms_id() == parms_id);
4298         ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5");
4299     }
4300 } // namespace sealtest
4301