1 /* Copyright (C) 2019-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 
13 #include <numeric>
14 
15 #include <helib/Ptxt.h>
16 #include <helib/helib.h>
17 #include <helib/replicate.h>
18 #include <helib/NumbTh.h>
19 
20 #include "test_common.h"
21 #include "gtest/gtest.h"
22 
23 namespace {
24 
25 struct BGVParameters
26 {
BGVParameters__anone97409180111::BGVParameters27   BGVParameters(unsigned m, unsigned p, unsigned r) : m(m), p(p), r(r){};
28 
29   const unsigned m;
30   const unsigned p;
31   const unsigned r;
32 
operator <<(std::ostream & os,const BGVParameters & params)33   friend std::ostream& operator<<(std::ostream& os, const BGVParameters& params)
34   {
35     return os << "{"
36               << "m = " << params.m << ", "
37               << "p = " << params.p << ", "
38               << "r = " << params.r << "}";
39   }
40 };
41 
42 class TestPtxtCKKS : public ::testing::TestWithParam<unsigned>
43 {
44 protected:
TestPtxtCKKS()45   TestPtxtCKKS() :
46       // Only relevant parameter is m for a CKKS plaintext
47       m(GetParam()),
48       context(m, -1, 40)
49   // VJS_NOTE: I changed r=50 to r=40.
50   // I find setting r so large can cause problems,
51   // and the test was not passing.
52   // This may be somethng that needs further investigation,
53   // but later...
54   // This probably has something to do with the slightly
55   // different logic in the new encoding functions
56   {}
57 
58   const unsigned long m;
59 
60   helib::Context context;
61 
62   const double pre_encryption_epsilon = 1E-8;
63   const double post_encryption_epsilon = 1E-3;
64 };
65 
TEST_P(TestPtxtCKKS,canBeConstructedWithCKKSContext)66 TEST_P(TestPtxtCKKS, canBeConstructedWithCKKSContext)
67 {
68   helib::Ptxt<helib::CKKS> ptxt(context);
69 }
70 
TEST_P(TestPtxtCKKS,canBeDefaultConstructed)71 TEST_P(TestPtxtCKKS, canBeDefaultConstructed) { helib::Ptxt<helib::CKKS> ptxt; }
72 
TEST_P(TestPtxtCKKS,canBeCopyConstructed)73 TEST_P(TestPtxtCKKS, canBeCopyConstructed)
74 {
75   helib::Ptxt<helib::CKKS> ptxt(context);
76   helib::Ptxt<helib::CKKS> ptxt2(ptxt);
77 }
78 
TEST_P(TestPtxtCKKS,canBeAssignedFromOtherPtxt)79 TEST_P(TestPtxtCKKS, canBeAssignedFromOtherPtxt)
80 {
81   helib::Ptxt<helib::CKKS> ptxt(context);
82   helib::Ptxt<helib::CKKS> ptxt2 = ptxt;
83 }
84 
TEST_P(TestPtxtCKKS,reportsWhetherItIsValid)85 TEST_P(TestPtxtCKKS, reportsWhetherItIsValid)
86 {
87   helib::Ptxt<helib::CKKS> invalid_ptxt;
88   helib::Ptxt<helib::CKKS> valid_ptxt(context);
89   EXPECT_FALSE(invalid_ptxt.isValid());
90   EXPECT_TRUE(valid_ptxt.isValid());
91 }
92 
TEST_P(TestPtxtCKKS,hasSameNumberOfSlotsAsContext)93 TEST_P(TestPtxtCKKS, hasSameNumberOfSlotsAsContext)
94 {
95   helib::Ptxt<helib::CKKS> ptxt(context);
96   EXPECT_EQ(ptxt.size(), context.ea->size());
97 }
98 
TEST_P(TestPtxtCKKS,preservesDataPassedIntoConstructor)99 TEST_P(TestPtxtCKKS, preservesDataPassedIntoConstructor)
100 {
101   std::vector<std::complex<double>> data(context.ea->size());
102   for (std::size_t i = 0; i < data.size(); ++i)
103     data[i] = i / 10.0;
104   helib::Ptxt<helib::CKKS> ptxt(context, data);
105 
106   COMPARE_CXDOUBLE_VECS(ptxt, data)
107 }
108 
TEST_P(TestPtxtCKKS,hasSameNumberOfSlotsAsContextWhenCreatedWithData)109 TEST_P(TestPtxtCKKS, hasSameNumberOfSlotsAsContextWhenCreatedWithData)
110 {
111   std::vector<std::complex<double>> data(context.ea->size() - 1);
112   helib::Ptxt<helib::CKKS> ptxt(context, data);
113   EXPECT_EQ(ptxt.size(), context.ea->size());
114 }
115 
TEST_P(TestPtxtCKKS,replicateValueWhenPassingASingleSlotTypeNumber)116 TEST_P(TestPtxtCKKS, replicateValueWhenPassingASingleSlotTypeNumber)
117 {
118   std::complex<double> num = {1. / 10.0, 1. / 10.0};
119 
120   helib::Ptxt<helib::CKKS> ptxt(context, num);
121   for (std::size_t i = 0; i < ptxt.size(); ++i) {
122     EXPECT_DOUBLE_EQ(ptxt[i].real(), num.real());
123     EXPECT_DOUBLE_EQ(ptxt[i].imag(), num.imag());
124   }
125 }
126 
TEST_P(TestPtxtCKKS,replicateValueWhenPassingASingleNonSlotTypeNumber)127 TEST_P(TestPtxtCKKS, replicateValueWhenPassingASingleNonSlotTypeNumber)
128 {
129   double num = 1. / 10.0;
130 
131   helib::Ptxt<helib::CKKS> ptxt(context, num);
132   for (std::size_t i = 0; i < ptxt.size(); ++i) {
133     EXPECT_DOUBLE_EQ(ptxt[i].real(), num);
134     EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
135   }
136 }
137 
TEST_P(TestPtxtCKKS,atMethodThrowsOrReturnsCorrectly)138 TEST_P(TestPtxtCKKS, atMethodThrowsOrReturnsCorrectly)
139 {
140   std::vector<std::complex<double>> data(context.ea->size());
141   for (std::size_t i = 0; i < data.size(); ++i)
142     data[i] = i / 10.0;
143   helib::Ptxt<helib::CKKS> ptxt(context, data);
144 
145   for (long i = -5; i < 0; ++i) {
146     EXPECT_THROW(ptxt.at(i), helib::OutOfRangeError);
147   }
148   for (long i = 0; i < helib::lsize(data); ++i) {
149     EXPECT_DOUBLE_EQ(ptxt.at(i).real(), data.at(i).real());
150     EXPECT_DOUBLE_EQ(ptxt.at(i).imag(), data.at(i).imag());
151   }
152   for (std::size_t i = data.size(); i < data.size() + 5; ++i) {
153     EXPECT_THROW(ptxt.at(i), helib::OutOfRangeError);
154   }
155 }
156 
TEST_P(TestPtxtCKKS,padsWithZerosWhenPassingInSmallDataVector)157 TEST_P(TestPtxtCKKS, padsWithZerosWhenPassingInSmallDataVector)
158 {
159   std::vector<std::complex<double>> data(context.ea->size() - 1);
160   for (long i = 0; i < helib::lsize(data); ++i) {
161     data[i] = {(i - 1) / 10.0, (i - 1) / 10.0};
162   }
163   helib::Ptxt<helib::CKKS> ptxt(context, data);
164   for (std::size_t i = 0; i < data.size(); ++i) {
165     EXPECT_DOUBLE_EQ(ptxt[i].real(), data[i].real());
166     EXPECT_DOUBLE_EQ(ptxt[i].imag(), data[i].imag());
167   }
168   for (std::size_t i = data.size(); i < ptxt.size(); ++i) {
169     EXPECT_DOUBLE_EQ(ptxt[i].real(), 0.0);
170     EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
171   }
172 }
173 
TEST_P(TestPtxtCKKS,preservesDataPassedIntoConstructorAsDouble)174 TEST_P(TestPtxtCKKS, preservesDataPassedIntoConstructorAsDouble)
175 {
176   std::vector<double> data(context.ea->size() - 1);
177   for (long i = 0; i < helib::lsize(data); ++i) {
178     data[i] = (i - 1) / 10.0;
179   }
180   helib::Ptxt<helib::CKKS> ptxt(context, data);
181   for (std::size_t i = 0; i < data.size(); ++i) {
182     EXPECT_DOUBLE_EQ(ptxt[i].real(), data[i]);
183     EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
184   }
185   for (std::size_t i = data.size(); i < ptxt.size(); ++i) {
186     EXPECT_DOUBLE_EQ(ptxt[i].real(), 0.0);
187     EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
188   }
189 }
190 
TEST_P(TestPtxtCKKS,writesDataCorrectlyToOstream)191 TEST_P(TestPtxtCKKS, writesDataCorrectlyToOstream)
192 {
193   std::vector<std::complex<double>> data(context.ea->size());
194   for (std::size_t i = 0; i < data.size(); ++i) {
195     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
196   }
197   helib::Ptxt<helib::CKKS> ptxt(context, data);
198   std::stringstream ss;
199   ss << "[";
200   ss << std::setprecision(std::numeric_limits<double>::digits10);
201   for (auto it = data.begin(); it != data.end(); it++) {
202     ss << "[" << it->real() << ", " << it->imag() << "]";
203     if (it != data.end() - 1) {
204       ss << ", ";
205     }
206   }
207   ss << "]";
208   std::string expected = ss.str();
209   std::ostringstream os;
210   os << ptxt;
211 
212   EXPECT_EQ(os.str(), expected);
213 }
214 
TEST_P(TestPtxtCKKS,readsDataCorrectlyFromIstream)215 TEST_P(TestPtxtCKKS, readsDataCorrectlyFromIstream)
216 {
217   std::vector<std::complex<double>> data(context.ea->size());
218   for (std::size_t i = 0; i < data.size(); ++i) {
219     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
220   }
221   helib::Ptxt<helib::CKKS> ptxt(context);
222   std::stringstream ss;
223   ss << "[";
224   ss << std::setprecision(std::numeric_limits<double>::digits10);
225   for (auto it = data.begin(); it != data.end(); it++) {
226     helib::serialize(ss, *it);
227     if (it != data.end() - 1) {
228       ss << ", ";
229     }
230   }
231   ss << "]";
232   std::istringstream is(ss.str());
233   is >> ptxt;
234 
235   for (std::size_t i = 0; i < ptxt.size(); ++i) {
236     EXPECT_NEAR(std::abs(ptxt[i] - data[i]), 0, pre_encryption_epsilon);
237   }
238 }
239 
TEST_P(TestPtxtCKKS,readsSquareBracketsDataCorrectly)240 TEST_P(TestPtxtCKKS, readsSquareBracketsDataCorrectly)
241 {
242   std::vector<std::complex<double>> data(context.ea->size());
243   for (std::size_t i = 0; i < data.size(); ++i) {
244     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
245   }
246   helib::Ptxt<helib::CKKS> ptxt(context);
247   std::stringstream ss;
248   ss << "[";
249   ss << std::setprecision(std::numeric_limits<double>::digits10);
250   for (auto it = data.begin(); it != data.end(); it++) {
251     ss << "[" << it->real() << ", " << it->imag() << "]";
252     if (it != data.end() - 1) {
253       ss << ", ";
254     }
255   }
256   ss << "]";
257   std::istringstream is(ss.str());
258   is >> ptxt;
259 
260   for (std::size_t i = 0; i < ptxt.size(); ++i) {
261     EXPECT_NEAR(std::abs(ptxt[i] - data[i]), 0, pre_encryption_epsilon);
262   }
263 }
264 
TEST_P(TestPtxtCKKS,serializeFunctionSerializesStdComplexCorrectly)265 TEST_P(TestPtxtCKKS, serializeFunctionSerializesStdComplexCorrectly)
266 {
267   // TODO: This test may be removed from the fixture and put as standalone
268   std::stringstream ss;
269   std::complex<double> num;
270 
271   num = 0.0;
272   helib::serialize(ss, num);
273   EXPECT_EQ(ss.str(), "[0, 0]");
274   ss.str("");
275 
276   num = 10.3;
277   helib::serialize(ss, num);
278   EXPECT_EQ(ss.str(), "[10.3, 0]");
279   ss.str("");
280 
281   num = {0, 10.3};
282   helib::serialize(ss, num);
283   EXPECT_EQ(ss.str(), "[0, 10.3]");
284   ss.str("");
285 
286   num = {3.3, 16.6};
287   helib::serialize(ss, num);
288   EXPECT_EQ(ss.str(), "[3.3, 16.6]");
289   ss.str("");
290 }
291 
TEST_P(TestPtxtCKKS,deserializeFunctionDeserializesStdComplexCorrectly)292 TEST_P(TestPtxtCKKS, deserializeFunctionDeserializesStdComplexCorrectly)
293 {
294   // TODO: This test may be removed from the fixture and put as standalone
295   std::complex<double> num, expected;
296   std::stringstream ss;
297 
298   num = 0.0;
299   ss << "[1,2,3]";
300   expected = 0.0;
301   EXPECT_THROW(helib::deserialize(ss, num), helib::IOError);
302   ss.str("");
303 
304   num = 0.0;
305   ss << "[]";
306   expected = 0.0;
307   helib::deserialize(ss, num);
308   EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
309   ss.str("");
310 
311   num = 0.0;
312   ss << "[0.0]";
313   expected = 0.0;
314   helib::deserialize(ss, num);
315   EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
316   ss.str("");
317 
318   num = 0.0;
319   ss << "[0.0,0.0]";
320   expected = 0.0;
321   helib::deserialize(ss, num);
322   EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
323   ss.str("");
324 
325   num = 0.0;
326   ss << "[5.3,0]";
327   expected = 5.3;
328   helib::deserialize(ss, num);
329   EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
330   ss.str("");
331 
332   num = 0.0;
333   ss << "[0,8.16]";
334   expected = {0.0, 8.16};
335   helib::deserialize(ss, num);
336   EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
337   ss.str("");
338 
339   num = 0.0;
340   ss << "[3.4,9.99]";
341   expected = {3.4, 9.99};
342   helib::deserialize(ss, num);
343   EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
344   ss.str("");
345 }
346 
TEST_P(TestPtxtCKKS,serializeFunctionSerializesCorrectly)347 TEST_P(TestPtxtCKKS, serializeFunctionSerializesCorrectly)
348 {
349   std::vector<std::complex<double>> data(context.ea->size());
350   for (std::size_t i = 0; i < data.size(); ++i) {
351     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
352   }
353   helib::Ptxt<helib::CKKS> ptxt(context, data);
354   std::stringstream ss;
355   ss << "[";
356   ss << std::setprecision(std::numeric_limits<double>::digits10);
357   for (auto it = data.begin(); it != data.end(); it++) {
358     ss << "[" << it->real() << ", " << it->imag() << "]";
359     if (it != data.end() - 1) {
360       ss << ", ";
361     }
362   }
363   ss << "]";
364   std::stringstream serialized_ptxt;
365   helib::serialize(serialized_ptxt, ptxt);
366 
367   EXPECT_EQ(serialized_ptxt.str(), ss.str());
368 }
369 
TEST_P(TestPtxtCKKS,deserializeWorksCorrectly)370 TEST_P(TestPtxtCKKS, deserializeWorksCorrectly)
371 {
372   std::vector<std::complex<double>> data(context.ea->size());
373   for (std::size_t i = 0; i < data.size(); ++i) {
374     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
375   }
376   helib::Ptxt<helib::CKKS> ptxt(context);
377   std::stringstream ss;
378   ss << "[";
379   ss << std::setprecision(std::numeric_limits<double>::digits10);
380   for (auto it = data.begin(); it != data.end(); it++) {
381     ss << "[" << it->real() << ", " << it->imag() << "]";
382     if (it != data.end() - 1) {
383       ss << ", ";
384     }
385   }
386   ss << "]";
387   std::istringstream is(ss.str());
388   is >> ptxt;
389 
390   for (std::size_t i = 0; i < ptxt.size(); ++i) {
391     EXPECT_NEAR(std::abs(ptxt[i] - data[i]), 0, pre_encryption_epsilon);
392   }
393 }
394 
TEST_P(TestPtxtCKKS,deserializeFunctionThrowsIfMoreElementsThanSlots)395 TEST_P(TestPtxtCKKS, deserializeFunctionThrowsIfMoreElementsThanSlots)
396 {
397   std::vector<std::complex<double>> data(context.ea->size() + 1);
398   for (std::size_t i = 0; i < data.size(); ++i) {
399     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
400   }
401   helib::Ptxt<helib::CKKS> ptxt(context);
402   std::stringstream ss;
403   ss << "[";
404   ss << std::setprecision(std::numeric_limits<double>::digits10);
405   for (auto it = data.begin(); it != data.end(); it++) {
406     ss << "[" << it->real() << ", " << it->imag() << "]";
407     if (it != data.end() - 1) {
408       ss << ", ";
409     }
410   }
411   ss << "]";
412   std::istringstream is(ss.str());
413   EXPECT_THROW(helib::deserialize(is, ptxt), helib::IOError);
414 }
415 
TEST_P(TestPtxtCKKS,rightShiftOperatorThrowsIfMoreElementsThanSlots)416 TEST_P(TestPtxtCKKS, rightShiftOperatorThrowsIfMoreElementsThanSlots)
417 {
418   std::vector<std::complex<double>> data(context.ea->size() + 1);
419   for (std::size_t i = 0; i < data.size(); ++i) {
420     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
421   }
422   helib::Ptxt<helib::CKKS> ptxt(context);
423   std::stringstream ss;
424   ss << "[";
425   ss << std::setprecision(std::numeric_limits<double>::digits10);
426   for (auto it = data.begin(); it != data.end(); it++) {
427     ss << "[" << it->real() << ", " << it->imag() << "]";
428     if (it != data.end() - 1) {
429       ss << ", ";
430     }
431   }
432   ss << "]";
433   std::istringstream is(ss.str());
434   EXPECT_THROW(is >> ptxt, helib::IOError);
435 }
436 
TEST_P(TestPtxtCKKS,deserializeIsInverseOfSerialize)437 TEST_P(TestPtxtCKKS, deserializeIsInverseOfSerialize)
438 {
439   std::vector<std::complex<double>> data(context.ea->size());
440   for (std::size_t i = 0; i < data.size(); ++i) {
441     data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
442   }
443   helib::Ptxt<helib::CKKS> ptxt(context, data);
444 
445   std::stringstream str;
446   str << ptxt;
447 
448   helib::Ptxt<helib::CKKS> deserialized(context);
449   str >> deserialized;
450 
451   for (std::size_t i = 0; i < ptxt.size(); ++i) {
452     EXPECT_NEAR(std::abs(ptxt[i] - deserialized[i]), 0, pre_encryption_epsilon);
453   }
454 }
455 
TEST_P(TestPtxtCKKS,readsManyPtxtsFromStream)456 TEST_P(TestPtxtCKKS, readsManyPtxtsFromStream)
457 {
458   std::vector<std::complex<double>> data1(context.ea->size());
459   std::vector<std::complex<double>> data2(context.ea->size());
460   std::vector<std::complex<double>> data3(context.ea->size());
461   for (std::size_t i = 0; i < data1.size(); ++i) {
462     data1[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
463     data2[i] = data1[i] * 2.0;
464     data3[i] = data1[i] * 3.5;
465   }
466   helib::Ptxt<helib::CKKS> ptxt1(context, data1);
467   helib::Ptxt<helib::CKKS> ptxt2(context, data2);
468   helib::Ptxt<helib::CKKS> ptxt3(context, data3);
469 
470   std::stringstream ss;
471   ss << ptxt1 << std::endl;
472   ss << ptxt2 << std::endl;
473   ss << ptxt3 << std::endl;
474 
475   helib::Ptxt<helib::CKKS> deserialized1(context);
476   helib::Ptxt<helib::CKKS> deserialized2(context);
477   helib::Ptxt<helib::CKKS> deserialized3(context);
478   ss >> deserialized1;
479   ss >> deserialized2;
480   ss >> deserialized3;
481 
482   for (std::size_t i = 0; i < ptxt1.size(); ++i) {
483     EXPECT_NEAR(std::abs(ptxt1[i] - deserialized1[i]),
484                 0,
485                 pre_encryption_epsilon);
486     EXPECT_NEAR(std::abs(ptxt2[i] - deserialized2[i]),
487                 0,
488                 pre_encryption_epsilon);
489     EXPECT_NEAR(std::abs(ptxt3[i] - deserialized3[i]),
490                 0,
491                 pre_encryption_epsilon);
492   }
493 }
494 
TEST_P(TestPtxtCKKS,getSlotReprReturnsData)495 TEST_P(TestPtxtCKKS, getSlotReprReturnsData)
496 {
497   std::vector<std::complex<double>> data(context.ea->size() - 1);
498   for (long i = 0; i < helib::lsize(data); ++i) {
499     data[i] = {(i - 1) / 10.0, (i - 1) / 10.0};
500   }
501   helib::Ptxt<helib::CKKS> ptxt(context, data);
502   std::vector<std::complex<double>> expected_repr(context.ea->size());
503   for (std::size_t i = 0; i < ptxt.size(); ++i) {
504     expected_repr[i] = i < data.size() ? data[i] : 0;
505   }
506   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_repr);
507 }
508 
TEST_P(TestPtxtCKKS,runningSumsWorksCorrectly)509 TEST_P(TestPtxtCKKS, runningSumsWorksCorrectly)
510 {
511   std::vector<std::complex<double>> data(context.ea->size());
512   for (std::size_t i = 0; i < data.size(); ++i) {
513     data[i] = {i / 1.0, (i * i) / 1.0};
514   }
515 
516   helib::Ptxt<helib::CKKS> ptxt(context, data);
517   ptxt.runningSums();
518 
519   std::vector<std::complex<double>> expected_result(context.ea->size());
520   for (std::size_t i = 0; i < data.size(); ++i)
521     expected_result[i] = {(i * (i + 1)) / 2.0,
522                           (i * (i + 1) * (2 * i + 1)) / 6.0};
523 
524   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
525 }
526 
TEST_P(TestPtxtCKKS,totalSumsWorksCorrectly)527 TEST_P(TestPtxtCKKS, totalSumsWorksCorrectly)
528 {
529   std::vector<std::complex<double>> data(context.ea->size());
530   for (std::size_t i = 0; i < data.size(); ++i) {
531     data[i] = {i / 1.0, (i * i) / 1.0};
532   }
533 
534   helib::Ptxt<helib::CKKS> ptxt(context, data);
535   ptxt.totalSums();
536 
537   std::vector<std::complex<double>> expected_result(context.ea->size());
538   for (std::size_t i = 0; i < data.size(); ++i)
539     expected_result[i] = {
540         ((data.size() - 1) * data.size()) / 2.0,
541         ((data.size() - 1) * data.size() * (2 * (data.size() - 1) + 1)) / 6.0};
542 
543   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
544 }
545 
TEST_P(TestPtxtCKKS,incrementalProductWorksCorrectly)546 TEST_P(TestPtxtCKKS, incrementalProductWorksCorrectly)
547 {
548   std::vector<std::complex<double>> data(context.ea->size());
549   for (std::size_t i = 0; i < data.size(); ++i) {
550     data[i] = {(i + 1) / 5.0, (i * i + 1) / 10.0};
551   }
552 
553   helib::Ptxt<helib::CKKS> ptxt(context, data);
554   ptxt.incrementalProduct();
555 
556   std::vector<std::complex<double>> expected_result(data);
557   for (std::size_t i = 1; i < data.size(); ++i)
558     expected_result[i] *= expected_result[i - 1];
559 
560   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result)
561 }
562 
TEST_P(TestPtxtCKKS,totalProductWorksCorrectly)563 TEST_P(TestPtxtCKKS, totalProductWorksCorrectly)
564 {
565   std::vector<std::complex<double>> data(context.ea->size());
566   for (std::size_t i = 0; i < data.size(); ++i) {
567     data[i] = {(i + 1) / 10.0, (i * i + 1) / 10.0};
568   }
569 
570   helib::Ptxt<helib::CKKS> ptxt(context, data);
571   ptxt.totalProduct();
572 
573   std::complex<double> product = {1.0, 0.0};
574   for (std::size_t i = 0; i < data.size(); ++i)
575     product *= data[i];
576   std::vector<std::complex<double>> expected_result(context.ea->size(),
577                                                     product);
578 
579   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
580 }
581 
TEST_P(TestPtxtCKKS,innerProductWorksCorrectly)582 TEST_P(TestPtxtCKKS, innerProductWorksCorrectly)
583 {
584   std::vector<std::complex<double>> data(context.ea->size());
585   for (std::size_t i = 0; i < data.size(); ++i) {
586     data[i] = {i / 1.0, (i * i) / 1.0};
587   }
588 
589   helib::Ptxt<helib::CKKS> ptxt(context, data);
590   std::vector<helib::Ptxt<helib::CKKS>> first_ptxt_vector(2, ptxt);
591   ptxt += ptxt;
592   std::vector<helib::Ptxt<helib::CKKS>> second_ptxt_vector(3, ptxt);
593 
594   helib::Ptxt<helib::CKKS> result(context);
595   innerProduct(result, first_ptxt_vector, second_ptxt_vector);
596 
597   std::vector<std::complex<double>> expected_result(data.size());
598   for (std::size_t i = 0; i < data.size(); ++i) {
599     expected_result[i] = (data[i] * (data[i] + data[i]));
600     expected_result[i] +=
601         expected_result[i]; // expected_result = 2*expected_result
602   }
603 
604   COMPARE_CXDOUBLE_VECS(result.getSlotRepr(), expected_result);
605 }
606 
TEST_P(TestPtxtCKKS,mapTo01MapsSlotsCorrectly)607 TEST_P(TestPtxtCKKS, mapTo01MapsSlotsCorrectly)
608 {
609   std::vector<std::complex<double>> data(context.ea->size());
610   for (std::size_t i = 0; i < data.size(); ++i) {
611     data[i] = {i / 1.0, (i * i) / 1.0};
612   }
613 
614   helib::Ptxt<helib::CKKS> ptxt(context, data);
615   helib::Ptxt<helib::CKKS> ptxt2(context, data);
616   // Should exist as a free function and a member function
617   ptxt.mapTo01();
618   mapTo01(*(context.ea), ptxt2);
619 
620   std::vector<std::complex<double>> expected_result(context.ea->size());
621   for (std::size_t i = 1; i < data.size(); ++i) {
622     expected_result[i] = {1, 0};
623   }
624 
625   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
626   COMPARE_CXDOUBLE_VECS(ptxt2.getSlotRepr(), expected_result);
627 }
628 
TEST_P(TestPtxtCKKS,timesEqualsOtherPlaintextWorks)629 TEST_P(TestPtxtCKKS, timesEqualsOtherPlaintextWorks)
630 {
631   std::vector<std::complex<double>> product_data(context.ea->size(),
632                                                  {-3.14, -1.0});
633   std::vector<std::complex<double>> multiplier_data(context.ea->size());
634   for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
635     multiplier_data[i] = {(i - 1) / 10.0, (i + 1) / 10.0};
636   }
637 
638   std::vector<std::complex<double>> expected_result(product_data);
639   for (std::size_t i = 0; i < product_data.size(); ++i) {
640     expected_result[i] = expected_result[i] * multiplier_data[i];
641   }
642 
643   helib::Ptxt<helib::CKKS> product(context, product_data);
644   helib::Ptxt<helib::CKKS> multiplier(context, multiplier_data);
645 
646   product *= multiplier;
647 
648   COMPARE_CXDOUBLE_VECS(product.getSlotRepr(), expected_result);
649 }
650 
TEST_P(TestPtxtCKKS,minusEqualsOtherPlaintextWorks)651 TEST_P(TestPtxtCKKS, minusEqualsOtherPlaintextWorks)
652 {
653   std::vector<std::complex<double>> difference_data(context.ea->size(),
654                                                     {2.718, -1.0});
655   std::vector<std::complex<double>> subtrahend_data(context.ea->size());
656   for (long i = 0; i < helib::lsize(subtrahend_data); ++i) {
657     subtrahend_data[i] = {(i - 1) / 10.0, (i + 1) / 10.0};
658   }
659 
660   std::vector<std::complex<double>> expected_result(difference_data);
661   for (std::size_t i = 0; i < subtrahend_data.size(); ++i) {
662     expected_result[i] = expected_result[i] - subtrahend_data[i];
663   }
664 
665   helib::Ptxt<helib::CKKS> difference(context, difference_data);
666   helib::Ptxt<helib::CKKS> subtrahend(context, subtrahend_data);
667 
668   difference -= subtrahend;
669 
670   COMPARE_CXDOUBLE_VECS(difference.getSlotRepr(), expected_result);
671 }
672 
TEST_P(TestPtxtCKKS,minusEqualsComplexScalarWorks)673 TEST_P(TestPtxtCKKS, minusEqualsComplexScalarWorks)
674 {
675   std::vector<std::complex<double>> data(context.ea->size());
676   for (long i = 0; i < helib::lsize(data); ++i) {
677     data[i] = {(i * i - 1) / 10.0, (i * i + 1) / 10.0};
678   }
679 
680   const std::complex<double> scalar = {2.5, -0.5};
681 
682   std::vector<std::complex<double>> expected_result(data);
683   for (auto& num : expected_result)
684     num = num - scalar;
685 
686   helib::Ptxt<helib::CKKS> ptxt(context, data);
687   ptxt -= scalar;
688 
689   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
690 }
691 
TEST_P(TestPtxtCKKS,minusEqualsNonComplexScalarWorks)692 TEST_P(TestPtxtCKKS, minusEqualsNonComplexScalarWorks)
693 {
694   std::vector<std::complex<double>> data(context.ea->size());
695   for (long i = 0; i < helib::lsize(data); ++i) {
696     data[i] = {(i * i - 1) / 2.0, (i * i + 1) / 5.0};
697   }
698 
699   const double scalar = 15.3;
700   const int int_scalar = 2;
701 
702   std::vector<std::complex<double>> expected_result(data);
703   for (auto& num : expected_result) {
704     num = num - scalar - static_cast<double>(int_scalar);
705   }
706 
707   helib::Ptxt<helib::CKKS> ptxt(context, data);
708   ptxt -= scalar;
709   ptxt -= int_scalar;
710 
711   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
712 }
713 
TEST_P(TestPtxtCKKS,plusEqualsOtherPlaintextWorks)714 TEST_P(TestPtxtCKKS, plusEqualsOtherPlaintextWorks)
715 {
716   std::vector<std::complex<double>> augend_data(context.ea->size());
717   std::vector<std::complex<double>> addend_data(context.ea->size());
718   for (long i = 0; i < helib::lsize(addend_data); ++i) {
719     augend_data[i] = {i / 10.0, i * i / 10.0};
720     addend_data[i] = {i / 20.0, i * i / 20.0};
721   }
722   std::vector<std::complex<double>> expected_result(context.ea->size());
723   for (std::size_t i = 0; i < expected_result.size(); ++i)
724     expected_result[i] = augend_data[i] + addend_data[i];
725 
726   helib::Ptxt<helib::CKKS> sum(context, augend_data);
727   helib::Ptxt<helib::CKKS> addend(context, addend_data);
728   sum += addend;
729 
730   COMPARE_CXDOUBLE_VECS(sum.getSlotRepr(), expected_result);
731 }
732 
TEST_P(TestPtxtCKKS,plusEqualsComplexScalarWorks)733 TEST_P(TestPtxtCKKS, plusEqualsComplexScalarWorks)
734 {
735   std::vector<std::complex<double>> data(context.ea->size());
736   for (long i = 0; i < helib::lsize(data); ++i) {
737     data[i] = {-i / 10.0, (3 - i) / 4.0};
738   }
739 
740   const double scalar = 3.14;
741 
742   std::vector<std::complex<double>> expected_result(data);
743   for (auto& num : expected_result)
744     num += scalar;
745 
746   helib::Ptxt<helib::CKKS> ptxt(context, data);
747   ptxt += scalar;
748 
749   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
750 }
751 
TEST_P(TestPtxtCKKS,plusEqualsNonComplexScalarWorks)752 TEST_P(TestPtxtCKKS, plusEqualsNonComplexScalarWorks)
753 {
754   std::vector<std::complex<double>> data(context.ea->size());
755   for (long i = 0; i < helib::lsize(data); ++i) {
756     data[i] = {i + i / 5.0, i - i / 4.0};
757   }
758 
759   const double scalar = 3.28;
760   const int int_scalar = 13;
761 
762   std::vector<std::complex<double>> expected_result(data);
763   for (auto& num : expected_result)
764     num += scalar + static_cast<double>(int_scalar);
765 
766   helib::Ptxt<helib::CKKS> ptxt(context, data);
767   ptxt += scalar;
768   ptxt += int_scalar;
769 
770   const auto& slots = ptxt.getSlotRepr();
771   for (std::size_t i = 0; i < data.size(); ++i) {
772 
773     EXPECT_DOUBLE_EQ(slots[i].real(), expected_result[i].real());
774     EXPECT_DOUBLE_EQ(slots[i].imag(), expected_result[i].imag());
775   }
776 }
777 
TEST_P(TestPtxtCKKS,timesEqualsScalarWorks)778 TEST_P(TestPtxtCKKS, timesEqualsScalarWorks)
779 {
780   std::vector<std::complex<double>> data(context.ea->size());
781   for (long i = 0; i < helib::lsize(data); ++i) {
782     data[i] = {i * i * i / 100.0, -i / 3.0};
783   }
784 
785   const double scalar = 10.28;
786   const int int_scalar = -2;
787 
788   std::vector<std::complex<double>> expected_result(data);
789   for (auto& num : expected_result)
790     num *= scalar * static_cast<double>(int_scalar);
791 
792   helib::Ptxt<helib::CKKS> ptxt(context, data);
793   ptxt *= scalar;
794   ptxt *= int_scalar;
795 
796   const auto& slots = ptxt.getSlotRepr();
797   for (std::size_t i = 0; i < data.size(); ++i) {
798 
799     EXPECT_DOUBLE_EQ(slots[i].real(), expected_result[i].real());
800     EXPECT_DOUBLE_EQ(slots[i].imag(), expected_result[i].imag());
801   }
802 }
803 
TEST_P(TestPtxtCKKS,equalityWithOtherPlaintextWorks)804 TEST_P(TestPtxtCKKS, equalityWithOtherPlaintextWorks)
805 {
806   std::vector<std::complex<double>> data(context.ea->size());
807   for (long i = 0; i < helib::lsize(data); ++i) {
808     data[i] = {i * 2.5, (i - 2) * 2.5};
809   }
810   helib::Ptxt<helib::CKKS> ptxt1(context, data);
811   helib::Ptxt<helib::CKKS> ptxt2(context, data);
812   EXPECT_TRUE(ptxt1 == ptxt2);
813   EXPECT_FALSE(ptxt1 == helib::Ptxt<helib::CKKS>());
814 }
815 
TEST_P(TestPtxtCKKS,notEqualsOperatorWithOtherPlaintextWorks)816 TEST_P(TestPtxtCKKS, notEqualsOperatorWithOtherPlaintextWorks)
817 {
818   std::vector<std::complex<double>> data1(context.ea->size());
819   std::vector<std::complex<double>> data2(context.ea->size());
820   for (long i = 0; i < helib::lsize(data1); ++i) {
821     data1[i] = {(i + 1) * 2.5,
822                 -i * 2.5}; // i+1 makes the first element differ from (0,0)
823     data2[i] = {i * 2.5, i * 6.5};
824   }
825   helib::Ptxt<helib::CKKS> ptxt1(context, data1);
826   helib::Ptxt<helib::CKKS> ptxt2(context, data2);
827   EXPECT_TRUE(ptxt1 != ptxt2);
828   EXPECT_FALSE(ptxt1 != ptxt1);
829 }
830 
TEST_P(TestPtxtCKKS,negateNegatesCorrectly)831 TEST_P(TestPtxtCKKS, negateNegatesCorrectly)
832 {
833   std::vector<std::complex<double>> data(context.ea->size());
834   const double pi = std::acos(-1);
835   for (long j = 0; j < helib::lsize(data); ++j) {
836     // Spiral with j -> j e^{2*i*pi*j/data.size()}
837     data[j] = std::complex<double>{static_cast<double>(j), 0} *
838               std::exp(std::complex<double>{0, 2.0 * pi * j / data.size()});
839   }
840 
841   std::vector<std::complex<double>> expected_result(data);
842   for (auto& num : expected_result)
843     num *= std::complex<double>{-1.0, 0};
844 
845   helib::Ptxt<helib::CKKS> ptxt(context, data);
846   ptxt.negate();
847 
848   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
849 }
850 
TEST_P(TestPtxtCKKS,addConstantWorksCorrectly)851 TEST_P(TestPtxtCKKS, addConstantWorksCorrectly)
852 {
853   std::vector<std::complex<double>> data(context.ea->size());
854   for (long i = 0; i < helib::lsize(data); ++i)
855     data[i] = {i * 4.5, i / 2.0};
856 
857   std::vector<std::complex<double>> expected_result(data);
858   for (auto& num : expected_result)
859     (num += 5) += std::complex<double>{0, 0.5};
860 
861   helib::Ptxt<helib::CKKS> ptxt(context, data);
862   ptxt.addConstantCKKS(5).addConstantCKKS(std::complex<double>{0, 0.5});
863 
864   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
865 }
866 
TEST_P(TestPtxtCKKS,multiplyByMultipliesCorrectly)867 TEST_P(TestPtxtCKKS, multiplyByMultipliesCorrectly)
868 {
869   std::vector<std::complex<double>> product_data(context.ea->size());
870   std::vector<std::complex<double>> multiplier_data(context.ea->size());
871   for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
872     product_data[i] = {(2 - i) / 10.0, (1 - i) / 10.0};
873     multiplier_data[i] = {std::exp(i / 100.), std::cos(i) * 12};
874   }
875 
876   std::vector<std::complex<double>> expected_result(product_data);
877   for (std::size_t i = 0; i < product_data.size(); ++i) {
878     expected_result[i] = expected_result[i] * multiplier_data[i];
879   }
880 
881   helib::Ptxt<helib::CKKS> product(context, product_data);
882   helib::Ptxt<helib::CKKS> multiplier(context, multiplier_data);
883 
884   product.multiplyBy(multiplier);
885 
886   // We use EXPECT_FLOAT_EQ as opposed to EXPECT_DOUBLE_EQ here to have a
887   // higher default threshold for precision.
888   COMPARE_CXFLOAT_VECS(product.getSlotRepr(), expected_result);
889 }
890 
TEST_P(TestPtxtCKKS,multiplyBy2MultipliesCorrectly)891 TEST_P(TestPtxtCKKS, multiplyBy2MultipliesCorrectly)
892 {
893   std::vector<std::complex<double>> product_data(context.ea->size());
894   std::vector<std::complex<double>> multiplier_data1(context.ea->size());
895   std::vector<std::complex<double>> multiplier_data2(context.ea->size());
896   for (long i = 0; i < helib::lsize(multiplier_data1); ++i) {
897     product_data[i] = static_cast<double>(i) *
898                       std::exp(std::complex<double>{0, static_cast<double>(i)});
899     multiplier_data2[i] =
900         static_cast<double>(i) *
901         std::exp(std::complex<double>{0, static_cast<double>(-i)});
902     ;
903     multiplier_data2[i] =
904         5.0 * std::exp(std::complex<double>{0, static_cast<double>(i)});
905     ;
906     ;
907   }
908 
909   std::vector<std::complex<double>> expected_result(product_data);
910   for (std::size_t i = 0; i < product_data.size(); ++i) {
911     expected_result[i] =
912         expected_result[i] * multiplier_data1[i] * multiplier_data2[i];
913   }
914 
915   helib::Ptxt<helib::CKKS> product(context, product_data);
916   helib::Ptxt<helib::CKKS> multiplier1(context, multiplier_data1);
917   helib::Ptxt<helib::CKKS> multiplier2(context, multiplier_data2);
918 
919   product.multiplyBy2(multiplier1, multiplier2);
920 
921   COMPARE_CXDOUBLE_VECS(product.getSlotRepr(), expected_result);
922 }
923 
TEST_P(TestPtxtCKKS,squareSquaresCorrectly)924 TEST_P(TestPtxtCKKS, squareSquaresCorrectly)
925 {
926   std::vector<std::complex<double>> data(context.ea->size());
927   for (long i = 0; i < helib::lsize(data); ++i) {
928     // Lemniscate of Bernoulli
929     double theta = 2. * std::acos(-1) * i / data.size();
930     data[i] = std::cos(2. * theta) * std::exp(std::complex<double>{0, theta});
931   }
932   std::vector<std::complex<double>> expected_result(data);
933   for (auto& num : expected_result)
934     num *= num;
935   helib::Ptxt<helib::CKKS> ptxt(context, data);
936   ptxt.square();
937   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
938 }
939 
TEST_P(TestPtxtCKKS,cubeCubesCorrectly)940 TEST_P(TestPtxtCKKS, cubeCubesCorrectly)
941 {
942   std::vector<std::complex<double>> data(context.ea->size());
943   for (long i = 0; i < helib::lsize(data); ++i) {
944     // Catenary
945     data[i] = {static_cast<double>(1. * i - data.size() / 2) / data.size(),
946                std::cosh((i - data.size() / 2.) / data.size())};
947   }
948   std::vector<std::complex<double>> expected_result(data);
949   for (auto& num : expected_result)
950     num = num * num * num;
951   helib::Ptxt<helib::CKKS> ptxt(context, data);
952   ptxt.cube();
953   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
954 }
955 
TEST_P(TestPtxtCKKS,powerCorrectlyRaisesToPowers)956 TEST_P(TestPtxtCKKS, powerCorrectlyRaisesToPowers)
957 {
958   std::vector<std::complex<double>> data(context.ea->size());
959   const double pi = std::acos(-1);
960   // Spiral inside the unit disk
961   for (long j = 0; j < helib::lsize(data); ++j) {
962     data[j] = std::complex<double>{j / (double)data.size()} *
963               std::exp(std::complex<double>{0, 2.0 * pi * j / data.size()});
964   }
965   std::vector<long> exponents{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1500};
966 
967   const auto naive_power = [](std::complex<double> base,
968                               unsigned long exponent) {
969     if (exponent == 0)
970       return std::complex<double>{1.0};
971     auto result = base;
972     while (--exponent)
973       result *= base;
974     return result;
975   };
976 
977   for (const auto& exponent : exponents) {
978     std::vector<std::complex<double>> expected_result(data);
979     for (auto& num : expected_result)
980       num = naive_power(num, exponent);
981     helib::Ptxt<helib::CKKS> ptxt(context, data);
982     ptxt.power(exponent);
983     for (std::size_t i = 0; i < ptxt.size(); ++i) {
984       EXPECT_NEAR(std::norm(ptxt[i] - expected_result[i]),
985                   0,
986                   pre_encryption_epsilon);
987     }
988   }
989 
990   // Make sure raising to 0 throws
991   helib::Ptxt<helib::CKKS> ptxt(context, data);
992   EXPECT_THROW(ptxt.power(0l), helib::InvalidArgument);
993 }
994 
TEST_P(TestPtxtCKKS,shiftShiftsRightCorrectly)995 TEST_P(TestPtxtCKKS, shiftShiftsRightCorrectly)
996 {
997   std::vector<std::complex<double>> data(context.ea->size());
998   std::vector<std::complex<double>> right_shifted_data(context.ea->size());
999   const auto non_neg_mod = [](int x, int mod) {
1000     return ((x % mod) + mod) % mod;
1001   };
1002   for (int i = 0; i < helib::lsize(data); ++i) {
1003     if (i > 3) {
1004       right_shifted_data[i] = {
1005           static_cast<double>(non_neg_mod(i - 3, data.size())),
1006           static_cast<double>(non_neg_mod(i - 3, data.size()))};
1007     }
1008     data[i] = {static_cast<double>(i), static_cast<double>(i)};
1009   }
1010   helib::Ptxt<helib::CKKS> ptxt(context, data);
1011 
1012   ptxt.shift(3);
1013   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), right_shifted_data);
1014 }
1015 
TEST_P(TestPtxtCKKS,shiftShiftsLeftCorrectly)1016 TEST_P(TestPtxtCKKS, shiftShiftsLeftCorrectly)
1017 {
1018   std::vector<std::complex<double>> data(context.ea->size());
1019   std::vector<std::complex<double>> left_shifted_data(context.ea->size());
1020   const auto non_neg_mod = [](int x, int mod) {
1021     return ((x % mod) + mod) % mod;
1022   };
1023   for (int i = 0; i < helib::lsize(data); ++i) {
1024     if (i < long(data.size()) - 3 && data.size() > 3) {
1025       left_shifted_data[i] = {
1026           static_cast<double>(non_neg_mod(i + 3, data.size())),
1027           static_cast<double>(non_neg_mod(i + 3, data.size()))};
1028     }
1029     data[i] = {static_cast<double>(i), static_cast<double>(i)};
1030   }
1031   helib::Ptxt<helib::CKKS> ptxt(context, data);
1032 
1033   ptxt.shift(-3);
1034   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_shifted_data);
1035 }
1036 
TEST_P(TestPtxtCKKS,shift1DShiftsRightCorrectly)1037 TEST_P(TestPtxtCKKS, shift1DShiftsRightCorrectly)
1038 {
1039   std::vector<std::complex<double>> data(context.ea->size());
1040   std::vector<std::complex<double>> right_shifted_data(context.ea->size());
1041   const auto non_neg_mod = [](int x, int mod) {
1042     return ((x % mod) + mod) % mod;
1043   };
1044   for (int i = 0; i < helib::lsize(data); ++i) {
1045     if (i > 3) {
1046       right_shifted_data[i] = {
1047           static_cast<double>(non_neg_mod(i - 3, data.size())),
1048           static_cast<double>(non_neg_mod(i - 3, data.size()))};
1049     }
1050     data[i] = {static_cast<double>(i), static_cast<double>(i)};
1051   }
1052   helib::Ptxt<helib::CKKS> ptxt(context, data);
1053 
1054   ptxt.shift1D(0, 3);
1055   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), right_shifted_data);
1056 }
1057 
TEST_P(TestPtxtCKKS,shift1DShiftsLeftCorrectly)1058 TEST_P(TestPtxtCKKS, shift1DShiftsLeftCorrectly)
1059 {
1060   std::vector<std::complex<double>> data(context.ea->size());
1061   std::vector<std::complex<double>> left_shifted_data(context.ea->size());
1062   const auto non_neg_mod = [](int x, int mod) {
1063     return ((x % mod) + mod) % mod;
1064   };
1065   for (int i = 0; i < helib::lsize(data); ++i) {
1066     if (i < long(data.size()) - 3 && data.size() > 3) {
1067       left_shifted_data[i] = {
1068           static_cast<double>(non_neg_mod(i + 3, data.size())),
1069           static_cast<double>(non_neg_mod(i + 3, data.size()))};
1070     }
1071     data[i] = {static_cast<double>(i), static_cast<double>(i)};
1072   }
1073   helib::Ptxt<helib::CKKS> ptxt(context, data);
1074 
1075   ptxt.shift1D(0, -3);
1076   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_shifted_data);
1077 }
1078 
1079 // These tests are disabled since the methods are private.
1080 // These can be useful if tweaking the logic of this area.
1081 // TEST(TestPtxtBGV, coord_To_Index_Works)
1082 // {
1083 //   helib::Context context(32109, 4999, 1);
1084 //   helib::Ptxt<helib::BGV> ptxt(context);
1085 //   std::vector<long> indices;
1086 //   for(long i=0; i<6; ++i)
1087 //     for(long j=0; j<2; ++j)
1088 //       for(long k=0; k<2; ++k)
1089 //         indices.push_back(ptxt.coordToIndex({i,j,k}));
1090 //   std::vector<long> expected_indices(context.ea->size());
1091 //   std::iota(expected_indices.begin(), expected_indices.end(), 0);
1092 //   EXPECT_EQ(expected_indices, indices);
1093 // }
1094 //
1095 // TEST(TestPtxtBGV, index_To_Coord_Works)
1096 // {
1097 //   helib::Context context(32109, 4999, 1);
1098 //   helib::Ptxt<helib::BGV> ptxt(context);
1099 //   std::vector<std::vector<long>> coords;
1100 //   for(std::size_t i=0; i<ptxt.size(); ++i)
1101 //     coords.push_back(ptxt.indexToCoord(i));
1102 //   std::vector<std::vector<long>> expected_coords;
1103 //   for(long i=0; i<6; ++i)
1104 //     for(long j=0; j<2; ++j)
1105 //       for(long k=0; k<2; ++k)
1106 //         expected_coords.push_back({i,j,k});
1107 //   EXPECT_EQ(expected_coords, coords);
1108 // }
1109 
TEST_P(TestPtxtCKKS,rotate1DRotatesCorrectly)1110 TEST_P(TestPtxtCKKS, rotate1DRotatesCorrectly)
1111 {
1112   std::vector<std::complex<double>> data(context.ea->size());
1113   std::vector<std::complex<double>> left_rotated_data(context.ea->size());
1114   const auto non_neg_mod = [](int x, int mod) {
1115     return ((x % mod) + mod) % mod;
1116   };
1117   for (int i = 0; i < helib::lsize(data); ++i) {
1118     data[i] = {static_cast<double>(non_neg_mod(i - 3, data.size())),
1119                static_cast<double>(non_neg_mod(i - 3, data.size()))};
1120     left_rotated_data[i] = {static_cast<double>(i), static_cast<double>(i)};
1121   }
1122   helib::Ptxt<helib::CKKS> ptxt(context, data);
1123 
1124   ptxt.rotate1D(0, -3);
1125   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_rotated_data);
1126   ptxt.rotate1D(0, 3);
1127   // Rotating back and forth gives the original data back
1128   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), data);
1129 }
1130 
TEST_P(TestPtxtCKKS,rotateRotatesCorrectly)1131 TEST_P(TestPtxtCKKS, rotateRotatesCorrectly)
1132 {
1133   std::vector<std::complex<double>> data(context.ea->size());
1134   std::vector<std::complex<double>> left_rotated_data(context.ea->size());
1135   const auto non_neg_mod = [](int x, int mod) {
1136     return ((x % mod) + mod) % mod;
1137   };
1138   for (int i = 0; i < helib::lsize(data); ++i) {
1139     data[i] = {static_cast<double>(non_neg_mod(i - 3, data.size())),
1140                static_cast<double>(non_neg_mod(i - 3, data.size()))};
1141     left_rotated_data[i] = {static_cast<double>(i), static_cast<double>(i)};
1142   }
1143   helib::Ptxt<helib::CKKS> ptxt(context, data);
1144 
1145   ptxt.rotate(-3);
1146   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_rotated_data);
1147   ptxt.rotate(3);
1148   // Rotating back and forth gives the original data back
1149   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), data);
1150 }
1151 
TEST_P(TestPtxtCKKS,automorphWorksCorrectly)1152 TEST_P(TestPtxtCKKS, automorphWorksCorrectly)
1153 {
1154   std::vector<std::complex<double>> data(context.ea->size());
1155   std::vector<std::complex<double>> left_rotated_data(context.ea->size());
1156   const auto non_neg_mod = [](int x, int mod) {
1157     return ((x % mod) + mod) % mod;
1158   };
1159   for (int i = 0; i < helib::lsize(data); ++i) {
1160     data[i] = {static_cast<double>(non_neg_mod(i - 3, data.size())),
1161                static_cast<double>(non_neg_mod(i - 3, data.size()))};
1162   }
1163   helib::Ptxt<helib::CKKS> ptxt(context, data);
1164   helib::Ptxt<helib::CKKS> expected_result(context, data);
1165 
1166   long k = context.zMStar.ith_rep(1) ? context.zMStar.ith_rep(1) : 1;
1167   ptxt.automorph(k);
1168   expected_result.rotate(1);
1169   COMPARE_CXDOUBLE_VECS(ptxt, expected_result);
1170 
1171   ptxt.automorph(context.zMStar.ith_rep(context.ea->size() - 1));
1172   expected_result.rotate(-1);
1173   COMPARE_CXDOUBLE_VECS(ptxt, expected_result);
1174 }
1175 
TEST_P(TestPtxtCKKS,replicateReplicatesCorrectly)1176 TEST_P(TestPtxtCKKS, replicateReplicatesCorrectly)
1177 {
1178   std::vector<std::complex<double>> data(context.ea->size());
1179   for (long i = 0; i < helib::lsize(data); ++i) {
1180     data[i] = {i / 10.0, -i / 20.0};
1181   }
1182   helib::Ptxt<helib::CKKS> ptxt(context, data);
1183   helib::replicate(*context.ea, ptxt, data.size() - 1);
1184   std::vector<std::complex<double>> replicated_data(context.ea->size(),
1185                                                     data[data.size() - 1]);
1186   COMPARE_CXDOUBLE_VECS(ptxt, replicated_data);
1187 }
1188 
TEST_P(TestPtxtCKKS,replicateAllWorksCorrectly)1189 TEST_P(TestPtxtCKKS, replicateAllWorksCorrectly)
1190 {
1191   std::vector<std::complex<double>> data(context.ea->size());
1192   for (long i = 0; i < helib::lsize(data); ++i) {
1193     data[i] = {i / 10.0, -i / 20.0};
1194   }
1195   helib::Ptxt<helib::CKKS> ptxt(context, data);
1196   std::vector<helib::Ptxt<helib::CKKS>> replicated_ptxts = ptxt.replicateAll();
1197   for (long i = 0; i < helib::lsize(data); ++i) {
1198     for (const auto& slot : replicated_ptxts[i].getSlotRepr()) {
1199       EXPECT_DOUBLE_EQ(data[i].real(), slot.real());
1200       EXPECT_DOUBLE_EQ(data[i].imag(), slot.imag());
1201     }
1202   }
1203 }
1204 
TEST_P(TestPtxtCKKS,randomSetsDataRandomly)1205 TEST_P(TestPtxtCKKS, randomSetsDataRandomly)
1206 {
1207   helib::Ptxt<helib::CKKS> ptxt(context);
1208   ptxt.random();
1209   std::vector<helib::Ptxt<helib::CKKS>> ptxts(5, ptxt);
1210   for (auto& p : ptxts)
1211     p.random();
1212 
1213   bool all_equal = true;
1214   for (std::size_t i = 0; i < ptxts.size() - 1; ++i)
1215     if (ptxts[i] != ptxts[i + 1]) {
1216       all_equal = false;
1217       break;
1218     }
1219   EXPECT_FALSE(all_equal) << "5 random ptxts are all equal - likely that"
1220                              " random() is not actually randomising!";
1221 }
1222 
TEST_P(TestPtxtCKKS,complexConjCorrectlyConjugates)1223 TEST_P(TestPtxtCKKS, complexConjCorrectlyConjugates)
1224 {
1225   std::vector<std::complex<double>> data(context.ea->size());
1226   const std::complex<double> z{1, -1};
1227   for (long j = 0; j < helib::lsize(data); ++j) {
1228     // Line segment starting at 1 - i with gradient 2
1229     data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1230   }
1231 
1232   std::vector<std::complex<double>> expected_result(data);
1233 
1234   helib::Ptxt<helib::CKKS> ptxt(context, data);
1235   ptxt.complexConj();
1236 
1237   for (auto& num : expected_result)
1238     num = std::conj(num);
1239 
1240   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1241 }
1242 
TEST_P(TestPtxtCKKS,extractRealPartIsCorrect)1243 TEST_P(TestPtxtCKKS, extractRealPartIsCorrect)
1244 {
1245   std::vector<std::complex<double>> data(context.ea->size());
1246   const std::complex<double> z{1, -1};
1247   for (long j = 0; j < helib::lsize(data); ++j) {
1248     // Line segment starting at 1 - i with gradient 2
1249     data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1250   }
1251 
1252   std::vector<std::complex<double>> expected_result(data);
1253   for (auto& num : expected_result)
1254     num = std::real(num);
1255 
1256   helib::Ptxt<helib::CKKS> ptxt(context, data);
1257   context.ea->getCx().extractRealPart(ptxt);
1258 
1259   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1260 }
1261 
TEST_P(TestPtxtCKKS,extractImPartIsCorrect)1262 TEST_P(TestPtxtCKKS, extractImPartIsCorrect)
1263 {
1264   std::vector<std::complex<double>> data(context.ea->size());
1265   const std::complex<double> z{1, -1};
1266   for (long j = 0; j < helib::lsize(data); ++j) {
1267     // Line segment starting at 1 - i with gradient 2
1268     data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1269   }
1270 
1271   std::vector<std::complex<double>> expected_result(data);
1272   for (auto& num : expected_result)
1273     num = std::imag(num);
1274 
1275   helib::Ptxt<helib::CKKS> ptxt(context, data);
1276   context.ea->getCx().extractImPart(ptxt);
1277 
1278   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1279 }
1280 
TEST_P(TestPtxtCKKS,realExtractsRealPart)1281 TEST_P(TestPtxtCKKS, realExtractsRealPart)
1282 {
1283   std::vector<std::complex<double>> data(context.ea->size());
1284   const std::complex<double> z{1, -1};
1285   for (long j = 0; j < helib::lsize(data); ++j) {
1286     // Line segment starting at 1 - i with gradient 2
1287     data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1288   }
1289 
1290   std::vector<std::complex<double>> expected_result(data);
1291   for (auto& num : expected_result)
1292     num = std::real(num);
1293 
1294   helib::Ptxt<helib::CKKS> ptxt(context, data);
1295   ptxt = ptxt.real();
1296 
1297   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1298 }
1299 
TEST_P(TestPtxtCKKS,imagExtractsImaginaryPart)1300 TEST_P(TestPtxtCKKS, imagExtractsImaginaryPart)
1301 {
1302   std::vector<std::complex<double>> data(context.ea->size());
1303   const std::complex<double> z{1, -1};
1304   for (long j = 0; j < helib::lsize(data); ++j) {
1305     // Line segment starting at 1 - i with gradient 2
1306     data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1307   }
1308 
1309   std::vector<std::complex<double>> expected_result(data);
1310   for (auto& num : expected_result)
1311     num = std::imag(num);
1312 
1313   helib::Ptxt<helib::CKKS> ptxt(context, data);
1314   ptxt = ptxt.imag();
1315 
1316   COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1317 }
1318 
TEST_P(TestPtxtCKKS,canEncryptAndDecryptComplexPtxtsWithKeys)1319 TEST_P(TestPtxtCKKS, canEncryptAndDecryptComplexPtxtsWithKeys)
1320 {
1321   helib::buildModChain(context, 100, 2);
1322   helib::SecKey secret_key(context);
1323   secret_key.GenSecKey();
1324   const helib::PubKey& public_key(secret_key);
1325 
1326   std::vector<std::complex<double>> data(context.ea->size());
1327   for (long i = 0; i < helib::lsize(data); ++i) {
1328     data[i] = {(i - 3) / 5.0, i + 10.0};
1329   }
1330   helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1331   helib::Ctxt ctxt(public_key);
1332 
1333   public_key.Encrypt(ctxt, pre_encryption);
1334 
1335   helib::Ptxt<helib::CKKS> post_decryption(context);
1336   secret_key.Decrypt(post_decryption, ctxt);
1337   EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1338   for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1339     EXPECT_NEAR(std::norm(pre_encryption[i] - post_decryption[i]),
1340                 0,
1341                 post_encryption_epsilon);
1342   }
1343 }
1344 
TEST_P(TestPtxtCKKS,canEncryptAndDecryptRealPtxtsWithKeys)1345 TEST_P(TestPtxtCKKS, canEncryptAndDecryptRealPtxtsWithKeys)
1346 {
1347   helib::buildModChain(context, 100, 2);
1348   helib::SecKey secret_key(context);
1349   secret_key.GenSecKey();
1350   const helib::PubKey& public_key(secret_key);
1351 
1352   std::vector<double> data(context.ea->size());
1353   for (long i = 0; i < helib::lsize(data); ++i) {
1354     data[i] = (i - 3) / 10.0;
1355   }
1356   helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1357   helib::Ctxt ctxt(public_key);
1358 
1359   public_key.Encrypt(ctxt, pre_encryption);
1360 
1361   helib::Ptxt<helib::CKKS> post_decryption(context);
1362   secret_key.Decrypt(post_decryption, ctxt);
1363   EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1364   for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1365     EXPECT_NEAR(pre_encryption[i].real(),
1366                 post_decryption[i].real(),
1367                 post_encryption_epsilon);
1368     EXPECT_NEAR(pre_encryption[i].imag(),
1369                 post_decryption[i].imag(),
1370                 post_encryption_epsilon);
1371   }
1372 }
1373 
TEST_P(TestPtxtCKKS,canEncryptAndDecryptComplexPtxtsWithEa)1374 TEST_P(TestPtxtCKKS, canEncryptAndDecryptComplexPtxtsWithEa)
1375 {
1376   helib::buildModChain(context, 100, 2);
1377   helib::SecKey secret_key(context);
1378   secret_key.GenSecKey();
1379   const helib::PubKey& public_key(secret_key);
1380 
1381   std::vector<std::complex<double>> data(context.ea->size());
1382   for (long i = 0; i < helib::lsize(data); ++i) {
1383     data[i] = {(i - 3) / 10.0, i + 5.0};
1384   }
1385   helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1386   helib::Ctxt ctxt(public_key);
1387 
1388   public_key.Encrypt(ctxt, pre_encryption);
1389 
1390   helib::Ptxt<helib::CKKS> post_decryption(context);
1391   secret_key.Decrypt(post_decryption, ctxt);
1392   EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1393   for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1394     EXPECT_NEAR(std::norm(pre_encryption[i] - post_decryption[i]),
1395                 0,
1396                 post_encryption_epsilon);
1397   }
1398 }
1399 
TEST_P(TestPtxtCKKS,canEncryptAndDecryptRealPtxtsWithEa)1400 TEST_P(TestPtxtCKKS, canEncryptAndDecryptRealPtxtsWithEa)
1401 {
1402   helib::buildModChain(context, 100, 2);
1403   const helib::EncryptedArrayCx& ea = context.ea->getCx();
1404   helib::SecKey secret_key(context);
1405   secret_key.GenSecKey();
1406   const helib::PubKey& public_key(secret_key);
1407 
1408   std::vector<double> data(context.ea->size());
1409   for (long i = 0; i < helib::lsize(data); ++i) {
1410     data[i] = (i - 3) / 10.0;
1411   }
1412   helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1413   helib::Ctxt ctxt(public_key);
1414 
1415   public_key.Encrypt(ctxt, pre_encryption);
1416 
1417   helib::Ptxt<helib::CKKS> post_decryption(context);
1418   ea.decrypt(ctxt, secret_key, post_decryption);
1419   EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1420   for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1421     EXPECT_NEAR(pre_encryption[i].real(),
1422                 post_decryption[i].real(),
1423                 post_encryption_epsilon);
1424     EXPECT_NEAR(pre_encryption[i].imag(),
1425                 post_decryption[i].imag(),
1426                 post_encryption_epsilon);
1427   }
1428 }
1429 
TEST_P(TestPtxtCKKS,plusEqualsWithCiphertextWorks)1430 TEST_P(TestPtxtCKKS, plusEqualsWithCiphertextWorks)
1431 {
1432   helib::buildModChain(context, 150, 2);
1433   helib::SecKey secret_key(context);
1434   secret_key.GenSecKey();
1435   const helib::PubKey& public_key(secret_key);
1436 
1437   // Encrypt the augend, addend is plaintext
1438   std::vector<std::complex<double>> augend_data(context.ea->size());
1439   std::vector<std::complex<double>> addend_data(context.ea->size());
1440   for (long i = 0; i < helib::lsize(augend_data); ++i) {
1441     augend_data[i] = {i / 10.0, -i * i / 63.0};
1442     addend_data[i] = {-i / 20.0, i * i * i * 2.6};
1443   }
1444   helib::Ptxt<helib::CKKS> augend_ptxt(context, augend_data);
1445   helib::Ptxt<helib::CKKS> addend(context, addend_data);
1446   helib::Ctxt augend(public_key);
1447   public_key.Encrypt(augend, augend_ptxt);
1448 
1449   augend += addend;
1450   augend_ptxt += addend;
1451 
1452   // augend_ptxt and augend should now match
1453   helib::Ptxt<helib::CKKS> result(context);
1454   secret_key.Decrypt(result, augend);
1455   EXPECT_EQ(result.size(), augend_ptxt.size());
1456   for (std::size_t i = 0; i < result.size(); ++i) {
1457     EXPECT_NEAR(std::norm(result[i] - augend_ptxt[i]),
1458                 0,
1459                 post_encryption_epsilon);
1460   }
1461 }
1462 
TEST_P(TestPtxtCKKS,addConstantCKKSWithCiphertextWorks)1463 TEST_P(TestPtxtCKKS, addConstantCKKSWithCiphertextWorks)
1464 {
1465   helib::buildModChain(context, 150, 2);
1466   helib::SecKey secret_key(context);
1467   secret_key.GenSecKey();
1468   const helib::PubKey& public_key(secret_key);
1469 
1470   // Encrypt the augend, addend is plaintext
1471   std::vector<std::complex<double>> augend_data(context.ea->size());
1472   std::vector<std::complex<double>> addend_data(context.ea->size());
1473   for (long i = 0; i < helib::lsize(augend_data); ++i) {
1474     augend_data[i] = {i / 70.0, -i * 10.5};
1475     addend_data[i] = {-i / 10.0, i * 0.8};
1476   }
1477   helib::Ptxt<helib::CKKS> augend_ptxt(context, augend_data);
1478   helib::Ptxt<helib::CKKS> addend(context, addend_data);
1479   helib::Ctxt augend(public_key);
1480   public_key.Encrypt(augend, augend_ptxt);
1481 
1482   augend.addConstantCKKS(addend);
1483   augend_ptxt += addend;
1484 
1485   // augend_ptxt and augend should now match
1486   helib::Ptxt<helib::CKKS> result(context);
1487   secret_key.Decrypt(result, augend);
1488   EXPECT_EQ(result.size(), augend_ptxt.size());
1489   for (std::size_t i = 0; i < result.size(); ++i) {
1490     EXPECT_NEAR(std::norm(result[i] - augend_ptxt[i]),
1491                 0,
1492                 post_encryption_epsilon);
1493   }
1494 }
1495 
TEST_P(TestPtxtCKKS,minusEqualsWithCiphertextWorks)1496 TEST_P(TestPtxtCKKS, minusEqualsWithCiphertextWorks)
1497 {
1498   helib::buildModChain(context, 150, 2);
1499   helib::SecKey secret_key(context);
1500   secret_key.GenSecKey();
1501   const helib::PubKey& public_key(secret_key);
1502 
1503   // Encrypt the minuend, subtrahend is plaintext
1504   std::vector<std::complex<double>> minuend_data(context.ea->size());
1505   std::vector<std::complex<double>> subtrahend_data(context.ea->size());
1506   for (long i = 0; i < helib::lsize(minuend_data); ++i) {
1507     minuend_data[i] = {i * i / 30.0, i * i / 4.5};
1508     subtrahend_data[i] = {(i + 3) / 4.0, -i * i / 1.3};
1509   }
1510   helib::Ptxt<helib::CKKS> minuend_ptxt(context, minuend_data);
1511   helib::Ptxt<helib::CKKS> subtrahend(context, subtrahend_data);
1512   helib::Ctxt minuend(public_key);
1513   public_key.Encrypt(minuend, minuend_ptxt);
1514 
1515   minuend -= subtrahend;
1516   minuend_ptxt -= subtrahend;
1517 
1518   // minuend_ptxt and minuend should now match
1519   helib::Ptxt<helib::CKKS> result(context);
1520   secret_key.Decrypt(result, minuend);
1521   EXPECT_EQ(result.size(), minuend_ptxt.size());
1522   for (std::size_t i = 0; i < result.size(); ++i) {
1523     EXPECT_NEAR(std::norm(result[i] - minuend_ptxt[i]),
1524                 0,
1525                 post_encryption_epsilon);
1526   }
1527 }
1528 
TEST_P(TestPtxtCKKS,multByConstantCKKSFromCiphertextWorks)1529 TEST_P(TestPtxtCKKS, multByConstantCKKSFromCiphertextWorks)
1530 {
1531   helib::buildModChain(context, 150, 2);
1532   helib::SecKey secret_key(context);
1533   secret_key.GenSecKey();
1534   const helib::PubKey& public_key(secret_key);
1535 
1536   // Encrypt the multiplier, multiplicand is plaintext
1537   std::vector<std::complex<double>> multiplier_data(context.ea->size());
1538   std::vector<std::complex<double>> multiplicand_data(context.ea->size());
1539   for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
1540     multiplier_data[i] = {i * 4.5, -i * i / 12.5};
1541     multiplicand_data[i] = {(i - 2.5) / 3.5, i * 4.2};
1542   }
1543   helib::Ptxt<helib::CKKS> multiplier_ptxt(context, multiplier_data);
1544   helib::Ptxt<helib::CKKS> multiplicand(context, multiplicand_data);
1545 
1546   helib::Ctxt multiplier(public_key);
1547   public_key.Encrypt(multiplier, multiplier_ptxt);
1548 
1549   multiplier.multByConstantCKKS(multiplicand);
1550   multiplier_ptxt *= multiplicand;
1551 
1552   // multiplier_ptxt and multiplier should now match
1553   helib::Ptxt<helib::CKKS> result(context);
1554   secret_key.Decrypt(result, multiplier);
1555   EXPECT_EQ(result.size(), multiplier_ptxt.size());
1556   for (std::size_t i = 0; i < result.size(); ++i) {
1557     EXPECT_NEAR(std::norm(result[i] - multiplier_ptxt[i]),
1558                 0,
1559                 post_encryption_epsilon);
1560   }
1561 }
1562 
TEST_P(TestPtxtCKKS,timesEqualsFromCiphertextWorks)1563 TEST_P(TestPtxtCKKS, timesEqualsFromCiphertextWorks)
1564 {
1565   helib::buildModChain(context, 150, 2);
1566   helib::SecKey secret_key(context);
1567   secret_key.GenSecKey();
1568   const helib::PubKey& public_key(secret_key);
1569 
1570   // Encrypt the multiplier, multiplicand is plaintext
1571   std::vector<std::complex<double>> multiplier_data(context.ea->size());
1572   std::vector<std::complex<double>> multiplicand_data(context.ea->size());
1573   for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
1574     multiplier_data[i] = {i * 4.5, -i * i / 3.3};
1575     multiplicand_data[i] = {(i - 2.5) / 3.5, i * i / 12.4};
1576   }
1577   helib::Ptxt<helib::CKKS> multiplier_ptxt(context, multiplier_data);
1578   helib::Ptxt<helib::CKKS> multiplicand(context, multiplicand_data);
1579 
1580   helib::Ctxt multiplier(public_key);
1581   public_key.Encrypt(multiplier, multiplier_ptxt);
1582 
1583   multiplier *= multiplicand;
1584   multiplier_ptxt *= multiplicand;
1585 
1586   // multiplier_ptxt and multiplier should now match
1587   helib::Ptxt<helib::CKKS> result(context);
1588   secret_key.Decrypt(result, multiplier);
1589   EXPECT_EQ(result.size(), multiplier_ptxt.size());
1590   for (std::size_t i = 0; i < result.size(); ++i) {
1591     EXPECT_NEAR(std::norm(result[i] - multiplier_ptxt[i]),
1592                 0,
1593                 post_encryption_epsilon);
1594   }
1595 }
1596 
TEST_P(TestPtxtCKKS,plusOperatorWithOtherPtxtWorks)1597 TEST_P(TestPtxtCKKS, plusOperatorWithOtherPtxtWorks)
1598 {
1599   std::vector<std::complex<double>> augend_data(context.ea->size());
1600   std::vector<std::complex<double>> addend_data(context.ea->size());
1601   std::vector<std::complex<double>> expected_sum_data(context.ea->size());
1602   for (long i = 0; i < helib::lsize(augend_data); ++i) {
1603     augend_data[i] = {i / 10.0, -i * i / 3.0};
1604     addend_data[i] = {-i / 20.0, i * i * i * 42.6};
1605     expected_sum_data[i] = augend_data[i] + addend_data[i];
1606   }
1607   helib::Ptxt<helib::CKKS> augend(context, augend_data);
1608   helib::Ptxt<helib::CKKS> addend(context, addend_data);
1609   helib::Ptxt<helib::CKKS> sum;
1610 
1611   sum = augend + addend;
1612 
1613   COMPARE_CXDOUBLE_VECS(expected_sum_data, sum);
1614 }
1615 
TEST_P(TestPtxtCKKS,minusOperatorWithOtherPtxtWorks)1616 TEST_P(TestPtxtCKKS, minusOperatorWithOtherPtxtWorks)
1617 {
1618   std::vector<std::complex<double>> minuend_data(context.ea->size());
1619   std::vector<std::complex<double>> subtrahend_data(context.ea->size());
1620   std::vector<std::complex<double>> expected_diff_data(context.ea->size());
1621   for (long i = 0; i < helib::lsize(minuend_data); ++i) {
1622     minuend_data[i] = {i / 10.0, -i * i / 3.0};
1623     subtrahend_data[i] = {-i / 20.0, i * i * i * 42.6};
1624     expected_diff_data[i] = minuend_data[i] - subtrahend_data[i];
1625   }
1626   helib::Ptxt<helib::CKKS> minuend(context, minuend_data);
1627   helib::Ptxt<helib::CKKS> subtrahend(context, subtrahend_data);
1628   helib::Ptxt<helib::CKKS> diff;
1629 
1630   diff = minuend - subtrahend;
1631 
1632   COMPARE_CXDOUBLE_VECS(expected_diff_data, diff);
1633 }
1634 
TEST_P(TestPtxtCKKS,timesOperatorWithOtherPtxtWorks)1635 TEST_P(TestPtxtCKKS, timesOperatorWithOtherPtxtWorks)
1636 {
1637   std::vector<std::complex<double>> multiplier_data(context.ea->size());
1638   std::vector<std::complex<double>> multiplicand_data(context.ea->size());
1639   std::vector<std::complex<double>> expected_product_data(context.ea->size());
1640   for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
1641     multiplier_data[i] = {i / 10.0, -i * i / 3.0};
1642     multiplicand_data[i] = {-i / 20.0, i * i * i * 42.6};
1643     expected_product_data[i] = multiplier_data[i] * multiplicand_data[i];
1644   }
1645   helib::Ptxt<helib::CKKS> multiplier(context, multiplier_data);
1646   helib::Ptxt<helib::CKKS> multiplicand(context, multiplicand_data);
1647   helib::Ptxt<helib::CKKS> product;
1648 
1649   product = multiplier * multiplicand;
1650 
1651   COMPARE_CXDOUBLE_VECS(expected_product_data, product);
1652 }
1653 
1654 class TestPtxtBGV : public ::testing::TestWithParam<BGVParameters>
1655 {
1656 protected:
TestPtxtBGV()1657   TestPtxtBGV() :
1658       m(GetParam().m),
1659       p(GetParam().p),
1660       r(GetParam().r),
1661       ppowr(power(p, r)),
1662       context(m, p, r)
1663   {}
1664 
power(long base,unsigned long exponent)1665   static long power(long base, unsigned long exponent)
1666   {
1667     long result = base;
1668     while (--exponent)
1669       result *= base;
1670     return result;
1671   }
1672 
1673   const unsigned long m;
1674   const unsigned long p;
1675   const unsigned long r;
1676   const unsigned long ppowr;
1677 
1678   helib::Context context;
1679 };
1680 
TEST_P(TestPtxtBGV,canBeConstructedWithBGVContext)1681 TEST_P(TestPtxtBGV, canBeConstructedWithBGVContext)
1682 {
1683   helib::Ptxt<helib::BGV> ptxt(context);
1684 }
1685 
TEST_P(TestPtxtBGV,canBeDefaultConstructed)1686 TEST_P(TestPtxtBGV, canBeDefaultConstructed) { helib::Ptxt<helib::BGV> ptxt; }
1687 
TEST_P(TestPtxtBGV,canBeCopyConstructed)1688 TEST_P(TestPtxtBGV, canBeCopyConstructed)
1689 {
1690   helib::Ptxt<helib::BGV> ptxt(context);
1691   helib::Ptxt<helib::BGV> ptxt2(ptxt);
1692 }
1693 
TEST_P(TestPtxtBGV,canBeAssignedFromOtherPtxt)1694 TEST_P(TestPtxtBGV, canBeAssignedFromOtherPtxt)
1695 {
1696   helib::Ptxt<helib::BGV> ptxt(context);
1697   helib::Ptxt<helib::BGV> ptxt2 = ptxt;
1698 }
1699 
TEST_P(TestPtxtBGV,reportsWhetherItIsValid)1700 TEST_P(TestPtxtBGV, reportsWhetherItIsValid)
1701 {
1702   helib::Ptxt<helib::BGV> invalid_ptxt;
1703   helib::Ptxt<helib::BGV> valid_ptxt(context);
1704   EXPECT_FALSE(invalid_ptxt.isValid());
1705   EXPECT_TRUE(valid_ptxt.isValid());
1706 }
1707 
TEST_P(TestPtxtBGV,preservesLongDataPassedIntoConstructor)1708 TEST_P(TestPtxtBGV, preservesLongDataPassedIntoConstructor)
1709 {
1710   std::vector<long> data(context.ea->size());
1711   std::iota(data.begin(), data.end(), 0);
1712   helib::Ptxt<helib::BGV> ptxt(context, data);
1713   EXPECT_EQ(ptxt.size(), data.size());
1714   for (std::size_t i = 0; i < data.size(); ++i) {
1715     EXPECT_EQ(ptxt[i], data[i]);
1716   }
1717 }
1718 
TEST_P(TestPtxtBGV,preservesCoefficientVectorDataPassedIntoConstructor)1719 TEST_P(TestPtxtBGV, preservesCoefficientVectorDataPassedIntoConstructor)
1720 {
1721   std::vector<std::vector<long>> data(context.ea->size());
1722   for (std::size_t i = 0; i < data.size(); ++i) {
1723     data[i] = {1};
1724   }
1725   helib::Ptxt<helib::BGV> ptxt(context, data);
1726   EXPECT_EQ(ptxt.size(), data.size());
1727   for (std::size_t i = 0; i < data.size(); ++i) {
1728     EXPECT_EQ(ptxt[i], data[i]);
1729   }
1730 }
1731 
TEST_P(TestPtxtBGV,preservesZzxDataPassedIntoConstructor)1732 TEST_P(TestPtxtBGV, preservesZzxDataPassedIntoConstructor)
1733 {
1734   std::vector<NTL::ZZX> data(context.ea->size());
1735   std::iota(data.begin(), data.end(), 0);
1736   helib::Ptxt<helib::BGV> ptxt(context, data);
1737   EXPECT_EQ(ptxt.size(), data.size());
1738   for (std::size_t i = 0; i < data.size(); ++i) {
1739     EXPECT_EQ(ptxt[i], data[i]);
1740   }
1741 }
1742 
TEST_P(TestPtxtBGV,writesDataCorrectlyToOstream)1743 TEST_P(TestPtxtBGV, writesDataCorrectlyToOstream)
1744 {
1745   const long p2r = context.slotRing->p2r;
1746   const long d = context.zMStar.getOrdP();
1747   helib::PolyMod poly(context.slotRing);
1748   std::vector<helib::PolyMod> data(context.ea->size(), poly);
1749   std::stringstream ss;
1750   ss << "[";
1751   for (long i = 0; i < helib::lsize(data); ++i) {
1752     NTL::ZZX input;
1753     NTL::SetCoeff(input, 0, i % p2r);
1754     if (d != 1) {
1755       NTL::SetCoeff(input, 1, (i + 2) % p2r);
1756     }
1757     data[i] = input;
1758     // Serialisation of data[i] (i.e. PolyMod) is tested in `TestPolyMod.cpp`
1759     ss << data[i] << (i != helib::lsize(data) - 1 ? ", " : "");
1760   }
1761   ss << "]";
1762   helib::Ptxt<helib::BGV> ptxt(context, data);
1763   std::string expected = ss.str();
1764   std::ostringstream os;
1765   os << ptxt;
1766 
1767   EXPECT_EQ(os.str(), expected);
1768 }
1769 
TEST_P(TestPtxtBGV,readsDataCorrectlyFromIstream)1770 TEST_P(TestPtxtBGV, readsDataCorrectlyFromIstream)
1771 {
1772   helib::PolyMod poly(context.slotRing);
1773   std::vector<helib::PolyMod> data(context.ea->size(), poly);
1774   for (long i = 0; i < helib::lsize(data); ++i) {
1775     data[i] = {i, i + 2};
1776   }
1777   helib::Ptxt<helib::BGV> ptxt(context);
1778   std::stringstream ss;
1779   ss << "[";
1780   for (auto it = data.begin(); it != data.end(); it++) {
1781     ss << *it;
1782     if (it != data.end() - 1) {
1783       ss << ", ";
1784     }
1785   }
1786   ss << "]";
1787 
1788   ss >> ptxt;
1789 
1790   for (std::size_t i = 0; i < ptxt.size(); ++i) {
1791     EXPECT_EQ(ptxt[i], data[i]);
1792   }
1793 }
1794 
TEST_P(TestPtxtBGV,deserializeIsInverseOfSerialize)1795 TEST_P(TestPtxtBGV, deserializeIsInverseOfSerialize)
1796 {
1797   helib::PolyMod poly(context.slotRing);
1798   std::vector<helib::PolyMod> data(context.ea->size(), poly);
1799   for (long i = 0; i < helib::lsize(data); ++i) {
1800     data[i] = {i, i + 2};
1801   }
1802   helib::Ptxt<helib::BGV> ptxt(context);
1803 
1804   std::stringstream str;
1805   str << ptxt;
1806 
1807   helib::Ptxt<helib::BGV> deserialized(context);
1808   str >> deserialized;
1809 
1810   EXPECT_EQ(ptxt, deserialized);
1811 }
1812 
TEST_P(TestPtxtBGV,serializeFunctionSerializesCorrectly)1813 TEST_P(TestPtxtBGV, serializeFunctionSerializesCorrectly)
1814 {
1815   helib::PolyMod poly(context.slotRing);
1816   std::vector<helib::PolyMod> data(context.ea->size(), poly);
1817   std::stringstream ptxt_string_stream;
1818   ptxt_string_stream << "[";
1819   for (long i = 0; i < helib::lsize(data); ++i) {
1820     data[i] = 2 * i;
1821     ptxt_string_stream << "[" << helib::mcMod(2 * i, ppowr) << "]";
1822     if (i < helib::lsize(data) - 1)
1823       ptxt_string_stream << ", ";
1824   }
1825   ptxt_string_stream << "]";
1826   helib::Ptxt<helib::BGV> ptxt(context, data);
1827 
1828   std::stringstream ss;
1829   helib::serialize(ss, ptxt);
1830 
1831   EXPECT_EQ(ss.str(), ptxt_string_stream.str());
1832 }
1833 
TEST_P(TestPtxtBGV,deserializeFunctionDeserializesCorrectly)1834 TEST_P(TestPtxtBGV, deserializeFunctionDeserializesCorrectly)
1835 {
1836   helib::PolyMod poly(context.slotRing);
1837   std::vector<helib::PolyMod> data(context.ea->size(), poly);
1838   std::stringstream ptxt_string_stream;
1839   ptxt_string_stream << "[";
1840   for (long i = 0; i < helib::lsize(data); ++i) {
1841     NTL::ZZX tmp;
1842     ptxt_string_stream << "[";
1843     for (long j = 0; j < context.zMStar.getOrdP(); ++j) {
1844       NTL::SetCoeff(tmp, j, j * j);
1845       ptxt_string_stream << j * j;
1846       if (j < context.zMStar.getOrdP() - 1)
1847         ptxt_string_stream << ",";
1848     }
1849     data[i] = tmp;
1850     ptxt_string_stream << "]";
1851     if (i < helib::lsize(data) - 1)
1852       ptxt_string_stream << ", ";
1853   }
1854   ptxt_string_stream << "]";
1855   helib::Ptxt<helib::BGV> ptxt(context, data);
1856 
1857   helib::Ptxt<helib::BGV> deserialized_ptxt(context);
1858   helib::deserialize(ptxt_string_stream, deserialized_ptxt);
1859 
1860   EXPECT_EQ(ptxt, deserialized_ptxt);
1861 }
1862 
TEST_P(TestPtxtBGV,deserializeFunctionThrowsIfMoreElementsThanSlots)1863 TEST_P(TestPtxtBGV, deserializeFunctionThrowsIfMoreElementsThanSlots)
1864 {
1865   helib::PolyMod poly(context.slotRing);
1866   std::vector<helib::PolyMod> data(context.ea->size() + 1, poly);
1867   std::stringstream ptxt_string_stream;
1868   ptxt_string_stream << "[";
1869   for (long i = 0; i < helib::lsize(data); ++i) {
1870     data[i] = {i, i * i};
1871     ptxt_string_stream << "[" << i << ", " << i * i << "]";
1872     if (i < helib::lsize(data) - 1)
1873       ptxt_string_stream << ", ";
1874   }
1875   ptxt_string_stream << "]";
1876 
1877   helib::Ptxt<helib::BGV> deserialized_ptxt(context);
1878 
1879   EXPECT_THROW(helib::deserialize(ptxt_string_stream, deserialized_ptxt),
1880                helib::IOError);
1881 }
1882 
TEST_P(TestPtxtBGV,rightShiftOperatorThrowsIfMoreElementsThanSlots)1883 TEST_P(TestPtxtBGV, rightShiftOperatorThrowsIfMoreElementsThanSlots)
1884 {
1885   helib::PolyMod poly(context.slotRing);
1886   std::vector<helib::PolyMod> data(context.ea->size() + 1, poly);
1887   std::stringstream ptxt_string_stream;
1888   ptxt_string_stream << "[";
1889   for (long i = 0; i < helib::lsize(data); ++i) {
1890     data[i] = {i, i * i};
1891     ptxt_string_stream << "[" << i << ", " << i * i << "]";
1892     if (i < helib::lsize(data) - 1)
1893       ptxt_string_stream << ", ";
1894   }
1895   ptxt_string_stream << "]";
1896 
1897   helib::Ptxt<helib::BGV> deserialized_ptxt(context);
1898   EXPECT_THROW(ptxt_string_stream >> deserialized_ptxt, helib::IOError);
1899 }
1900 
TEST_P(TestPtxtBGV,readsManyPtxtsFromStream)1901 TEST_P(TestPtxtBGV, readsManyPtxtsFromStream)
1902 {
1903   helib::PolyMod poly(context.slotRing);
1904   std::vector<helib::PolyMod> data1(context.ea->size(), poly);
1905   std::vector<helib::PolyMod> data2(context.ea->size(), poly);
1906   std::vector<helib::PolyMod> data3(context.ea->size(), poly);
1907   for (long i = 0; i < helib::lsize(data1); ++i) {
1908     data1[i] = {i, i + 2};
1909     data2[i] = {2 * i, 2 * (i + 2)};
1910     data3[i] = {3 * i, 3 * (i + 2)};
1911   }
1912   helib::Ptxt<helib::BGV> ptxt1(context, data1);
1913   helib::Ptxt<helib::BGV> ptxt2(context, data2);
1914   helib::Ptxt<helib::BGV> ptxt3(context, data3);
1915 
1916   std::stringstream ss;
1917   ss << ptxt1 << std::endl;
1918   ss << ptxt2 << std::endl;
1919   ss << ptxt3 << std::endl;
1920 
1921   helib::Ptxt<helib::BGV> deserialized1(context);
1922   helib::Ptxt<helib::BGV> deserialized2(context);
1923   helib::Ptxt<helib::BGV> deserialized3(context);
1924   ss >> deserialized1;
1925   ss >> deserialized2;
1926   ss >> deserialized3;
1927 
1928   EXPECT_EQ(ptxt1, deserialized1);
1929   EXPECT_EQ(ptxt2, deserialized2);
1930   EXPECT_EQ(ptxt3, deserialized3);
1931 }
1932 
TEST_P(TestPtxtBGV,preservesPolyModDataPassedIntoConstructor)1933 TEST_P(TestPtxtBGV, preservesPolyModDataPassedIntoConstructor)
1934 {
1935   helib::PolyMod poly(context.slotRing);
1936   std::vector<helib::PolyMod> data(context.ea->size(), poly);
1937   helib::Ptxt<helib::BGV> ptxt(context, data);
1938   EXPECT_EQ(ptxt.size(), data.size());
1939   for (std::size_t i = 0; i < data.size(); ++i) {
1940     EXPECT_EQ(ptxt[i], data[i]);
1941   }
1942 }
1943 
TEST_P(TestPtxtBGV,throwsIfp2rAndGDoNotMatchThoseFromContext)1944 TEST_P(TestPtxtBGV, throwsIfp2rAndGDoNotMatchThoseFromContext)
1945 {
1946   NTL::ZZX G = context.slotRing->G;
1947   long p = context.slotRing->p;
1948   long r = context.slotRing->r;
1949   // Non-matching p^r
1950   std::shared_ptr<helib::PolyModRing> badPolyModRing1(
1951       new helib::PolyModRing(p + 1, r, G));
1952   helib::PolyMod badPolyMod1(badPolyModRing1);
1953   // Non-matching G
1954   std::shared_ptr<helib::PolyModRing> badPolyModRing2(
1955       new helib::PolyModRing(p, r, G + 1));
1956   helib::PolyMod badPolyMod2(badPolyModRing2);
1957   // All good
1958   std::shared_ptr<helib::PolyModRing> goodPolyModRing(
1959       new helib::PolyModRing(p, r, G));
1960   helib::PolyMod goodPolyMod(goodPolyModRing);
1961 
1962   std::vector<helib::PolyMod> data(context.ea->size(), goodPolyMod);
1963 
1964   // Make all of them good except 1, make sure it still notices
1965   data.back() = badPolyMod1;
1966   EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, data),
1967                helib::RuntimeError);
1968   data.back() = badPolyMod2;
1969   EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, data),
1970                helib::RuntimeError);
1971 
1972   // Make sure it complains if it's just given 1 bad PolyMod too
1973   EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, badPolyMod1),
1974                helib::RuntimeError);
1975   EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, badPolyMod2),
1976                helib::RuntimeError);
1977 }
1978 
TEST_P(TestPtxtBGV,lsizeReportsCorrectSize)1979 TEST_P(TestPtxtBGV, lsizeReportsCorrectSize)
1980 {
1981   std::vector<long> data(context.ea->size());
1982   helib::Ptxt<helib::BGV> ptxt(context, data);
1983   EXPECT_EQ(ptxt.lsize(), data.size());
1984 }
1985 
TEST_P(TestPtxtBGV,sizeReportsCorrectSize)1986 TEST_P(TestPtxtBGV, sizeReportsCorrectSize)
1987 {
1988   std::vector<long> data(context.ea->size());
1989   helib::Ptxt<helib::BGV> ptxt(context, data);
1990   EXPECT_EQ(ptxt.size(), data.size());
1991 }
1992 
TEST_P(TestPtxtBGV,padsWithZerosWhenPassingInSmallDataVector)1993 TEST_P(TestPtxtBGV, padsWithZerosWhenPassingInSmallDataVector)
1994 {
1995   std::vector<long> data(context.ea->size() - 1);
1996   std::iota(data.begin(), data.end(), 0);
1997   helib::Ptxt<helib::BGV> ptxt(context, data);
1998   for (std::size_t i = 0; i < data.size(); ++i) {
1999     EXPECT_EQ(ptxt[i], helib::mcMod(data[i], ppowr));
2000   }
2001   for (std::size_t i = data.size(); i < ptxt.size(); ++i) {
2002     EXPECT_EQ(ptxt[i], 0l);
2003   }
2004 }
2005 
TEST_P(TestPtxtBGV,hasSameNumberOfSlotsAsContext)2006 TEST_P(TestPtxtBGV, hasSameNumberOfSlotsAsContext)
2007 {
2008   helib::Ptxt<helib::BGV> ptxt(context);
2009   EXPECT_EQ(context.ea->size(), ptxt.size());
2010 }
2011 
TEST_P(TestPtxtBGV,randomSetsDataRandomly)2012 TEST_P(TestPtxtBGV, randomSetsDataRandomly)
2013 {
2014   helib::Ptxt<helib::BGV> ptxt(context);
2015   ptxt.random();
2016   std::vector<helib::Ptxt<helib::BGV>> ptxts(5, ptxt);
2017   for (auto& p : ptxts)
2018     p.random();
2019 
2020   bool all_equal = true;
2021   for (std::size_t i = 0; i < ptxts.size() - 1; ++i)
2022     if (ptxts[i] != ptxts[i + 1]) {
2023       all_equal = false;
2024       break;
2025     }
2026   EXPECT_FALSE(all_equal) << "5 random ptxts are all equal - likely that"
2027                              " random() is not actually randomising!";
2028 }
2029 
TEST_P(TestPtxtBGV,runningSumsWorksCorrectly)2030 TEST_P(TestPtxtBGV, runningSumsWorksCorrectly)
2031 {
2032   std::vector<long> data(context.ea->size());
2033   std::iota(data.begin(), data.end(), 1);
2034   std::vector<long> expected_result(data.size());
2035   for (std::size_t i = 0; i < data.size(); ++i)
2036     expected_result[i] = ((i + 1) * (i + 2)) / 2;
2037 
2038   helib::Ptxt<helib::BGV> ptxt(context, data);
2039   ptxt.runningSums();
2040 
2041   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2042     EXPECT_EQ(ptxt[i], expected_result[i]);
2043   }
2044 }
2045 
TEST_P(TestPtxtBGV,totalSumsWorksCorrectly)2046 TEST_P(TestPtxtBGV, totalSumsWorksCorrectly)
2047 {
2048   std::vector<long> data(context.ea->size());
2049   std::iota(data.begin(), data.end(), 1);
2050   std::vector<long> expected_result(data.size());
2051   for (std::size_t i = 0; i < data.size(); ++i)
2052     expected_result[i] = (data.size() * (data.size() + 1)) / 2;
2053 
2054   helib::Ptxt<helib::BGV> ptxt(context, data);
2055   ptxt.totalSums();
2056 
2057   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2058     EXPECT_EQ(ptxt[i], expected_result[i]);
2059   }
2060 }
2061 
TEST_P(TestPtxtBGV,incrementalProductWorksCorrectly)2062 TEST_P(TestPtxtBGV, incrementalProductWorksCorrectly)
2063 {
2064   std::vector<long> data(context.ea->size());
2065   std::iota(data.begin(), data.end(), 1);
2066   std::vector<long> expected_result(data);
2067   for (std::size_t i = 1; i < data.size(); ++i)
2068     expected_result[i] =
2069         (expected_result[i] * expected_result[i - 1]) % context.slotRing->p2r;
2070 
2071   helib::Ptxt<helib::BGV> ptxt(context, data);
2072   ptxt.incrementalProduct();
2073 
2074   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2075     EXPECT_EQ(ptxt[i], expected_result[i]);
2076   }
2077 }
2078 
TEST_P(TestPtxtBGV,totalProductWorksCorrectly)2079 TEST_P(TestPtxtBGV, totalProductWorksCorrectly)
2080 {
2081   std::vector<long> data(context.ea->size());
2082   std::iota(data.begin(), data.end(), 1);
2083   long product = 1;
2084   for (std::size_t i = 0; i < data.size(); ++i) {
2085     product *= data[i];
2086     product %= context.slotRing->p2r;
2087   }
2088   std::vector<long> expected_result(data.size(), product);
2089 
2090   helib::Ptxt<helib::BGV> ptxt(context, data);
2091   ptxt.totalProduct();
2092 
2093   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2094     EXPECT_EQ(ptxt[i], expected_result[i]);
2095   }
2096 }
2097 
TEST_P(TestPtxtBGV,innerProductWorksCorrectly)2098 TEST_P(TestPtxtBGV, innerProductWorksCorrectly)
2099 {
2100   std::vector<long> data(context.ea->size());
2101   std::iota(data.begin(), data.end(), 0);
2102   helib::Ptxt<helib::BGV> ptxt(context, data);
2103   std::vector<helib::Ptxt<helib::BGV>> first_ptxt_vector(4, ptxt);
2104   ptxt += ptxt;
2105   std::vector<helib::Ptxt<helib::BGV>> second_ptxt_vector(4, ptxt);
2106 
2107   helib::Ptxt<helib::BGV> result(context);
2108   innerProduct(result, first_ptxt_vector, second_ptxt_vector);
2109 
2110   std::vector<long> expected_result(data.size());
2111   for (std::size_t i = 0; i < data.size(); ++i) {
2112     expected_result[i] = 4 * (data[i] * (2 * data[i]));
2113   }
2114 
2115   for (std::size_t i = 0; i < result.size(); ++i) {
2116     EXPECT_EQ(result[i], expected_result[i]);
2117   }
2118 }
2119 
TEST_P(TestPtxtBGV,mapTo01MapsSlotsCorrectly)2120 TEST_P(TestPtxtBGV, mapTo01MapsSlotsCorrectly)
2121 {
2122   std::vector<long> data(context.ea->size());
2123   std::iota(data.begin(), data.end(), 0);
2124   std::vector<long> expected_result(data.size(), 1);
2125   for (std::size_t i = 0; i < data.size(); ++i)
2126     if (i % p == 0)
2127       expected_result[i] = 0;
2128 
2129   // Should exist as a free function and a member function
2130   helib::Ptxt<helib::BGV> ptxt(context, data);
2131   helib::Ptxt<helib::BGV> ptxt2(context, data);
2132   ptxt.mapTo01();
2133   mapTo01(*(context.ea), ptxt2);
2134 
2135   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2136     EXPECT_EQ(ptxt[i], expected_result[i]);
2137     EXPECT_EQ(ptxt2[i], expected_result[i]);
2138   }
2139 }
2140 
TEST(TestPtxtBGV,automorphWorksCorrectly)2141 TEST(TestPtxtBGV, automorphWorksCorrectly)
2142 {
2143   std::vector<long> gens = {11, 2};
2144   std::vector<long> ords = {6, 2};
2145   const helib::Context context(45, 19, 1, gens, ords);
2146   std::vector<NTL::ZZX> data(context.ea->size());
2147   for (std::size_t i = 0; i < data.size(); ++i) {
2148     NTL::SetX(data[i]);
2149     (data[i] += 1) *= i;
2150   }
2151 
2152   helib::Ptxt<helib::BGV> ptxt(context, data);
2153   helib::Ptxt<helib::BGV> expected_result(ptxt);
2154   expected_result[0] = {13, 10};
2155   expected_result[1] = {0, 0};
2156   expected_result[2] = {18, 8};
2157   expected_result[3] = {2, 2};
2158   expected_result[4] = {12, 18};
2159   expected_result[5] = {4, 4};
2160   expected_result[6] = {17, 16};
2161   expected_result[7] = {6, 6};
2162   expected_result[8] = {3, 14};
2163   expected_result[9] = {8, 8};
2164   expected_result[10] = {8, 12};
2165   expected_result[11] = {10, 10};
2166 
2167   ptxt.automorph(2);
2168 
2169   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2170     EXPECT_EQ(ptxt[i], expected_result[i]);
2171   }
2172 }
2173 
TEST_P(TestPtxtBGV,frobeniusAutomorphWithConstantsWorksCorrectly)2174 TEST_P(TestPtxtBGV, frobeniusAutomorphWithConstantsWorksCorrectly)
2175 {
2176   std::vector<long> data(context.ea->size());
2177   std::iota(data.begin(), data.end(), 0);
2178   std::vector<long> expected_result(data);
2179   helib::Ptxt<helib::BGV> ptxt(context, data);
2180   for (long i = 0; i <= context.zMStar.getOrdP(); ++i) {
2181     auto ptxtUnderTest = ptxt;
2182     ptxtUnderTest.frobeniusAutomorph(i);
2183     for (std::size_t j = 0; j < ptxtUnderTest.size(); ++j) {
2184       ASSERT_EQ(ptxtUnderTest[j], expected_result[j])
2185           << "i = " << i << " j = " << j << std::endl;
2186     }
2187   }
2188 }
2189 
TEST(TestPtxtBGV,frobeniusAutomorphWithPolynomialsWorksCorrectly)2190 TEST(TestPtxtBGV, frobeniusAutomorphWithPolynomialsWorksCorrectly)
2191 {
2192   const helib::Context context(45, 19, 1);
2193   std::vector<NTL::ZZX> data(context.ea->size());
2194   for (std::size_t i = 0; i < data.size(); ++i) {
2195     NTL::SetX(data[i]);
2196     (data[i] += 1) *= i;
2197   }
2198 
2199   helib::Ptxt<helib::BGV> ptxt(context, data);
2200   helib::Ptxt<helib::BGV> expected_result(ptxt);
2201   expected_result[0] = {0, 0};
2202   expected_result[1] = {12, 18};
2203   expected_result[2] = {5, 17};
2204   expected_result[3] = {17, 16};
2205   expected_result[4] = {10, 15};
2206   expected_result[5] = {3, 14};
2207   expected_result[6] = {15, 13};
2208   expected_result[7] = {8, 12};
2209   expected_result[8] = {1, 11};
2210   expected_result[9] = {13, 10};
2211   expected_result[10] = {6, 9};
2212   expected_result[11] = {18, 8};
2213 
2214   ptxt.frobeniusAutomorph(1);
2215 
2216   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2217     EXPECT_EQ(ptxt[i], expected_result[i]);
2218   }
2219 }
2220 
TEST_P(TestPtxtBGV,timesEqualsOtherPlaintextWorks)2221 TEST_P(TestPtxtBGV, timesEqualsOtherPlaintextWorks)
2222 {
2223   std::vector<long> product_data(context.ea->size(), 3);
2224   std::vector<long> multiplier_data(context.ea->size());
2225   std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
2226 
2227   std::vector<long> expected_result(product_data);
2228   for (std::size_t i = 0; i < product_data.size(); ++i) {
2229     expected_result[i] = expected_result[i] * multiplier_data[i];
2230   }
2231 
2232   helib::Ptxt<helib::BGV> product(context, product_data);
2233   helib::Ptxt<helib::BGV> multiplier(context, multiplier_data);
2234 
2235   product *= multiplier;
2236 
2237   for (std::size_t i = 0; i < product.size(); ++i) {
2238     EXPECT_EQ(product[i], expected_result[i]);
2239   }
2240 }
2241 
TEST_P(TestPtxtBGV,minusEqualsOtherPlaintextWorks)2242 TEST_P(TestPtxtBGV, minusEqualsOtherPlaintextWorks)
2243 {
2244   std::vector<long> difference_data(context.ea->size(), 1);
2245   std::vector<long> subtrahend_data(context.ea->size());
2246   std::iota(subtrahend_data.begin(), subtrahend_data.end(), 0);
2247 
2248   std::vector<long> expected_result(difference_data);
2249   for (std::size_t i = 0; i < subtrahend_data.size(); ++i) {
2250     expected_result[i] = expected_result[i] - subtrahend_data[i];
2251   }
2252 
2253   helib::Ptxt<helib::BGV> difference(context, difference_data);
2254   helib::Ptxt<helib::BGV> subtrahend(context, subtrahend_data);
2255 
2256   difference -= subtrahend;
2257 
2258   for (std::size_t i = 0; i < difference.size(); ++i) {
2259     EXPECT_EQ(difference[i], expected_result[i]);
2260   }
2261 }
2262 
TEST_P(TestPtxtBGV,minusEqualsScalarWorks)2263 TEST_P(TestPtxtBGV, minusEqualsScalarWorks)
2264 {
2265   std::vector<long> data(context.ea->size());
2266   std::iota(data.begin(), data.end(), 0);
2267 
2268   const long scalar = 3;
2269 
2270   std::vector<long> expected_result(data);
2271   for (auto& num : expected_result)
2272     num = num - scalar;
2273 
2274   helib::Ptxt<helib::BGV> ptxt(context, data);
2275   ptxt -= scalar;
2276 
2277   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2278     EXPECT_EQ(ptxt[i], expected_result[i]);
2279   }
2280 }
2281 
TEST_P(TestPtxtBGV,plusEqualsOtherPlaintextWorks)2282 TEST_P(TestPtxtBGV, plusEqualsOtherPlaintextWorks)
2283 {
2284   std::vector<long> augend_data(context.ea->size());
2285   std::iota(augend_data.begin(), augend_data.end(), 0);
2286   std::vector<long> addend_data(context.ea->size());
2287   for (long i = 0; i < helib::lsize(addend_data); ++i)
2288     addend_data[i] = helib::mcMod(2 * i + 1, p);
2289   std::vector<long> expected_result(context.ea->size());
2290   for (long i = 0; i < helib::lsize(expected_result); ++i)
2291     expected_result[i] = augend_data[i] + addend_data[i];
2292 
2293   helib::Ptxt<helib::BGV> sum(context, augend_data);
2294   helib::Ptxt<helib::BGV> addend(context, addend_data);
2295   sum += addend;
2296 
2297   for (std::size_t i = 0; i < sum.size(); ++i) {
2298     EXPECT_EQ(sum[i], expected_result[i]);
2299   }
2300 }
2301 
TEST_P(TestPtxtBGV,plusEqualsScalarWorks)2302 TEST_P(TestPtxtBGV, plusEqualsScalarWorks)
2303 {
2304   std::vector<long> data(context.ea->size());
2305   std::iota(data.begin(), data.end(), 0);
2306 
2307   const long scalar = 3;
2308 
2309   std::vector<long> expected_result(data);
2310   for (auto& num : expected_result)
2311     num = num + scalar;
2312 
2313   helib::Ptxt<helib::BGV> ptxt(context, data);
2314   ptxt += scalar;
2315 
2316   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2317     EXPECT_EQ(ptxt[i], expected_result[i]);
2318   }
2319 }
2320 
TEST_P(TestPtxtBGV,timesEqualsScalarWorks)2321 TEST_P(TestPtxtBGV, timesEqualsScalarWorks)
2322 {
2323   std::vector<long> data(context.ea->size());
2324   std::iota(data.begin(), data.end(), 0);
2325 
2326   const long scalar = 3;
2327 
2328   std::vector<long> expected_result(data);
2329   for (auto& num : expected_result)
2330     num = num * scalar;
2331 
2332   helib::Ptxt<helib::BGV> ptxt(context, data);
2333   ptxt *= scalar;
2334 
2335   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2336     EXPECT_EQ(ptxt[i], expected_result[i]);
2337   }
2338 }
2339 
TEST_P(TestPtxtBGV,equalityWithOtherPlaintextWorks)2340 TEST_P(TestPtxtBGV, equalityWithOtherPlaintextWorks)
2341 {
2342   std::vector<long> data(context.ea->size());
2343   std::iota(data.begin(), data.end(), 0);
2344 
2345   helib::Ptxt<helib::BGV> ptxt1(context, data);
2346   helib::Ptxt<helib::BGV> ptxt2(context, data);
2347   EXPECT_TRUE(ptxt1 == ptxt2);
2348 }
2349 
TEST_P(TestPtxtBGV,notEqualsOperatorWithOtherPlaintextWorks)2350 TEST_P(TestPtxtBGV, notEqualsOperatorWithOtherPlaintextWorks)
2351 {
2352   std::vector<long> data1(context.ea->size());
2353   std::iota(data1.begin(), data1.end(), 0);
2354   std::vector<long> data2(context.ea->size());
2355   std::iota(data2.begin(), data2.end(), 1);
2356 
2357   helib::Ptxt<helib::BGV> ptxt1(context, data1);
2358   helib::Ptxt<helib::BGV> ptxt2(context, data2);
2359   EXPECT_TRUE(ptxt1 != ptxt2);
2360 }
2361 
TEST_P(TestPtxtBGV,negateNegatesCorrectly)2362 TEST_P(TestPtxtBGV, negateNegatesCorrectly)
2363 {
2364   std::vector<long> data(context.ea->size());
2365   std::iota(data.begin(), data.end(), 0);
2366 
2367   std::vector<long> expected_result(data);
2368   for (auto& num : expected_result)
2369     num = -num;
2370 
2371   helib::Ptxt<helib::BGV> ptxt(context, data);
2372   ptxt.negate();
2373 
2374   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2375     EXPECT_EQ(ptxt[i], expected_result[i]);
2376   }
2377 }
2378 
TEST_P(TestPtxtBGV,addConstantWorksCorrectly)2379 TEST_P(TestPtxtBGV, addConstantWorksCorrectly)
2380 {
2381   NTL::ZZX input;
2382   NTL::SetCoeff(input, 0, 2);
2383   NTL::SetCoeff(input, 1, 1);
2384 
2385   helib::PolyMod poly(input, context.slotRing);
2386   std::vector<helib::PolyMod> data(context.ea->size());
2387   for (std::size_t i = 0; i < data.size(); ++i)
2388     data[i] = poly + i;
2389 
2390   std::vector<helib::PolyMod> expected_result(data);
2391   for (auto& num : expected_result)
2392     (num += input) += 3L;
2393 
2394   helib::Ptxt<helib::BGV> ptxt(context, data);
2395   ptxt.addConstant(input).addConstant(3l);
2396 
2397   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2398     EXPECT_EQ(ptxt[i], expected_result[i]);
2399   }
2400 }
2401 
TEST_P(TestPtxtBGV,multiplyByMultipliesCorrectly)2402 TEST_P(TestPtxtBGV, multiplyByMultipliesCorrectly)
2403 {
2404   std::vector<long> product_data(context.ea->size(), 3);
2405   std::vector<long> multiplier_data(context.ea->size());
2406   std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
2407 
2408   std::vector<long> expected_result(product_data);
2409   for (std::size_t i = 0; i < product_data.size(); ++i) {
2410     expected_result[i] = expected_result[i] * multiplier_data[i];
2411   }
2412 
2413   helib::Ptxt<helib::BGV> product(context, product_data);
2414   helib::Ptxt<helib::BGV> multiplier(context, multiplier_data);
2415 
2416   product.multiplyBy(multiplier);
2417 
2418   for (std::size_t i = 0; i < product.size(); ++i) {
2419     EXPECT_EQ(product[i], expected_result[i]);
2420   }
2421 }
2422 
TEST_P(TestPtxtBGV,multiplyBy2MultipliesCorrectly)2423 TEST_P(TestPtxtBGV, multiplyBy2MultipliesCorrectly)
2424 {
2425   std::vector<long> product_data(context.ea->size(), 3);
2426   std::vector<long> multiplier_data1(context.ea->size());
2427   std::vector<long> multiplier_data2(context.ea->size());
2428   std::iota(multiplier_data1.begin(), multiplier_data1.end(), 0);
2429   std::iota(multiplier_data2.begin(), multiplier_data2.end(), 0);
2430 
2431   std::vector<long> expected_result(product_data);
2432   for (std::size_t i = 0; i < product_data.size(); ++i) {
2433     expected_result[i] =
2434         expected_result[i] * multiplier_data1[i] * multiplier_data2[i];
2435   }
2436 
2437   helib::Ptxt<helib::BGV> product(context, product_data);
2438   helib::Ptxt<helib::BGV> multiplier1(context, multiplier_data1);
2439   helib::Ptxt<helib::BGV> multiplier2(context, multiplier_data2);
2440 
2441   product.multiplyBy2(multiplier1, multiplier2);
2442 
2443   for (std::size_t i = 0; i < product.size(); ++i) {
2444     EXPECT_EQ(product[i], expected_result[i]);
2445   }
2446 }
2447 
TEST_P(TestPtxtBGV,squareSquaresCorrectly)2448 TEST_P(TestPtxtBGV, squareSquaresCorrectly)
2449 {
2450   std::vector<long> data(context.ea->size());
2451   std::iota(data.begin(), data.end(), 0);
2452   std::vector<long> expected_result(data);
2453   for (auto& num : expected_result)
2454     num = num * num;
2455   helib::Ptxt<helib::BGV> ptxt(context, data);
2456   ptxt.square();
2457   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2458     EXPECT_EQ(ptxt[i], expected_result[i]);
2459   }
2460 }
2461 
TEST_P(TestPtxtBGV,cubeCubesCorrectly)2462 TEST_P(TestPtxtBGV, cubeCubesCorrectly)
2463 {
2464   std::vector<long> data(context.ea->size());
2465   std::iota(data.begin(), data.end(), 0);
2466   std::vector<long> expected_result(data);
2467   for (auto& num : expected_result)
2468     num = num * num * num;
2469   helib::Ptxt<helib::BGV> ptxt(context, data);
2470   ptxt.cube();
2471   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2472     EXPECT_EQ(ptxt[i], expected_result[i]);
2473   }
2474 }
2475 
TEST_P(TestPtxtBGV,powerCorrectlyRaisesToPowers)2476 TEST_P(TestPtxtBGV, powerCorrectlyRaisesToPowers)
2477 {
2478   std::vector<long> data(context.ea->size());
2479   std::iota(data.begin(),
2480             data.end(),
2481             -(static_cast<long>(context.ea->size()) / 2));
2482   std::vector<long> exponents{1, 3, 4, 5, 300};
2483 
2484   const auto naive_powermod =
2485       [](long base, unsigned long exponent, unsigned long mod) {
2486         if (exponent == 0)
2487           return 1l;
2488 
2489         long result = base;
2490         while (--exponent)
2491           result = helib::mcMod(result * base, mod);
2492         return result;
2493       };
2494 
2495   for (const auto& exponent : exponents) {
2496     std::vector<long> expected_result(data);
2497     for (auto& num : expected_result)
2498       num = naive_powermod(num, exponent, ppowr);
2499     helib::Ptxt<helib::BGV> ptxt(context, data);
2500     ptxt.power(exponent);
2501     for (std::size_t i = 0; i < ptxt.size(); ++i) {
2502       EXPECT_EQ(ptxt[i], expected_result[i]);
2503     }
2504   }
2505 
2506   // Make sure raising to 0 throws
2507   helib::Ptxt<helib::CKKS> ptxt(context, data);
2508   EXPECT_THROW(ptxt.power(0l), helib::InvalidArgument);
2509 }
2510 
TEST_P(TestPtxtBGV,shiftShiftsRightCorrectly)2511 TEST_P(TestPtxtBGV, shiftShiftsRightCorrectly)
2512 {
2513   std::vector<long> data(context.ea->size());
2514   std::vector<long> right_shifted_data(context.ea->size());
2515   const auto non_neg_mod = [](int x, int mod) {
2516     return ((x % mod) + mod) % mod;
2517   };
2518   for (std::size_t i = 0; i < data.size(); ++i) {
2519     if (i > 3) {
2520       right_shifted_data[i] = non_neg_mod(i - 3, data.size());
2521     }
2522     data[i] = i;
2523   }
2524   helib::Ptxt<helib::BGV> ptxt(context, data);
2525 
2526   ptxt.shift(3);
2527   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2528     EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2529   }
2530 }
2531 
TEST_P(TestPtxtBGV,shiftShiftsLeftCorrectly)2532 TEST_P(TestPtxtBGV, shiftShiftsLeftCorrectly)
2533 {
2534   std::vector<long> data(context.ea->size());
2535   std::vector<long> left_shifted_data(context.ea->size());
2536   const auto non_neg_mod = [](int x, int mod) {
2537     return ((x % mod) + mod) % mod;
2538   };
2539   for (std::size_t i = 0; i < data.size(); ++i) {
2540     if (i < data.size() - 3 && data.size() > 3) {
2541       left_shifted_data[i] = non_neg_mod(i + 3, data.size());
2542     }
2543     data[i] = i;
2544   }
2545   helib::Ptxt<helib::BGV> ptxt(context, data);
2546 
2547   ptxt.shift(-3);
2548   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2549     EXPECT_EQ(ptxt[i], left_shifted_data[i]);
2550   }
2551 }
2552 
TEST(TestPtxtBGV,shift1DShiftsRightCorrectly)2553 TEST(TestPtxtBGV, shift1DShiftsRightCorrectly)
2554 {
2555   long amount = 1;
2556   const helib::Context context(45, 19, 1);
2557   std::vector<long> data(context.ea->size());
2558   std::vector<long> right_shifted_data(context.ea->size());
2559   const auto shift_first_dim = [](long amount, std::vector<long>& data) {
2560     std::vector<long> new_data(data.size(), 0l);
2561     for (long i = 0; i < helib::lsize(data); ++i)
2562       if (i + 2 * amount < 12 && i + 2 * amount >= 0)
2563         new_data[i + 2 * amount] = data[i];
2564     data = std::move(new_data);
2565   };
2566   const auto shift_second_dim = [](long amount, std::vector<long>& data) {
2567     std::vector<long> new_data(data.size(), 0l);
2568     for (long i = 0; i < helib::lsize(data); ++i)
2569       switch (amount) {
2570       case 1l:
2571         if (i < helib::lsize(data) - 1)
2572           new_data[i + 1] = i & 1 ? 0 : data[i];
2573         break;
2574       case 0l:
2575         new_data[i] = data[i];
2576         break;
2577       case -1l:
2578         if (i > 0)
2579           new_data[i - 1] = i & 1 ? data[i] : 0;
2580         break;
2581       default:
2582         new_data[i] = 0;
2583         break;
2584       }
2585     data = std::move(new_data);
2586   };
2587   for (long i = 0; i < helib::lsize(data); ++i) {
2588     data[i] = i;
2589   }
2590   {
2591     helib::Ptxt<helib::BGV> ptxt(context, data);
2592 
2593     right_shifted_data = data;
2594     shift_first_dim(amount, right_shifted_data);
2595     ptxt.shift1D(0, amount);
2596 
2597     for (std::size_t i = 0; i < ptxt.size(); ++i) {
2598       EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2599     }
2600   }
2601   {
2602     helib::Ptxt<helib::BGV> ptxt(context, data);
2603 
2604     right_shifted_data = data;
2605     shift_second_dim(amount, right_shifted_data);
2606     ptxt.shift1D(1, amount);
2607 
2608     for (std::size_t i = 0; i < ptxt.size(); ++i) {
2609       EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2610     }
2611   }
2612 }
2613 
TEST(TestPtxtBGV,shift1DShiftsLeftCorrectly)2614 TEST(TestPtxtBGV, shift1DShiftsLeftCorrectly)
2615 {
2616   long amount = -1;
2617   const helib::Context context(45, 19, 1);
2618   std::vector<long> data(context.ea->size());
2619   std::vector<long> right_shifted_data(context.ea->size());
2620   const auto shift_first_dim = [](long amount, std::vector<long>& data) {
2621     std::vector<long> new_data(data.size(), 0l);
2622     for (long i = 0; i < helib::lsize(data); ++i)
2623       if (i + 2 * amount < 12 && i + 2 * amount >= 0)
2624         new_data[i + 2 * amount] = data[i];
2625     data = std::move(new_data);
2626   };
2627   const auto shift_second_dim = [](long amount, std::vector<long>& data) {
2628     std::vector<long> new_data(data.size(), 0l);
2629     for (long i = 0; i < helib::lsize(data); ++i)
2630       switch (amount) {
2631       case 1l:
2632         if (i < helib::lsize(data) - 1)
2633           new_data[i + 1] = i & 1 ? 0 : data[i];
2634         break;
2635       case 0l:
2636         new_data[i] = data[i];
2637         break;
2638       case -1l:
2639         if (i > 0)
2640           new_data[i - 1] = i & 1 ? data[i] : 0;
2641         break;
2642       default:
2643         new_data[i] = 0;
2644         break;
2645       }
2646     data = std::move(new_data);
2647   };
2648   for (long i = 0; i < helib::lsize(data); ++i) {
2649     data[i] = i;
2650   }
2651   {
2652     helib::Ptxt<helib::BGV> ptxt(context, data);
2653 
2654     right_shifted_data = data;
2655     shift_first_dim(amount, right_shifted_data);
2656     ptxt.shift1D(0, amount);
2657 
2658     for (std::size_t i = 0; i < ptxt.size(); ++i) {
2659       EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2660     }
2661   }
2662   {
2663     helib::Ptxt<helib::BGV> ptxt(context, data);
2664 
2665     right_shifted_data = data;
2666     shift_second_dim(amount, right_shifted_data);
2667     ptxt.shift1D(1, amount);
2668 
2669     for (std::size_t i = 0; i < ptxt.size(); ++i) {
2670       EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2671     }
2672   }
2673 }
2674 
TEST(TestPtxtBGV,rotate1DRotatesCorrectly)2675 TEST(TestPtxtBGV, rotate1DRotatesCorrectly)
2676 {
2677   long amount = 1;
2678   const helib::Context context(45, 19, 1);
2679   std::vector<long> data(context.ea->size());
2680   std::vector<long> left_rotated_data(context.ea->size());
2681   const auto rotate_first_dim = [](long amount, std::vector<long>& data) {
2682     amount = helib::mcMod(amount, 12);
2683     std::vector<long> new_data(data);
2684     for (long i = 0; i < helib::lsize(data); ++i)
2685       new_data[(i + 2 * amount) % 12] = data[i];
2686     data = std::move(new_data);
2687   };
2688   const auto rotate_second_dim = [](long amount, std::vector<long>& data) {
2689     std::vector<long> new_data(data);
2690     for (long i = 0; i < helib::lsize(data); ++i)
2691       if (amount % 2)
2692         new_data[i + (i & 1 ? -1 : 1)] = data[i];
2693     data = std::move(new_data);
2694   };
2695   for (long i = 0; i < helib::lsize(data); ++i) {
2696     data[i] = i;
2697   }
2698   helib::Ptxt<helib::BGV> ptxt(context, data);
2699 
2700   // Rotate in first dimension (Good Dimension)
2701   rotate_first_dim(-amount, data);
2702   ptxt.rotate1D(0, -amount);
2703 
2704   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2705     EXPECT_EQ(ptxt[i], data[i]);
2706   }
2707 
2708   rotate_first_dim(amount, data);
2709   ptxt.rotate1D(0, amount);
2710   // Rotating back and forth gives the original data back
2711   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2712     EXPECT_EQ(ptxt[i], data[i]);
2713   }
2714 
2715   // Rotate in second dimension (Bad Dimension)
2716   rotate_second_dim(-amount, data);
2717   ptxt.rotate1D(1, -amount);
2718 
2719   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2720     EXPECT_EQ(ptxt[i], data[i]);
2721   }
2722 
2723   rotate_second_dim(amount, data);
2724   ptxt.rotate1D(1, amount);
2725   // Rotating back and forth gives the original data back
2726   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2727     EXPECT_EQ(ptxt[i], data[i]);
2728   }
2729 }
2730 
TEST_P(TestPtxtBGV,rotateRotatesCorrectly)2731 TEST_P(TestPtxtBGV, rotateRotatesCorrectly)
2732 {
2733   std::vector<long> data(context.ea->size());
2734   std::vector<long> left_rotated_data(context.ea->size());
2735   const auto non_neg_mod = [](int x, int mod) {
2736     return ((x % mod) + mod) % mod;
2737   };
2738   for (int i = 0; i < helib::lsize(data); ++i) {
2739     data[i] = non_neg_mod(i - 3, data.size());
2740     left_rotated_data[i] = i;
2741   }
2742   helib::Ptxt<helib::BGV> ptxt(context, data);
2743 
2744   ptxt.rotate(-3);
2745   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2746     EXPECT_EQ(ptxt[i], left_rotated_data[i]);
2747   }
2748   ptxt.rotate(3);
2749   // Rotating back and forth gives the original data back
2750   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2751     EXPECT_EQ(ptxt[i], data[i]);
2752   }
2753 }
2754 
TEST_P(TestPtxtBGV,replicateReplicatesCorrectly)2755 TEST_P(TestPtxtBGV, replicateReplicatesCorrectly)
2756 {
2757   std::vector<long> data(context.ea->size());
2758   std::iota(data.begin(), data.end(), 0);
2759   helib::Ptxt<helib::BGV> ptxt(context, data);
2760   helib::replicate(*context.ea, ptxt, data.size() - 1);
2761   std::vector<long> replicated_data(context.ea->size(), data[data.size() - 1]);
2762   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2763     EXPECT_EQ(ptxt[i], replicated_data[i]);
2764   }
2765 }
2766 
TEST_P(TestPtxtBGV,replicateAllWorksCorrectly)2767 TEST_P(TestPtxtBGV, replicateAllWorksCorrectly)
2768 {
2769   std::vector<long> data(context.ea->size());
2770   std::iota(data.begin(), data.end(), 0);
2771   helib::Ptxt<helib::BGV> ptxt(context, data);
2772   std::vector<helib::Ptxt<helib::BGV>> replicated_ptxts = ptxt.replicateAll();
2773   for (long i = 0; i < helib::lsize(data); ++i) {
2774     for (const auto& slot : replicated_ptxts[i].getSlotRepr()) {
2775       EXPECT_EQ(slot, data[i]);
2776     }
2777   }
2778 }
2779 
TEST_P(TestPtxtBGV,clearZeroesAllSlots)2780 TEST_P(TestPtxtBGV, clearZeroesAllSlots)
2781 {
2782   std::vector<long> data(context.ea->size());
2783   std::iota(data.begin(), data.end(), 0);
2784   std::vector<long> expected_result(context.ea->size(), 0);
2785   helib::Ptxt<helib::BGV> ptxt(context, data);
2786   ptxt.clear();
2787   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2788     EXPECT_EQ(ptxt[i], expected_result[i]);
2789   }
2790 }
2791 
TEST_P(TestPtxtBGV,defaultConstructedPtxtThrowsWhenOperatedOn)2792 TEST_P(TestPtxtBGV, defaultConstructedPtxtThrowsWhenOperatedOn)
2793 {
2794   helib::Ptxt<helib::BGV> ptxt1;
2795   helib::Ptxt<helib::BGV> ptxt2;
2796   helib::PolyMod poly;
2797   EXPECT_THROW(ptxt1.getSlotRepr(), helib::RuntimeError);
2798   EXPECT_THROW(ptxt1.setData(poly), helib::RuntimeError);
2799   EXPECT_THROW(ptxt1[0], helib::RuntimeError);
2800   EXPECT_THROW(ptxt1 *= ptxt2, helib::RuntimeError);
2801   EXPECT_THROW(ptxt1 += ptxt2, helib::RuntimeError);
2802   EXPECT_THROW(ptxt1 -= ptxt2, helib::RuntimeError);
2803   EXPECT_THROW(ptxt1 += 1l, helib::RuntimeError);
2804   EXPECT_THROW(ptxt1 *= 3l, helib::RuntimeError);
2805   EXPECT_THROW(ptxt1.negate(), helib::RuntimeError);
2806   EXPECT_THROW(ptxt1.multiplyBy(ptxt2), helib::RuntimeError);
2807   EXPECT_THROW(ptxt1.multiplyBy2(ptxt1, ptxt2), helib::RuntimeError);
2808   EXPECT_THROW(ptxt1.square(), helib::RuntimeError);
2809   EXPECT_THROW(ptxt1.cube(), helib::RuntimeError);
2810   EXPECT_THROW(ptxt1.power(4l), helib::RuntimeError);
2811   EXPECT_THROW(ptxt1.size(), helib::RuntimeError);
2812   EXPECT_THROW(ptxt1.rotate(1), helib::RuntimeError);
2813   EXPECT_THROW(ptxt1.rotate1D(0, 1), helib::RuntimeError);
2814   EXPECT_THROW(ptxt1.shift(1), helib::RuntimeError);
2815   EXPECT_THROW(ptxt1.lsize(), helib::LogicError);
2816 }
2817 
TEST_P(TestPtxtBGV,defaultConstructedContextCannotBeRightOperand)2818 TEST_P(TestPtxtBGV, defaultConstructedContextCannotBeRightOperand)
2819 {
2820   std::vector<long> data(context.ea->size());
2821   std::iota(data.begin(), data.end(), 0);
2822   helib::Ptxt<helib::BGV> valid_ptxt(context, data);
2823   helib::Ptxt<helib::BGV> invalid_ptxt;
2824 
2825   EXPECT_THROW(valid_ptxt *= invalid_ptxt, helib::RuntimeError);
2826   EXPECT_THROW(valid_ptxt += invalid_ptxt, helib::RuntimeError);
2827   EXPECT_THROW(valid_ptxt -= invalid_ptxt, helib::RuntimeError);
2828   EXPECT_THROW(valid_ptxt.multiplyBy(invalid_ptxt), helib::RuntimeError);
2829   EXPECT_THROW(valid_ptxt.multiplyBy2(invalid_ptxt, valid_ptxt),
2830                helib::RuntimeError);
2831   EXPECT_THROW(valid_ptxt.multiplyBy2(valid_ptxt, invalid_ptxt),
2832                helib::RuntimeError);
2833 }
2834 
TEST_P(TestPtxtBGV,cannotOperateBetweenPtxtsWithDifferentContexts)2835 TEST_P(TestPtxtBGV, cannotOperateBetweenPtxtsWithDifferentContexts)
2836 {
2837   helib::Context different_context(m, p, 2 * r);
2838   std::vector<long> data(context.ea->size(), 1);
2839   helib::Ptxt<helib::BGV> ptxt1(context, data);
2840   helib::Ptxt<helib::BGV> ptxt2(different_context, data);
2841   EXPECT_THROW(ptxt1 *= ptxt2, helib::LogicError);
2842   EXPECT_THROW(ptxt1 -= ptxt2, helib::LogicError);
2843   EXPECT_THROW(ptxt1 += ptxt2, helib::LogicError);
2844   EXPECT_THROW(ptxt1.multiplyBy(ptxt2), helib::LogicError);
2845   EXPECT_THROW(ptxt1.multiplyBy2(ptxt1, ptxt2), helib::LogicError);
2846 }
2847 
TEST_P(TestPtxtBGV,preservesDataPassedAsZZX)2848 TEST_P(TestPtxtBGV, preservesDataPassedAsZZX)
2849 {
2850   // Put in x + 1 and make sure we get x + 1 out
2851   NTL::ZZX input_polynomial;
2852   SetCoeff(input_polynomial, 0, 1);
2853   SetCoeff(input_polynomial, 1, 1);
2854 
2855   helib::Ptxt<helib::BGV> ptxt(context, input_polynomial);
2856 
2857   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2858     EXPECT_EQ(ptxt[i], input_polynomial);
2859   }
2860 }
2861 
TEST_P(TestPtxtBGV,setDataWorksWithZZXSameOrderAsPhiMX)2862 TEST_P(TestPtxtBGV, setDataWorksWithZZXSameOrderAsPhiMX)
2863 {
2864   NTL::ZZX phi_mx;
2865   switch (context.alMod.getTag()) {
2866   case helib::PA_GF2_tag:
2867     phi_mx = NTL::conv<NTL::ZZX>(
2868         context.alMod.getDerived(helib::PA_GF2()).getPhimXMod());
2869     break;
2870   case helib::PA_zz_p_tag:
2871     helib::convert(phi_mx,
2872                    context.alMod.getDerived(helib::PA_zz_p()).getPhimXMod());
2873     break;
2874   case helib::PA_cx_tag:
2875     // CKKS: do nothing
2876     break;
2877   default:
2878     throw helib::LogicError("No valid tag found in EncryptedArray");
2879   }
2880   // Put phi_mx + 1 as data
2881   NTL::ZZX input_polynomial(phi_mx + 1);
2882 
2883   helib::Ptxt<helib::BGV> ptxt(context, input_polynomial);
2884 
2885   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2886     EXPECT_EQ(ptxt[i], 1l);
2887   }
2888 }
2889 
TEST_P(TestPtxtBGV,decodeSetDataWorks)2890 TEST_P(TestPtxtBGV, decodeSetDataWorks)
2891 {
2892   // Put in x + 1 and make sure we get x + 1 out
2893   NTL::ZZX input_polynomial;
2894   SetCoeff(input_polynomial, 0, 1);
2895   SetCoeff(input_polynomial, 1, 1);
2896 
2897   std::vector<NTL::ZZX> test_decoded;
2898   context.ea->decode(test_decoded, input_polynomial);
2899 
2900   helib::Ptxt<helib::BGV> ptxt(context);
2901 
2902   ptxt.decodeSetData(input_polynomial);
2903 
2904   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2905     EXPECT_EQ(ptxt[i], test_decoded[i]);
2906   }
2907 }
2908 
TEST_P(TestPtxtBGV,decodeSetDataWorksWithZZXSameOrderAsPhiMX)2909 TEST_P(TestPtxtBGV, decodeSetDataWorksWithZZXSameOrderAsPhiMX)
2910 {
2911   NTL::ZZX phi_mx;
2912   switch (context.alMod.getTag()) {
2913   case helib::PA_GF2_tag:
2914     phi_mx = NTL::conv<NTL::ZZX>(
2915         context.alMod.getDerived(helib::PA_GF2()).getPhimXMod());
2916     break;
2917   case helib::PA_zz_p_tag:
2918     helib::convert(phi_mx,
2919                    context.alMod.getDerived(helib::PA_zz_p()).getPhimXMod());
2920     break;
2921   case helib::PA_cx_tag:
2922     // CKKS: do nothing
2923     break;
2924   default:
2925     throw helib::LogicError("No valid tag found in EncryptedArray");
2926   }
2927   // Put phi_mx + 1 as data
2928   NTL::ZZX input_polynomial(phi_mx + 1);
2929 
2930   helib::Ptxt<helib::BGV> ptxt(context, input_polynomial);
2931 
2932   std::vector<NTL::ZZX> test_decoded;
2933   context.ea->decode(test_decoded, input_polynomial);
2934 
2935   for (std::size_t i = 0; i < ptxt.size(); ++i) {
2936     EXPECT_EQ(ptxt[i], test_decoded[i]);
2937   }
2938 }
2939 
TEST_P(TestPtxtBGV,canEncryptAndDecryptPtxts)2940 TEST_P(TestPtxtBGV, canEncryptAndDecryptPtxts)
2941 {
2942   helib::buildModChain(context, 30, 2);
2943   helib::SecKey secret_key(context);
2944   secret_key.GenSecKey();
2945   const helib::PubKey& public_key(secret_key);
2946 
2947   std::vector<long> data(context.ea->size());
2948   std::iota(data.begin(), data.end(), 0);
2949   helib::Ptxt<helib::BGV> pre_encryption(context, data);
2950   helib::Ctxt ctxt(public_key);
2951   public_key.Encrypt(ctxt, pre_encryption);
2952   helib::Ptxt<helib::BGV> post_decryption(context);
2953   secret_key.Decrypt(post_decryption, ctxt);
2954   for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
2955     EXPECT_EQ(pre_encryption[i], post_decryption[i]);
2956   }
2957 }
2958 
TEST_P(TestPtxtBGV,plusEqualsWithCiphertextWorks)2959 TEST_P(TestPtxtBGV, plusEqualsWithCiphertextWorks)
2960 {
2961   helib::buildModChain(context, 30, 2);
2962   helib::SecKey secret_key(context);
2963   secret_key.GenSecKey();
2964   const helib::PubKey& public_key(secret_key);
2965 
2966   // Encrypt the augend, addend is plaintext
2967   std::vector<long> augend_data(context.ea->size());
2968   std::vector<long> addend_data(context.ea->size());
2969   std::iota(augend_data.begin(), augend_data.end(), 0);
2970   std::iota(addend_data.begin(), addend_data.end(), 7);
2971   helib::Ptxt<helib::BGV> augend_ptxt(context, augend_data);
2972   helib::Ptxt<helib::BGV> addend(context, addend_data);
2973   helib::Ctxt augend(public_key);
2974   public_key.Encrypt(augend, augend_ptxt);
2975 
2976   augend += addend;
2977   augend_ptxt += addend;
2978 
2979   // augend_ptxt and augend should now match
2980   helib::Ptxt<helib::BGV> result(context);
2981   secret_key.Decrypt(result, augend);
2982   for (std::size_t i = 0; i < result.size(); ++i) {
2983     EXPECT_EQ(result[i], augend_ptxt[i]);
2984   }
2985 }
2986 
TEST_P(TestPtxtBGV,addConstantFromCiphertextWorks)2987 TEST_P(TestPtxtBGV, addConstantFromCiphertextWorks)
2988 {
2989   helib::buildModChain(context, 30, 2);
2990   helib::SecKey secret_key(context);
2991   secret_key.GenSecKey();
2992   const helib::PubKey& public_key(secret_key);
2993 
2994   // Encrypt the augend, addend is plaintext
2995   std::vector<long> augend_data(context.ea->size());
2996   std::vector<long> addend_data(context.ea->size());
2997   std::iota(augend_data.begin(), augend_data.end(), 0);
2998   std::iota(addend_data.begin(), addend_data.end(), 7);
2999   helib::Ptxt<helib::BGV> augend_ptxt(context, augend_data);
3000   helib::Ptxt<helib::BGV> addend(context, addend_data);
3001   helib::Ctxt augend(public_key);
3002   public_key.Encrypt(augend, augend_ptxt);
3003 
3004   augend.addConstant(addend);
3005   augend_ptxt += addend;
3006 
3007   // augend_ptxt and augend should now match
3008   helib::Ptxt<helib::BGV> result(context);
3009   secret_key.Decrypt(result, augend);
3010   for (std::size_t i = 0; i < result.size(); ++i) {
3011     EXPECT_EQ(result[i], augend_ptxt[i]);
3012   }
3013 }
3014 
TEST_P(TestPtxtBGV,minusEqualsWithCiphertextWorks)3015 TEST_P(TestPtxtBGV, minusEqualsWithCiphertextWorks)
3016 {
3017   helib::buildModChain(context, 30, 2);
3018   helib::SecKey secret_key(context);
3019   secret_key.GenSecKey();
3020   const helib::PubKey& public_key(secret_key);
3021 
3022   // Encrypt the minuend, subtrahend is plaintext
3023   std::vector<long> minuend_data(context.ea->size());
3024   std::vector<long> subtrahend_data(context.ea->size());
3025   std::iota(minuend_data.begin(), minuend_data.end(), 0);
3026   std::iota(subtrahend_data.begin(), subtrahend_data.end(), 7);
3027   helib::Ptxt<helib::BGV> minuend_ptxt(context, minuend_data);
3028   helib::Ptxt<helib::BGV> subtrahend(context, subtrahend_data);
3029   helib::Ctxt minuend(public_key);
3030   public_key.Encrypt(minuend, minuend_ptxt);
3031 
3032   minuend -= subtrahend;
3033   minuend_ptxt -= subtrahend;
3034 
3035   // minuend_ptxt and minuend should now match
3036   helib::Ptxt<helib::BGV> result(context);
3037   secret_key.Decrypt(result, minuend);
3038   for (std::size_t i = 0; i < result.size(); ++i) {
3039     EXPECT_EQ(result[i], minuend_ptxt[i]);
3040   }
3041 }
3042 
TEST_P(TestPtxtBGV,timesEqualsWithCiphertextWorks)3043 TEST_P(TestPtxtBGV, timesEqualsWithCiphertextWorks)
3044 {
3045   helib::buildModChain(context, 30, 2);
3046   helib::SecKey secret_key(context);
3047   secret_key.GenSecKey();
3048   const helib::PubKey& public_key(secret_key);
3049 
3050   // Encrypt the multiplier, multiplicand is plaintext
3051   std::vector<long> multiplier_data(context.ea->size());
3052   std::vector<long> multiplicand_data(context.ea->size());
3053   std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
3054   std::iota(multiplicand_data.begin(), multiplicand_data.end(), 7);
3055   helib::Ptxt<helib::BGV> multiplier_ptxt(context, multiplier_data);
3056   helib::Ptxt<helib::BGV> multiplicand(context, multiplicand_data);
3057   helib::Ctxt multiplier(public_key);
3058   public_key.Encrypt(multiplier, multiplier_ptxt);
3059 
3060   multiplier *= multiplicand;
3061   multiplier_ptxt *= multiplicand;
3062 
3063   // multiplier_ptxt and multiplier should now match
3064   helib::Ptxt<helib::BGV> result(context);
3065   secret_key.Decrypt(result, multiplier);
3066   for (std::size_t i = 0; i < result.size(); ++i) {
3067     EXPECT_EQ(result[i], multiplier_ptxt[i]);
3068   }
3069 }
3070 
TEST_P(TestPtxtBGV,multByConstantFromCiphertextWorks)3071 TEST_P(TestPtxtBGV, multByConstantFromCiphertextWorks)
3072 {
3073   helib::buildModChain(context, 30, 2);
3074   helib::SecKey secret_key(context);
3075   secret_key.GenSecKey();
3076   const helib::PubKey& public_key(secret_key);
3077 
3078   // Encrypt the multiplier, multiplicand is plaintext
3079   std::vector<long> multiplier_data(context.ea->size());
3080   std::vector<long> multiplicand_data(context.ea->size());
3081   std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
3082   std::iota(multiplicand_data.begin(), multiplicand_data.end(), 7);
3083   helib::Ptxt<helib::BGV> multiplier_ptxt(context, multiplier_data);
3084   helib::Ptxt<helib::BGV> multiplicand(context, multiplicand_data);
3085   helib::Ctxt multiplier(public_key);
3086   public_key.Encrypt(multiplier, multiplier_ptxt);
3087 
3088   multiplier.multByConstant(multiplicand);
3089   multiplier_ptxt *= multiplicand;
3090 
3091   // multiplier_ptxt and multiplier should now match
3092   helib::Ptxt<helib::BGV> result(context);
3093   secret_key.Decrypt(result, multiplier);
3094   for (std::size_t i = 0; i < result.size(); ++i) {
3095     EXPECT_EQ(result[i], multiplier_ptxt[i]);
3096   }
3097 }
3098 
3099 // Useful for testing non-power of 2 for CKKS
3100 // INSTANTIATE_TEST_SUITE_P(various_parameters, TestPtxtCKKS,
3101 //        ::testing::Values( 17, 168, 126, 78, 33, 50, 64)
3102 // );
3103 
3104 INSTANTIATE_TEST_SUITE_P(
3105     various_Parameters,
3106     TestPtxtCKKS,
3107     ::testing::Values(2 << 1, 2 << 2, 2 << 3, 2 << 4, 2 << 5, 2 << 6, 2 << 7));
3108 
3109 INSTANTIATE_TEST_SUITE_P(
3110     various_Parameters,
3111     TestPtxtBGV,
3112     ::testing::Values(BGVParameters(17, 2, 1),
3113                       BGVParameters(17, 2, 3),
3114                       BGVParameters(168, 13, 1),
3115                       BGVParameters(126, 127, 1),
3116                       BGVParameters(78, 79, 1),
3117                       BGVParameters(33, 19, 2),
3118                       // NOTE: This was used because it has 3 good dimensions
3119                       // BGVParameters(10005, 37, 1),
3120                       BGVParameters(50, 53, 1)));
3121 
3122 } // namespace
3123