1 /* Copyright (C) 2019-2020 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12
13 #include <numeric>
14
15 #include <helib/Ptxt.h>
16 #include <helib/helib.h>
17 #include <helib/replicate.h>
18 #include <helib/NumbTh.h>
19
20 #include "test_common.h"
21 #include "gtest/gtest.h"
22
23 namespace {
24
25 struct BGVParameters
26 {
BGVParameters__anone97409180111::BGVParameters27 BGVParameters(unsigned m, unsigned p, unsigned r) : m(m), p(p), r(r){};
28
29 const unsigned m;
30 const unsigned p;
31 const unsigned r;
32
operator <<(std::ostream & os,const BGVParameters & params)33 friend std::ostream& operator<<(std::ostream& os, const BGVParameters& params)
34 {
35 return os << "{"
36 << "m = " << params.m << ", "
37 << "p = " << params.p << ", "
38 << "r = " << params.r << "}";
39 }
40 };
41
42 class TestPtxtCKKS : public ::testing::TestWithParam<unsigned>
43 {
44 protected:
TestPtxtCKKS()45 TestPtxtCKKS() :
46 // Only relevant parameter is m for a CKKS plaintext
47 m(GetParam()),
48 context(m, -1, 40)
49 // VJS_NOTE: I changed r=50 to r=40.
50 // I find setting r so large can cause problems,
51 // and the test was not passing.
52 // This may be somethng that needs further investigation,
53 // but later...
54 // This probably has something to do with the slightly
55 // different logic in the new encoding functions
56 {}
57
58 const unsigned long m;
59
60 helib::Context context;
61
62 const double pre_encryption_epsilon = 1E-8;
63 const double post_encryption_epsilon = 1E-3;
64 };
65
TEST_P(TestPtxtCKKS,canBeConstructedWithCKKSContext)66 TEST_P(TestPtxtCKKS, canBeConstructedWithCKKSContext)
67 {
68 helib::Ptxt<helib::CKKS> ptxt(context);
69 }
70
TEST_P(TestPtxtCKKS,canBeDefaultConstructed)71 TEST_P(TestPtxtCKKS, canBeDefaultConstructed) { helib::Ptxt<helib::CKKS> ptxt; }
72
TEST_P(TestPtxtCKKS,canBeCopyConstructed)73 TEST_P(TestPtxtCKKS, canBeCopyConstructed)
74 {
75 helib::Ptxt<helib::CKKS> ptxt(context);
76 helib::Ptxt<helib::CKKS> ptxt2(ptxt);
77 }
78
TEST_P(TestPtxtCKKS,canBeAssignedFromOtherPtxt)79 TEST_P(TestPtxtCKKS, canBeAssignedFromOtherPtxt)
80 {
81 helib::Ptxt<helib::CKKS> ptxt(context);
82 helib::Ptxt<helib::CKKS> ptxt2 = ptxt;
83 }
84
TEST_P(TestPtxtCKKS,reportsWhetherItIsValid)85 TEST_P(TestPtxtCKKS, reportsWhetherItIsValid)
86 {
87 helib::Ptxt<helib::CKKS> invalid_ptxt;
88 helib::Ptxt<helib::CKKS> valid_ptxt(context);
89 EXPECT_FALSE(invalid_ptxt.isValid());
90 EXPECT_TRUE(valid_ptxt.isValid());
91 }
92
TEST_P(TestPtxtCKKS,hasSameNumberOfSlotsAsContext)93 TEST_P(TestPtxtCKKS, hasSameNumberOfSlotsAsContext)
94 {
95 helib::Ptxt<helib::CKKS> ptxt(context);
96 EXPECT_EQ(ptxt.size(), context.ea->size());
97 }
98
TEST_P(TestPtxtCKKS,preservesDataPassedIntoConstructor)99 TEST_P(TestPtxtCKKS, preservesDataPassedIntoConstructor)
100 {
101 std::vector<std::complex<double>> data(context.ea->size());
102 for (std::size_t i = 0; i < data.size(); ++i)
103 data[i] = i / 10.0;
104 helib::Ptxt<helib::CKKS> ptxt(context, data);
105
106 COMPARE_CXDOUBLE_VECS(ptxt, data)
107 }
108
TEST_P(TestPtxtCKKS,hasSameNumberOfSlotsAsContextWhenCreatedWithData)109 TEST_P(TestPtxtCKKS, hasSameNumberOfSlotsAsContextWhenCreatedWithData)
110 {
111 std::vector<std::complex<double>> data(context.ea->size() - 1);
112 helib::Ptxt<helib::CKKS> ptxt(context, data);
113 EXPECT_EQ(ptxt.size(), context.ea->size());
114 }
115
TEST_P(TestPtxtCKKS,replicateValueWhenPassingASingleSlotTypeNumber)116 TEST_P(TestPtxtCKKS, replicateValueWhenPassingASingleSlotTypeNumber)
117 {
118 std::complex<double> num = {1. / 10.0, 1. / 10.0};
119
120 helib::Ptxt<helib::CKKS> ptxt(context, num);
121 for (std::size_t i = 0; i < ptxt.size(); ++i) {
122 EXPECT_DOUBLE_EQ(ptxt[i].real(), num.real());
123 EXPECT_DOUBLE_EQ(ptxt[i].imag(), num.imag());
124 }
125 }
126
TEST_P(TestPtxtCKKS,replicateValueWhenPassingASingleNonSlotTypeNumber)127 TEST_P(TestPtxtCKKS, replicateValueWhenPassingASingleNonSlotTypeNumber)
128 {
129 double num = 1. / 10.0;
130
131 helib::Ptxt<helib::CKKS> ptxt(context, num);
132 for (std::size_t i = 0; i < ptxt.size(); ++i) {
133 EXPECT_DOUBLE_EQ(ptxt[i].real(), num);
134 EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
135 }
136 }
137
TEST_P(TestPtxtCKKS,atMethodThrowsOrReturnsCorrectly)138 TEST_P(TestPtxtCKKS, atMethodThrowsOrReturnsCorrectly)
139 {
140 std::vector<std::complex<double>> data(context.ea->size());
141 for (std::size_t i = 0; i < data.size(); ++i)
142 data[i] = i / 10.0;
143 helib::Ptxt<helib::CKKS> ptxt(context, data);
144
145 for (long i = -5; i < 0; ++i) {
146 EXPECT_THROW(ptxt.at(i), helib::OutOfRangeError);
147 }
148 for (long i = 0; i < helib::lsize(data); ++i) {
149 EXPECT_DOUBLE_EQ(ptxt.at(i).real(), data.at(i).real());
150 EXPECT_DOUBLE_EQ(ptxt.at(i).imag(), data.at(i).imag());
151 }
152 for (std::size_t i = data.size(); i < data.size() + 5; ++i) {
153 EXPECT_THROW(ptxt.at(i), helib::OutOfRangeError);
154 }
155 }
156
TEST_P(TestPtxtCKKS,padsWithZerosWhenPassingInSmallDataVector)157 TEST_P(TestPtxtCKKS, padsWithZerosWhenPassingInSmallDataVector)
158 {
159 std::vector<std::complex<double>> data(context.ea->size() - 1);
160 for (long i = 0; i < helib::lsize(data); ++i) {
161 data[i] = {(i - 1) / 10.0, (i - 1) / 10.0};
162 }
163 helib::Ptxt<helib::CKKS> ptxt(context, data);
164 for (std::size_t i = 0; i < data.size(); ++i) {
165 EXPECT_DOUBLE_EQ(ptxt[i].real(), data[i].real());
166 EXPECT_DOUBLE_EQ(ptxt[i].imag(), data[i].imag());
167 }
168 for (std::size_t i = data.size(); i < ptxt.size(); ++i) {
169 EXPECT_DOUBLE_EQ(ptxt[i].real(), 0.0);
170 EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
171 }
172 }
173
TEST_P(TestPtxtCKKS,preservesDataPassedIntoConstructorAsDouble)174 TEST_P(TestPtxtCKKS, preservesDataPassedIntoConstructorAsDouble)
175 {
176 std::vector<double> data(context.ea->size() - 1);
177 for (long i = 0; i < helib::lsize(data); ++i) {
178 data[i] = (i - 1) / 10.0;
179 }
180 helib::Ptxt<helib::CKKS> ptxt(context, data);
181 for (std::size_t i = 0; i < data.size(); ++i) {
182 EXPECT_DOUBLE_EQ(ptxt[i].real(), data[i]);
183 EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
184 }
185 for (std::size_t i = data.size(); i < ptxt.size(); ++i) {
186 EXPECT_DOUBLE_EQ(ptxt[i].real(), 0.0);
187 EXPECT_DOUBLE_EQ(ptxt[i].imag(), 0.0);
188 }
189 }
190
TEST_P(TestPtxtCKKS,writesDataCorrectlyToOstream)191 TEST_P(TestPtxtCKKS, writesDataCorrectlyToOstream)
192 {
193 std::vector<std::complex<double>> data(context.ea->size());
194 for (std::size_t i = 0; i < data.size(); ++i) {
195 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
196 }
197 helib::Ptxt<helib::CKKS> ptxt(context, data);
198 std::stringstream ss;
199 ss << "[";
200 ss << std::setprecision(std::numeric_limits<double>::digits10);
201 for (auto it = data.begin(); it != data.end(); it++) {
202 ss << "[" << it->real() << ", " << it->imag() << "]";
203 if (it != data.end() - 1) {
204 ss << ", ";
205 }
206 }
207 ss << "]";
208 std::string expected = ss.str();
209 std::ostringstream os;
210 os << ptxt;
211
212 EXPECT_EQ(os.str(), expected);
213 }
214
TEST_P(TestPtxtCKKS,readsDataCorrectlyFromIstream)215 TEST_P(TestPtxtCKKS, readsDataCorrectlyFromIstream)
216 {
217 std::vector<std::complex<double>> data(context.ea->size());
218 for (std::size_t i = 0; i < data.size(); ++i) {
219 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
220 }
221 helib::Ptxt<helib::CKKS> ptxt(context);
222 std::stringstream ss;
223 ss << "[";
224 ss << std::setprecision(std::numeric_limits<double>::digits10);
225 for (auto it = data.begin(); it != data.end(); it++) {
226 helib::serialize(ss, *it);
227 if (it != data.end() - 1) {
228 ss << ", ";
229 }
230 }
231 ss << "]";
232 std::istringstream is(ss.str());
233 is >> ptxt;
234
235 for (std::size_t i = 0; i < ptxt.size(); ++i) {
236 EXPECT_NEAR(std::abs(ptxt[i] - data[i]), 0, pre_encryption_epsilon);
237 }
238 }
239
TEST_P(TestPtxtCKKS,readsSquareBracketsDataCorrectly)240 TEST_P(TestPtxtCKKS, readsSquareBracketsDataCorrectly)
241 {
242 std::vector<std::complex<double>> data(context.ea->size());
243 for (std::size_t i = 0; i < data.size(); ++i) {
244 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
245 }
246 helib::Ptxt<helib::CKKS> ptxt(context);
247 std::stringstream ss;
248 ss << "[";
249 ss << std::setprecision(std::numeric_limits<double>::digits10);
250 for (auto it = data.begin(); it != data.end(); it++) {
251 ss << "[" << it->real() << ", " << it->imag() << "]";
252 if (it != data.end() - 1) {
253 ss << ", ";
254 }
255 }
256 ss << "]";
257 std::istringstream is(ss.str());
258 is >> ptxt;
259
260 for (std::size_t i = 0; i < ptxt.size(); ++i) {
261 EXPECT_NEAR(std::abs(ptxt[i] - data[i]), 0, pre_encryption_epsilon);
262 }
263 }
264
TEST_P(TestPtxtCKKS,serializeFunctionSerializesStdComplexCorrectly)265 TEST_P(TestPtxtCKKS, serializeFunctionSerializesStdComplexCorrectly)
266 {
267 // TODO: This test may be removed from the fixture and put as standalone
268 std::stringstream ss;
269 std::complex<double> num;
270
271 num = 0.0;
272 helib::serialize(ss, num);
273 EXPECT_EQ(ss.str(), "[0, 0]");
274 ss.str("");
275
276 num = 10.3;
277 helib::serialize(ss, num);
278 EXPECT_EQ(ss.str(), "[10.3, 0]");
279 ss.str("");
280
281 num = {0, 10.3};
282 helib::serialize(ss, num);
283 EXPECT_EQ(ss.str(), "[0, 10.3]");
284 ss.str("");
285
286 num = {3.3, 16.6};
287 helib::serialize(ss, num);
288 EXPECT_EQ(ss.str(), "[3.3, 16.6]");
289 ss.str("");
290 }
291
TEST_P(TestPtxtCKKS,deserializeFunctionDeserializesStdComplexCorrectly)292 TEST_P(TestPtxtCKKS, deserializeFunctionDeserializesStdComplexCorrectly)
293 {
294 // TODO: This test may be removed from the fixture and put as standalone
295 std::complex<double> num, expected;
296 std::stringstream ss;
297
298 num = 0.0;
299 ss << "[1,2,3]";
300 expected = 0.0;
301 EXPECT_THROW(helib::deserialize(ss, num), helib::IOError);
302 ss.str("");
303
304 num = 0.0;
305 ss << "[]";
306 expected = 0.0;
307 helib::deserialize(ss, num);
308 EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
309 ss.str("");
310
311 num = 0.0;
312 ss << "[0.0]";
313 expected = 0.0;
314 helib::deserialize(ss, num);
315 EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
316 ss.str("");
317
318 num = 0.0;
319 ss << "[0.0,0.0]";
320 expected = 0.0;
321 helib::deserialize(ss, num);
322 EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
323 ss.str("");
324
325 num = 0.0;
326 ss << "[5.3,0]";
327 expected = 5.3;
328 helib::deserialize(ss, num);
329 EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
330 ss.str("");
331
332 num = 0.0;
333 ss << "[0,8.16]";
334 expected = {0.0, 8.16};
335 helib::deserialize(ss, num);
336 EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
337 ss.str("");
338
339 num = 0.0;
340 ss << "[3.4,9.99]";
341 expected = {3.4, 9.99};
342 helib::deserialize(ss, num);
343 EXPECT_NEAR(std::abs(num - expected), 0.0, pre_encryption_epsilon);
344 ss.str("");
345 }
346
TEST_P(TestPtxtCKKS,serializeFunctionSerializesCorrectly)347 TEST_P(TestPtxtCKKS, serializeFunctionSerializesCorrectly)
348 {
349 std::vector<std::complex<double>> data(context.ea->size());
350 for (std::size_t i = 0; i < data.size(); ++i) {
351 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
352 }
353 helib::Ptxt<helib::CKKS> ptxt(context, data);
354 std::stringstream ss;
355 ss << "[";
356 ss << std::setprecision(std::numeric_limits<double>::digits10);
357 for (auto it = data.begin(); it != data.end(); it++) {
358 ss << "[" << it->real() << ", " << it->imag() << "]";
359 if (it != data.end() - 1) {
360 ss << ", ";
361 }
362 }
363 ss << "]";
364 std::stringstream serialized_ptxt;
365 helib::serialize(serialized_ptxt, ptxt);
366
367 EXPECT_EQ(serialized_ptxt.str(), ss.str());
368 }
369
TEST_P(TestPtxtCKKS,deserializeWorksCorrectly)370 TEST_P(TestPtxtCKKS, deserializeWorksCorrectly)
371 {
372 std::vector<std::complex<double>> data(context.ea->size());
373 for (std::size_t i = 0; i < data.size(); ++i) {
374 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
375 }
376 helib::Ptxt<helib::CKKS> ptxt(context);
377 std::stringstream ss;
378 ss << "[";
379 ss << std::setprecision(std::numeric_limits<double>::digits10);
380 for (auto it = data.begin(); it != data.end(); it++) {
381 ss << "[" << it->real() << ", " << it->imag() << "]";
382 if (it != data.end() - 1) {
383 ss << ", ";
384 }
385 }
386 ss << "]";
387 std::istringstream is(ss.str());
388 is >> ptxt;
389
390 for (std::size_t i = 0; i < ptxt.size(); ++i) {
391 EXPECT_NEAR(std::abs(ptxt[i] - data[i]), 0, pre_encryption_epsilon);
392 }
393 }
394
TEST_P(TestPtxtCKKS,deserializeFunctionThrowsIfMoreElementsThanSlots)395 TEST_P(TestPtxtCKKS, deserializeFunctionThrowsIfMoreElementsThanSlots)
396 {
397 std::vector<std::complex<double>> data(context.ea->size() + 1);
398 for (std::size_t i = 0; i < data.size(); ++i) {
399 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
400 }
401 helib::Ptxt<helib::CKKS> ptxt(context);
402 std::stringstream ss;
403 ss << "[";
404 ss << std::setprecision(std::numeric_limits<double>::digits10);
405 for (auto it = data.begin(); it != data.end(); it++) {
406 ss << "[" << it->real() << ", " << it->imag() << "]";
407 if (it != data.end() - 1) {
408 ss << ", ";
409 }
410 }
411 ss << "]";
412 std::istringstream is(ss.str());
413 EXPECT_THROW(helib::deserialize(is, ptxt), helib::IOError);
414 }
415
TEST_P(TestPtxtCKKS,rightShiftOperatorThrowsIfMoreElementsThanSlots)416 TEST_P(TestPtxtCKKS, rightShiftOperatorThrowsIfMoreElementsThanSlots)
417 {
418 std::vector<std::complex<double>> data(context.ea->size() + 1);
419 for (std::size_t i = 0; i < data.size(); ++i) {
420 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
421 }
422 helib::Ptxt<helib::CKKS> ptxt(context);
423 std::stringstream ss;
424 ss << "[";
425 ss << std::setprecision(std::numeric_limits<double>::digits10);
426 for (auto it = data.begin(); it != data.end(); it++) {
427 ss << "[" << it->real() << ", " << it->imag() << "]";
428 if (it != data.end() - 1) {
429 ss << ", ";
430 }
431 }
432 ss << "]";
433 std::istringstream is(ss.str());
434 EXPECT_THROW(is >> ptxt, helib::IOError);
435 }
436
TEST_P(TestPtxtCKKS,deserializeIsInverseOfSerialize)437 TEST_P(TestPtxtCKKS, deserializeIsInverseOfSerialize)
438 {
439 std::vector<std::complex<double>> data(context.ea->size());
440 for (std::size_t i = 0; i < data.size(); ++i) {
441 data[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
442 }
443 helib::Ptxt<helib::CKKS> ptxt(context, data);
444
445 std::stringstream str;
446 str << ptxt;
447
448 helib::Ptxt<helib::CKKS> deserialized(context);
449 str >> deserialized;
450
451 for (std::size_t i = 0; i < ptxt.size(); ++i) {
452 EXPECT_NEAR(std::abs(ptxt[i] - deserialized[i]), 0, pre_encryption_epsilon);
453 }
454 }
455
TEST_P(TestPtxtCKKS,readsManyPtxtsFromStream)456 TEST_P(TestPtxtCKKS, readsManyPtxtsFromStream)
457 {
458 std::vector<std::complex<double>> data1(context.ea->size());
459 std::vector<std::complex<double>> data2(context.ea->size());
460 std::vector<std::complex<double>> data3(context.ea->size());
461 for (std::size_t i = 0; i < data1.size(); ++i) {
462 data1[i] = {(i * i) / 10.0, (i * i * i) / 7.5};
463 data2[i] = data1[i] * 2.0;
464 data3[i] = data1[i] * 3.5;
465 }
466 helib::Ptxt<helib::CKKS> ptxt1(context, data1);
467 helib::Ptxt<helib::CKKS> ptxt2(context, data2);
468 helib::Ptxt<helib::CKKS> ptxt3(context, data3);
469
470 std::stringstream ss;
471 ss << ptxt1 << std::endl;
472 ss << ptxt2 << std::endl;
473 ss << ptxt3 << std::endl;
474
475 helib::Ptxt<helib::CKKS> deserialized1(context);
476 helib::Ptxt<helib::CKKS> deserialized2(context);
477 helib::Ptxt<helib::CKKS> deserialized3(context);
478 ss >> deserialized1;
479 ss >> deserialized2;
480 ss >> deserialized3;
481
482 for (std::size_t i = 0; i < ptxt1.size(); ++i) {
483 EXPECT_NEAR(std::abs(ptxt1[i] - deserialized1[i]),
484 0,
485 pre_encryption_epsilon);
486 EXPECT_NEAR(std::abs(ptxt2[i] - deserialized2[i]),
487 0,
488 pre_encryption_epsilon);
489 EXPECT_NEAR(std::abs(ptxt3[i] - deserialized3[i]),
490 0,
491 pre_encryption_epsilon);
492 }
493 }
494
TEST_P(TestPtxtCKKS,getSlotReprReturnsData)495 TEST_P(TestPtxtCKKS, getSlotReprReturnsData)
496 {
497 std::vector<std::complex<double>> data(context.ea->size() - 1);
498 for (long i = 0; i < helib::lsize(data); ++i) {
499 data[i] = {(i - 1) / 10.0, (i - 1) / 10.0};
500 }
501 helib::Ptxt<helib::CKKS> ptxt(context, data);
502 std::vector<std::complex<double>> expected_repr(context.ea->size());
503 for (std::size_t i = 0; i < ptxt.size(); ++i) {
504 expected_repr[i] = i < data.size() ? data[i] : 0;
505 }
506 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_repr);
507 }
508
TEST_P(TestPtxtCKKS,runningSumsWorksCorrectly)509 TEST_P(TestPtxtCKKS, runningSumsWorksCorrectly)
510 {
511 std::vector<std::complex<double>> data(context.ea->size());
512 for (std::size_t i = 0; i < data.size(); ++i) {
513 data[i] = {i / 1.0, (i * i) / 1.0};
514 }
515
516 helib::Ptxt<helib::CKKS> ptxt(context, data);
517 ptxt.runningSums();
518
519 std::vector<std::complex<double>> expected_result(context.ea->size());
520 for (std::size_t i = 0; i < data.size(); ++i)
521 expected_result[i] = {(i * (i + 1)) / 2.0,
522 (i * (i + 1) * (2 * i + 1)) / 6.0};
523
524 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
525 }
526
TEST_P(TestPtxtCKKS,totalSumsWorksCorrectly)527 TEST_P(TestPtxtCKKS, totalSumsWorksCorrectly)
528 {
529 std::vector<std::complex<double>> data(context.ea->size());
530 for (std::size_t i = 0; i < data.size(); ++i) {
531 data[i] = {i / 1.0, (i * i) / 1.0};
532 }
533
534 helib::Ptxt<helib::CKKS> ptxt(context, data);
535 ptxt.totalSums();
536
537 std::vector<std::complex<double>> expected_result(context.ea->size());
538 for (std::size_t i = 0; i < data.size(); ++i)
539 expected_result[i] = {
540 ((data.size() - 1) * data.size()) / 2.0,
541 ((data.size() - 1) * data.size() * (2 * (data.size() - 1) + 1)) / 6.0};
542
543 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
544 }
545
TEST_P(TestPtxtCKKS,incrementalProductWorksCorrectly)546 TEST_P(TestPtxtCKKS, incrementalProductWorksCorrectly)
547 {
548 std::vector<std::complex<double>> data(context.ea->size());
549 for (std::size_t i = 0; i < data.size(); ++i) {
550 data[i] = {(i + 1) / 5.0, (i * i + 1) / 10.0};
551 }
552
553 helib::Ptxt<helib::CKKS> ptxt(context, data);
554 ptxt.incrementalProduct();
555
556 std::vector<std::complex<double>> expected_result(data);
557 for (std::size_t i = 1; i < data.size(); ++i)
558 expected_result[i] *= expected_result[i - 1];
559
560 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result)
561 }
562
TEST_P(TestPtxtCKKS,totalProductWorksCorrectly)563 TEST_P(TestPtxtCKKS, totalProductWorksCorrectly)
564 {
565 std::vector<std::complex<double>> data(context.ea->size());
566 for (std::size_t i = 0; i < data.size(); ++i) {
567 data[i] = {(i + 1) / 10.0, (i * i + 1) / 10.0};
568 }
569
570 helib::Ptxt<helib::CKKS> ptxt(context, data);
571 ptxt.totalProduct();
572
573 std::complex<double> product = {1.0, 0.0};
574 for (std::size_t i = 0; i < data.size(); ++i)
575 product *= data[i];
576 std::vector<std::complex<double>> expected_result(context.ea->size(),
577 product);
578
579 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
580 }
581
TEST_P(TestPtxtCKKS,innerProductWorksCorrectly)582 TEST_P(TestPtxtCKKS, innerProductWorksCorrectly)
583 {
584 std::vector<std::complex<double>> data(context.ea->size());
585 for (std::size_t i = 0; i < data.size(); ++i) {
586 data[i] = {i / 1.0, (i * i) / 1.0};
587 }
588
589 helib::Ptxt<helib::CKKS> ptxt(context, data);
590 std::vector<helib::Ptxt<helib::CKKS>> first_ptxt_vector(2, ptxt);
591 ptxt += ptxt;
592 std::vector<helib::Ptxt<helib::CKKS>> second_ptxt_vector(3, ptxt);
593
594 helib::Ptxt<helib::CKKS> result(context);
595 innerProduct(result, first_ptxt_vector, second_ptxt_vector);
596
597 std::vector<std::complex<double>> expected_result(data.size());
598 for (std::size_t i = 0; i < data.size(); ++i) {
599 expected_result[i] = (data[i] * (data[i] + data[i]));
600 expected_result[i] +=
601 expected_result[i]; // expected_result = 2*expected_result
602 }
603
604 COMPARE_CXDOUBLE_VECS(result.getSlotRepr(), expected_result);
605 }
606
TEST_P(TestPtxtCKKS,mapTo01MapsSlotsCorrectly)607 TEST_P(TestPtxtCKKS, mapTo01MapsSlotsCorrectly)
608 {
609 std::vector<std::complex<double>> data(context.ea->size());
610 for (std::size_t i = 0; i < data.size(); ++i) {
611 data[i] = {i / 1.0, (i * i) / 1.0};
612 }
613
614 helib::Ptxt<helib::CKKS> ptxt(context, data);
615 helib::Ptxt<helib::CKKS> ptxt2(context, data);
616 // Should exist as a free function and a member function
617 ptxt.mapTo01();
618 mapTo01(*(context.ea), ptxt2);
619
620 std::vector<std::complex<double>> expected_result(context.ea->size());
621 for (std::size_t i = 1; i < data.size(); ++i) {
622 expected_result[i] = {1, 0};
623 }
624
625 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
626 COMPARE_CXDOUBLE_VECS(ptxt2.getSlotRepr(), expected_result);
627 }
628
TEST_P(TestPtxtCKKS,timesEqualsOtherPlaintextWorks)629 TEST_P(TestPtxtCKKS, timesEqualsOtherPlaintextWorks)
630 {
631 std::vector<std::complex<double>> product_data(context.ea->size(),
632 {-3.14, -1.0});
633 std::vector<std::complex<double>> multiplier_data(context.ea->size());
634 for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
635 multiplier_data[i] = {(i - 1) / 10.0, (i + 1) / 10.0};
636 }
637
638 std::vector<std::complex<double>> expected_result(product_data);
639 for (std::size_t i = 0; i < product_data.size(); ++i) {
640 expected_result[i] = expected_result[i] * multiplier_data[i];
641 }
642
643 helib::Ptxt<helib::CKKS> product(context, product_data);
644 helib::Ptxt<helib::CKKS> multiplier(context, multiplier_data);
645
646 product *= multiplier;
647
648 COMPARE_CXDOUBLE_VECS(product.getSlotRepr(), expected_result);
649 }
650
TEST_P(TestPtxtCKKS,minusEqualsOtherPlaintextWorks)651 TEST_P(TestPtxtCKKS, minusEqualsOtherPlaintextWorks)
652 {
653 std::vector<std::complex<double>> difference_data(context.ea->size(),
654 {2.718, -1.0});
655 std::vector<std::complex<double>> subtrahend_data(context.ea->size());
656 for (long i = 0; i < helib::lsize(subtrahend_data); ++i) {
657 subtrahend_data[i] = {(i - 1) / 10.0, (i + 1) / 10.0};
658 }
659
660 std::vector<std::complex<double>> expected_result(difference_data);
661 for (std::size_t i = 0; i < subtrahend_data.size(); ++i) {
662 expected_result[i] = expected_result[i] - subtrahend_data[i];
663 }
664
665 helib::Ptxt<helib::CKKS> difference(context, difference_data);
666 helib::Ptxt<helib::CKKS> subtrahend(context, subtrahend_data);
667
668 difference -= subtrahend;
669
670 COMPARE_CXDOUBLE_VECS(difference.getSlotRepr(), expected_result);
671 }
672
TEST_P(TestPtxtCKKS,minusEqualsComplexScalarWorks)673 TEST_P(TestPtxtCKKS, minusEqualsComplexScalarWorks)
674 {
675 std::vector<std::complex<double>> data(context.ea->size());
676 for (long i = 0; i < helib::lsize(data); ++i) {
677 data[i] = {(i * i - 1) / 10.0, (i * i + 1) / 10.0};
678 }
679
680 const std::complex<double> scalar = {2.5, -0.5};
681
682 std::vector<std::complex<double>> expected_result(data);
683 for (auto& num : expected_result)
684 num = num - scalar;
685
686 helib::Ptxt<helib::CKKS> ptxt(context, data);
687 ptxt -= scalar;
688
689 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
690 }
691
TEST_P(TestPtxtCKKS,minusEqualsNonComplexScalarWorks)692 TEST_P(TestPtxtCKKS, minusEqualsNonComplexScalarWorks)
693 {
694 std::vector<std::complex<double>> data(context.ea->size());
695 for (long i = 0; i < helib::lsize(data); ++i) {
696 data[i] = {(i * i - 1) / 2.0, (i * i + 1) / 5.0};
697 }
698
699 const double scalar = 15.3;
700 const int int_scalar = 2;
701
702 std::vector<std::complex<double>> expected_result(data);
703 for (auto& num : expected_result) {
704 num = num - scalar - static_cast<double>(int_scalar);
705 }
706
707 helib::Ptxt<helib::CKKS> ptxt(context, data);
708 ptxt -= scalar;
709 ptxt -= int_scalar;
710
711 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
712 }
713
TEST_P(TestPtxtCKKS,plusEqualsOtherPlaintextWorks)714 TEST_P(TestPtxtCKKS, plusEqualsOtherPlaintextWorks)
715 {
716 std::vector<std::complex<double>> augend_data(context.ea->size());
717 std::vector<std::complex<double>> addend_data(context.ea->size());
718 for (long i = 0; i < helib::lsize(addend_data); ++i) {
719 augend_data[i] = {i / 10.0, i * i / 10.0};
720 addend_data[i] = {i / 20.0, i * i / 20.0};
721 }
722 std::vector<std::complex<double>> expected_result(context.ea->size());
723 for (std::size_t i = 0; i < expected_result.size(); ++i)
724 expected_result[i] = augend_data[i] + addend_data[i];
725
726 helib::Ptxt<helib::CKKS> sum(context, augend_data);
727 helib::Ptxt<helib::CKKS> addend(context, addend_data);
728 sum += addend;
729
730 COMPARE_CXDOUBLE_VECS(sum.getSlotRepr(), expected_result);
731 }
732
TEST_P(TestPtxtCKKS,plusEqualsComplexScalarWorks)733 TEST_P(TestPtxtCKKS, plusEqualsComplexScalarWorks)
734 {
735 std::vector<std::complex<double>> data(context.ea->size());
736 for (long i = 0; i < helib::lsize(data); ++i) {
737 data[i] = {-i / 10.0, (3 - i) / 4.0};
738 }
739
740 const double scalar = 3.14;
741
742 std::vector<std::complex<double>> expected_result(data);
743 for (auto& num : expected_result)
744 num += scalar;
745
746 helib::Ptxt<helib::CKKS> ptxt(context, data);
747 ptxt += scalar;
748
749 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
750 }
751
TEST_P(TestPtxtCKKS,plusEqualsNonComplexScalarWorks)752 TEST_P(TestPtxtCKKS, plusEqualsNonComplexScalarWorks)
753 {
754 std::vector<std::complex<double>> data(context.ea->size());
755 for (long i = 0; i < helib::lsize(data); ++i) {
756 data[i] = {i + i / 5.0, i - i / 4.0};
757 }
758
759 const double scalar = 3.28;
760 const int int_scalar = 13;
761
762 std::vector<std::complex<double>> expected_result(data);
763 for (auto& num : expected_result)
764 num += scalar + static_cast<double>(int_scalar);
765
766 helib::Ptxt<helib::CKKS> ptxt(context, data);
767 ptxt += scalar;
768 ptxt += int_scalar;
769
770 const auto& slots = ptxt.getSlotRepr();
771 for (std::size_t i = 0; i < data.size(); ++i) {
772
773 EXPECT_DOUBLE_EQ(slots[i].real(), expected_result[i].real());
774 EXPECT_DOUBLE_EQ(slots[i].imag(), expected_result[i].imag());
775 }
776 }
777
TEST_P(TestPtxtCKKS,timesEqualsScalarWorks)778 TEST_P(TestPtxtCKKS, timesEqualsScalarWorks)
779 {
780 std::vector<std::complex<double>> data(context.ea->size());
781 for (long i = 0; i < helib::lsize(data); ++i) {
782 data[i] = {i * i * i / 100.0, -i / 3.0};
783 }
784
785 const double scalar = 10.28;
786 const int int_scalar = -2;
787
788 std::vector<std::complex<double>> expected_result(data);
789 for (auto& num : expected_result)
790 num *= scalar * static_cast<double>(int_scalar);
791
792 helib::Ptxt<helib::CKKS> ptxt(context, data);
793 ptxt *= scalar;
794 ptxt *= int_scalar;
795
796 const auto& slots = ptxt.getSlotRepr();
797 for (std::size_t i = 0; i < data.size(); ++i) {
798
799 EXPECT_DOUBLE_EQ(slots[i].real(), expected_result[i].real());
800 EXPECT_DOUBLE_EQ(slots[i].imag(), expected_result[i].imag());
801 }
802 }
803
TEST_P(TestPtxtCKKS,equalityWithOtherPlaintextWorks)804 TEST_P(TestPtxtCKKS, equalityWithOtherPlaintextWorks)
805 {
806 std::vector<std::complex<double>> data(context.ea->size());
807 for (long i = 0; i < helib::lsize(data); ++i) {
808 data[i] = {i * 2.5, (i - 2) * 2.5};
809 }
810 helib::Ptxt<helib::CKKS> ptxt1(context, data);
811 helib::Ptxt<helib::CKKS> ptxt2(context, data);
812 EXPECT_TRUE(ptxt1 == ptxt2);
813 EXPECT_FALSE(ptxt1 == helib::Ptxt<helib::CKKS>());
814 }
815
TEST_P(TestPtxtCKKS,notEqualsOperatorWithOtherPlaintextWorks)816 TEST_P(TestPtxtCKKS, notEqualsOperatorWithOtherPlaintextWorks)
817 {
818 std::vector<std::complex<double>> data1(context.ea->size());
819 std::vector<std::complex<double>> data2(context.ea->size());
820 for (long i = 0; i < helib::lsize(data1); ++i) {
821 data1[i] = {(i + 1) * 2.5,
822 -i * 2.5}; // i+1 makes the first element differ from (0,0)
823 data2[i] = {i * 2.5, i * 6.5};
824 }
825 helib::Ptxt<helib::CKKS> ptxt1(context, data1);
826 helib::Ptxt<helib::CKKS> ptxt2(context, data2);
827 EXPECT_TRUE(ptxt1 != ptxt2);
828 EXPECT_FALSE(ptxt1 != ptxt1);
829 }
830
TEST_P(TestPtxtCKKS,negateNegatesCorrectly)831 TEST_P(TestPtxtCKKS, negateNegatesCorrectly)
832 {
833 std::vector<std::complex<double>> data(context.ea->size());
834 const double pi = std::acos(-1);
835 for (long j = 0; j < helib::lsize(data); ++j) {
836 // Spiral with j -> j e^{2*i*pi*j/data.size()}
837 data[j] = std::complex<double>{static_cast<double>(j), 0} *
838 std::exp(std::complex<double>{0, 2.0 * pi * j / data.size()});
839 }
840
841 std::vector<std::complex<double>> expected_result(data);
842 for (auto& num : expected_result)
843 num *= std::complex<double>{-1.0, 0};
844
845 helib::Ptxt<helib::CKKS> ptxt(context, data);
846 ptxt.negate();
847
848 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
849 }
850
TEST_P(TestPtxtCKKS,addConstantWorksCorrectly)851 TEST_P(TestPtxtCKKS, addConstantWorksCorrectly)
852 {
853 std::vector<std::complex<double>> data(context.ea->size());
854 for (long i = 0; i < helib::lsize(data); ++i)
855 data[i] = {i * 4.5, i / 2.0};
856
857 std::vector<std::complex<double>> expected_result(data);
858 for (auto& num : expected_result)
859 (num += 5) += std::complex<double>{0, 0.5};
860
861 helib::Ptxt<helib::CKKS> ptxt(context, data);
862 ptxt.addConstantCKKS(5).addConstantCKKS(std::complex<double>{0, 0.5});
863
864 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
865 }
866
TEST_P(TestPtxtCKKS,multiplyByMultipliesCorrectly)867 TEST_P(TestPtxtCKKS, multiplyByMultipliesCorrectly)
868 {
869 std::vector<std::complex<double>> product_data(context.ea->size());
870 std::vector<std::complex<double>> multiplier_data(context.ea->size());
871 for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
872 product_data[i] = {(2 - i) / 10.0, (1 - i) / 10.0};
873 multiplier_data[i] = {std::exp(i / 100.), std::cos(i) * 12};
874 }
875
876 std::vector<std::complex<double>> expected_result(product_data);
877 for (std::size_t i = 0; i < product_data.size(); ++i) {
878 expected_result[i] = expected_result[i] * multiplier_data[i];
879 }
880
881 helib::Ptxt<helib::CKKS> product(context, product_data);
882 helib::Ptxt<helib::CKKS> multiplier(context, multiplier_data);
883
884 product.multiplyBy(multiplier);
885
886 // We use EXPECT_FLOAT_EQ as opposed to EXPECT_DOUBLE_EQ here to have a
887 // higher default threshold for precision.
888 COMPARE_CXFLOAT_VECS(product.getSlotRepr(), expected_result);
889 }
890
TEST_P(TestPtxtCKKS,multiplyBy2MultipliesCorrectly)891 TEST_P(TestPtxtCKKS, multiplyBy2MultipliesCorrectly)
892 {
893 std::vector<std::complex<double>> product_data(context.ea->size());
894 std::vector<std::complex<double>> multiplier_data1(context.ea->size());
895 std::vector<std::complex<double>> multiplier_data2(context.ea->size());
896 for (long i = 0; i < helib::lsize(multiplier_data1); ++i) {
897 product_data[i] = static_cast<double>(i) *
898 std::exp(std::complex<double>{0, static_cast<double>(i)});
899 multiplier_data2[i] =
900 static_cast<double>(i) *
901 std::exp(std::complex<double>{0, static_cast<double>(-i)});
902 ;
903 multiplier_data2[i] =
904 5.0 * std::exp(std::complex<double>{0, static_cast<double>(i)});
905 ;
906 ;
907 }
908
909 std::vector<std::complex<double>> expected_result(product_data);
910 for (std::size_t i = 0; i < product_data.size(); ++i) {
911 expected_result[i] =
912 expected_result[i] * multiplier_data1[i] * multiplier_data2[i];
913 }
914
915 helib::Ptxt<helib::CKKS> product(context, product_data);
916 helib::Ptxt<helib::CKKS> multiplier1(context, multiplier_data1);
917 helib::Ptxt<helib::CKKS> multiplier2(context, multiplier_data2);
918
919 product.multiplyBy2(multiplier1, multiplier2);
920
921 COMPARE_CXDOUBLE_VECS(product.getSlotRepr(), expected_result);
922 }
923
TEST_P(TestPtxtCKKS,squareSquaresCorrectly)924 TEST_P(TestPtxtCKKS, squareSquaresCorrectly)
925 {
926 std::vector<std::complex<double>> data(context.ea->size());
927 for (long i = 0; i < helib::lsize(data); ++i) {
928 // Lemniscate of Bernoulli
929 double theta = 2. * std::acos(-1) * i / data.size();
930 data[i] = std::cos(2. * theta) * std::exp(std::complex<double>{0, theta});
931 }
932 std::vector<std::complex<double>> expected_result(data);
933 for (auto& num : expected_result)
934 num *= num;
935 helib::Ptxt<helib::CKKS> ptxt(context, data);
936 ptxt.square();
937 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
938 }
939
TEST_P(TestPtxtCKKS,cubeCubesCorrectly)940 TEST_P(TestPtxtCKKS, cubeCubesCorrectly)
941 {
942 std::vector<std::complex<double>> data(context.ea->size());
943 for (long i = 0; i < helib::lsize(data); ++i) {
944 // Catenary
945 data[i] = {static_cast<double>(1. * i - data.size() / 2) / data.size(),
946 std::cosh((i - data.size() / 2.) / data.size())};
947 }
948 std::vector<std::complex<double>> expected_result(data);
949 for (auto& num : expected_result)
950 num = num * num * num;
951 helib::Ptxt<helib::CKKS> ptxt(context, data);
952 ptxt.cube();
953 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
954 }
955
TEST_P(TestPtxtCKKS,powerCorrectlyRaisesToPowers)956 TEST_P(TestPtxtCKKS, powerCorrectlyRaisesToPowers)
957 {
958 std::vector<std::complex<double>> data(context.ea->size());
959 const double pi = std::acos(-1);
960 // Spiral inside the unit disk
961 for (long j = 0; j < helib::lsize(data); ++j) {
962 data[j] = std::complex<double>{j / (double)data.size()} *
963 std::exp(std::complex<double>{0, 2.0 * pi * j / data.size()});
964 }
965 std::vector<long> exponents{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1500};
966
967 const auto naive_power = [](std::complex<double> base,
968 unsigned long exponent) {
969 if (exponent == 0)
970 return std::complex<double>{1.0};
971 auto result = base;
972 while (--exponent)
973 result *= base;
974 return result;
975 };
976
977 for (const auto& exponent : exponents) {
978 std::vector<std::complex<double>> expected_result(data);
979 for (auto& num : expected_result)
980 num = naive_power(num, exponent);
981 helib::Ptxt<helib::CKKS> ptxt(context, data);
982 ptxt.power(exponent);
983 for (std::size_t i = 0; i < ptxt.size(); ++i) {
984 EXPECT_NEAR(std::norm(ptxt[i] - expected_result[i]),
985 0,
986 pre_encryption_epsilon);
987 }
988 }
989
990 // Make sure raising to 0 throws
991 helib::Ptxt<helib::CKKS> ptxt(context, data);
992 EXPECT_THROW(ptxt.power(0l), helib::InvalidArgument);
993 }
994
TEST_P(TestPtxtCKKS,shiftShiftsRightCorrectly)995 TEST_P(TestPtxtCKKS, shiftShiftsRightCorrectly)
996 {
997 std::vector<std::complex<double>> data(context.ea->size());
998 std::vector<std::complex<double>> right_shifted_data(context.ea->size());
999 const auto non_neg_mod = [](int x, int mod) {
1000 return ((x % mod) + mod) % mod;
1001 };
1002 for (int i = 0; i < helib::lsize(data); ++i) {
1003 if (i > 3) {
1004 right_shifted_data[i] = {
1005 static_cast<double>(non_neg_mod(i - 3, data.size())),
1006 static_cast<double>(non_neg_mod(i - 3, data.size()))};
1007 }
1008 data[i] = {static_cast<double>(i), static_cast<double>(i)};
1009 }
1010 helib::Ptxt<helib::CKKS> ptxt(context, data);
1011
1012 ptxt.shift(3);
1013 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), right_shifted_data);
1014 }
1015
TEST_P(TestPtxtCKKS,shiftShiftsLeftCorrectly)1016 TEST_P(TestPtxtCKKS, shiftShiftsLeftCorrectly)
1017 {
1018 std::vector<std::complex<double>> data(context.ea->size());
1019 std::vector<std::complex<double>> left_shifted_data(context.ea->size());
1020 const auto non_neg_mod = [](int x, int mod) {
1021 return ((x % mod) + mod) % mod;
1022 };
1023 for (int i = 0; i < helib::lsize(data); ++i) {
1024 if (i < long(data.size()) - 3 && data.size() > 3) {
1025 left_shifted_data[i] = {
1026 static_cast<double>(non_neg_mod(i + 3, data.size())),
1027 static_cast<double>(non_neg_mod(i + 3, data.size()))};
1028 }
1029 data[i] = {static_cast<double>(i), static_cast<double>(i)};
1030 }
1031 helib::Ptxt<helib::CKKS> ptxt(context, data);
1032
1033 ptxt.shift(-3);
1034 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_shifted_data);
1035 }
1036
TEST_P(TestPtxtCKKS,shift1DShiftsRightCorrectly)1037 TEST_P(TestPtxtCKKS, shift1DShiftsRightCorrectly)
1038 {
1039 std::vector<std::complex<double>> data(context.ea->size());
1040 std::vector<std::complex<double>> right_shifted_data(context.ea->size());
1041 const auto non_neg_mod = [](int x, int mod) {
1042 return ((x % mod) + mod) % mod;
1043 };
1044 for (int i = 0; i < helib::lsize(data); ++i) {
1045 if (i > 3) {
1046 right_shifted_data[i] = {
1047 static_cast<double>(non_neg_mod(i - 3, data.size())),
1048 static_cast<double>(non_neg_mod(i - 3, data.size()))};
1049 }
1050 data[i] = {static_cast<double>(i), static_cast<double>(i)};
1051 }
1052 helib::Ptxt<helib::CKKS> ptxt(context, data);
1053
1054 ptxt.shift1D(0, 3);
1055 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), right_shifted_data);
1056 }
1057
TEST_P(TestPtxtCKKS,shift1DShiftsLeftCorrectly)1058 TEST_P(TestPtxtCKKS, shift1DShiftsLeftCorrectly)
1059 {
1060 std::vector<std::complex<double>> data(context.ea->size());
1061 std::vector<std::complex<double>> left_shifted_data(context.ea->size());
1062 const auto non_neg_mod = [](int x, int mod) {
1063 return ((x % mod) + mod) % mod;
1064 };
1065 for (int i = 0; i < helib::lsize(data); ++i) {
1066 if (i < long(data.size()) - 3 && data.size() > 3) {
1067 left_shifted_data[i] = {
1068 static_cast<double>(non_neg_mod(i + 3, data.size())),
1069 static_cast<double>(non_neg_mod(i + 3, data.size()))};
1070 }
1071 data[i] = {static_cast<double>(i), static_cast<double>(i)};
1072 }
1073 helib::Ptxt<helib::CKKS> ptxt(context, data);
1074
1075 ptxt.shift1D(0, -3);
1076 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_shifted_data);
1077 }
1078
1079 // These tests are disabled since the methods are private.
1080 // These can be useful if tweaking the logic of this area.
1081 // TEST(TestPtxtBGV, coord_To_Index_Works)
1082 // {
1083 // helib::Context context(32109, 4999, 1);
1084 // helib::Ptxt<helib::BGV> ptxt(context);
1085 // std::vector<long> indices;
1086 // for(long i=0; i<6; ++i)
1087 // for(long j=0; j<2; ++j)
1088 // for(long k=0; k<2; ++k)
1089 // indices.push_back(ptxt.coordToIndex({i,j,k}));
1090 // std::vector<long> expected_indices(context.ea->size());
1091 // std::iota(expected_indices.begin(), expected_indices.end(), 0);
1092 // EXPECT_EQ(expected_indices, indices);
1093 // }
1094 //
1095 // TEST(TestPtxtBGV, index_To_Coord_Works)
1096 // {
1097 // helib::Context context(32109, 4999, 1);
1098 // helib::Ptxt<helib::BGV> ptxt(context);
1099 // std::vector<std::vector<long>> coords;
1100 // for(std::size_t i=0; i<ptxt.size(); ++i)
1101 // coords.push_back(ptxt.indexToCoord(i));
1102 // std::vector<std::vector<long>> expected_coords;
1103 // for(long i=0; i<6; ++i)
1104 // for(long j=0; j<2; ++j)
1105 // for(long k=0; k<2; ++k)
1106 // expected_coords.push_back({i,j,k});
1107 // EXPECT_EQ(expected_coords, coords);
1108 // }
1109
TEST_P(TestPtxtCKKS,rotate1DRotatesCorrectly)1110 TEST_P(TestPtxtCKKS, rotate1DRotatesCorrectly)
1111 {
1112 std::vector<std::complex<double>> data(context.ea->size());
1113 std::vector<std::complex<double>> left_rotated_data(context.ea->size());
1114 const auto non_neg_mod = [](int x, int mod) {
1115 return ((x % mod) + mod) % mod;
1116 };
1117 for (int i = 0; i < helib::lsize(data); ++i) {
1118 data[i] = {static_cast<double>(non_neg_mod(i - 3, data.size())),
1119 static_cast<double>(non_neg_mod(i - 3, data.size()))};
1120 left_rotated_data[i] = {static_cast<double>(i), static_cast<double>(i)};
1121 }
1122 helib::Ptxt<helib::CKKS> ptxt(context, data);
1123
1124 ptxt.rotate1D(0, -3);
1125 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_rotated_data);
1126 ptxt.rotate1D(0, 3);
1127 // Rotating back and forth gives the original data back
1128 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), data);
1129 }
1130
TEST_P(TestPtxtCKKS,rotateRotatesCorrectly)1131 TEST_P(TestPtxtCKKS, rotateRotatesCorrectly)
1132 {
1133 std::vector<std::complex<double>> data(context.ea->size());
1134 std::vector<std::complex<double>> left_rotated_data(context.ea->size());
1135 const auto non_neg_mod = [](int x, int mod) {
1136 return ((x % mod) + mod) % mod;
1137 };
1138 for (int i = 0; i < helib::lsize(data); ++i) {
1139 data[i] = {static_cast<double>(non_neg_mod(i - 3, data.size())),
1140 static_cast<double>(non_neg_mod(i - 3, data.size()))};
1141 left_rotated_data[i] = {static_cast<double>(i), static_cast<double>(i)};
1142 }
1143 helib::Ptxt<helib::CKKS> ptxt(context, data);
1144
1145 ptxt.rotate(-3);
1146 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), left_rotated_data);
1147 ptxt.rotate(3);
1148 // Rotating back and forth gives the original data back
1149 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), data);
1150 }
1151
TEST_P(TestPtxtCKKS,automorphWorksCorrectly)1152 TEST_P(TestPtxtCKKS, automorphWorksCorrectly)
1153 {
1154 std::vector<std::complex<double>> data(context.ea->size());
1155 std::vector<std::complex<double>> left_rotated_data(context.ea->size());
1156 const auto non_neg_mod = [](int x, int mod) {
1157 return ((x % mod) + mod) % mod;
1158 };
1159 for (int i = 0; i < helib::lsize(data); ++i) {
1160 data[i] = {static_cast<double>(non_neg_mod(i - 3, data.size())),
1161 static_cast<double>(non_neg_mod(i - 3, data.size()))};
1162 }
1163 helib::Ptxt<helib::CKKS> ptxt(context, data);
1164 helib::Ptxt<helib::CKKS> expected_result(context, data);
1165
1166 long k = context.zMStar.ith_rep(1) ? context.zMStar.ith_rep(1) : 1;
1167 ptxt.automorph(k);
1168 expected_result.rotate(1);
1169 COMPARE_CXDOUBLE_VECS(ptxt, expected_result);
1170
1171 ptxt.automorph(context.zMStar.ith_rep(context.ea->size() - 1));
1172 expected_result.rotate(-1);
1173 COMPARE_CXDOUBLE_VECS(ptxt, expected_result);
1174 }
1175
TEST_P(TestPtxtCKKS,replicateReplicatesCorrectly)1176 TEST_P(TestPtxtCKKS, replicateReplicatesCorrectly)
1177 {
1178 std::vector<std::complex<double>> data(context.ea->size());
1179 for (long i = 0; i < helib::lsize(data); ++i) {
1180 data[i] = {i / 10.0, -i / 20.0};
1181 }
1182 helib::Ptxt<helib::CKKS> ptxt(context, data);
1183 helib::replicate(*context.ea, ptxt, data.size() - 1);
1184 std::vector<std::complex<double>> replicated_data(context.ea->size(),
1185 data[data.size() - 1]);
1186 COMPARE_CXDOUBLE_VECS(ptxt, replicated_data);
1187 }
1188
TEST_P(TestPtxtCKKS,replicateAllWorksCorrectly)1189 TEST_P(TestPtxtCKKS, replicateAllWorksCorrectly)
1190 {
1191 std::vector<std::complex<double>> data(context.ea->size());
1192 for (long i = 0; i < helib::lsize(data); ++i) {
1193 data[i] = {i / 10.0, -i / 20.0};
1194 }
1195 helib::Ptxt<helib::CKKS> ptxt(context, data);
1196 std::vector<helib::Ptxt<helib::CKKS>> replicated_ptxts = ptxt.replicateAll();
1197 for (long i = 0; i < helib::lsize(data); ++i) {
1198 for (const auto& slot : replicated_ptxts[i].getSlotRepr()) {
1199 EXPECT_DOUBLE_EQ(data[i].real(), slot.real());
1200 EXPECT_DOUBLE_EQ(data[i].imag(), slot.imag());
1201 }
1202 }
1203 }
1204
TEST_P(TestPtxtCKKS,randomSetsDataRandomly)1205 TEST_P(TestPtxtCKKS, randomSetsDataRandomly)
1206 {
1207 helib::Ptxt<helib::CKKS> ptxt(context);
1208 ptxt.random();
1209 std::vector<helib::Ptxt<helib::CKKS>> ptxts(5, ptxt);
1210 for (auto& p : ptxts)
1211 p.random();
1212
1213 bool all_equal = true;
1214 for (std::size_t i = 0; i < ptxts.size() - 1; ++i)
1215 if (ptxts[i] != ptxts[i + 1]) {
1216 all_equal = false;
1217 break;
1218 }
1219 EXPECT_FALSE(all_equal) << "5 random ptxts are all equal - likely that"
1220 " random() is not actually randomising!";
1221 }
1222
TEST_P(TestPtxtCKKS,complexConjCorrectlyConjugates)1223 TEST_P(TestPtxtCKKS, complexConjCorrectlyConjugates)
1224 {
1225 std::vector<std::complex<double>> data(context.ea->size());
1226 const std::complex<double> z{1, -1};
1227 for (long j = 0; j < helib::lsize(data); ++j) {
1228 // Line segment starting at 1 - i with gradient 2
1229 data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1230 }
1231
1232 std::vector<std::complex<double>> expected_result(data);
1233
1234 helib::Ptxt<helib::CKKS> ptxt(context, data);
1235 ptxt.complexConj();
1236
1237 for (auto& num : expected_result)
1238 num = std::conj(num);
1239
1240 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1241 }
1242
TEST_P(TestPtxtCKKS,extractRealPartIsCorrect)1243 TEST_P(TestPtxtCKKS, extractRealPartIsCorrect)
1244 {
1245 std::vector<std::complex<double>> data(context.ea->size());
1246 const std::complex<double> z{1, -1};
1247 for (long j = 0; j < helib::lsize(data); ++j) {
1248 // Line segment starting at 1 - i with gradient 2
1249 data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1250 }
1251
1252 std::vector<std::complex<double>> expected_result(data);
1253 for (auto& num : expected_result)
1254 num = std::real(num);
1255
1256 helib::Ptxt<helib::CKKS> ptxt(context, data);
1257 context.ea->getCx().extractRealPart(ptxt);
1258
1259 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1260 }
1261
TEST_P(TestPtxtCKKS,extractImPartIsCorrect)1262 TEST_P(TestPtxtCKKS, extractImPartIsCorrect)
1263 {
1264 std::vector<std::complex<double>> data(context.ea->size());
1265 const std::complex<double> z{1, -1};
1266 for (long j = 0; j < helib::lsize(data); ++j) {
1267 // Line segment starting at 1 - i with gradient 2
1268 data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1269 }
1270
1271 std::vector<std::complex<double>> expected_result(data);
1272 for (auto& num : expected_result)
1273 num = std::imag(num);
1274
1275 helib::Ptxt<helib::CKKS> ptxt(context, data);
1276 context.ea->getCx().extractImPart(ptxt);
1277
1278 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1279 }
1280
TEST_P(TestPtxtCKKS,realExtractsRealPart)1281 TEST_P(TestPtxtCKKS, realExtractsRealPart)
1282 {
1283 std::vector<std::complex<double>> data(context.ea->size());
1284 const std::complex<double> z{1, -1};
1285 for (long j = 0; j < helib::lsize(data); ++j) {
1286 // Line segment starting at 1 - i with gradient 2
1287 data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1288 }
1289
1290 std::vector<std::complex<double>> expected_result(data);
1291 for (auto& num : expected_result)
1292 num = std::real(num);
1293
1294 helib::Ptxt<helib::CKKS> ptxt(context, data);
1295 ptxt = ptxt.real();
1296
1297 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1298 }
1299
TEST_P(TestPtxtCKKS,imagExtractsImaginaryPart)1300 TEST_P(TestPtxtCKKS, imagExtractsImaginaryPart)
1301 {
1302 std::vector<std::complex<double>> data(context.ea->size());
1303 const std::complex<double> z{1, -1};
1304 for (long j = 0; j < helib::lsize(data); ++j) {
1305 // Line segment starting at 1 - i with gradient 2
1306 data[j] = z + std::complex<double>{j / 4.0, j / 2.0};
1307 }
1308
1309 std::vector<std::complex<double>> expected_result(data);
1310 for (auto& num : expected_result)
1311 num = std::imag(num);
1312
1313 helib::Ptxt<helib::CKKS> ptxt(context, data);
1314 ptxt = ptxt.imag();
1315
1316 COMPARE_CXDOUBLE_VECS(ptxt.getSlotRepr(), expected_result);
1317 }
1318
TEST_P(TestPtxtCKKS,canEncryptAndDecryptComplexPtxtsWithKeys)1319 TEST_P(TestPtxtCKKS, canEncryptAndDecryptComplexPtxtsWithKeys)
1320 {
1321 helib::buildModChain(context, 100, 2);
1322 helib::SecKey secret_key(context);
1323 secret_key.GenSecKey();
1324 const helib::PubKey& public_key(secret_key);
1325
1326 std::vector<std::complex<double>> data(context.ea->size());
1327 for (long i = 0; i < helib::lsize(data); ++i) {
1328 data[i] = {(i - 3) / 5.0, i + 10.0};
1329 }
1330 helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1331 helib::Ctxt ctxt(public_key);
1332
1333 public_key.Encrypt(ctxt, pre_encryption);
1334
1335 helib::Ptxt<helib::CKKS> post_decryption(context);
1336 secret_key.Decrypt(post_decryption, ctxt);
1337 EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1338 for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1339 EXPECT_NEAR(std::norm(pre_encryption[i] - post_decryption[i]),
1340 0,
1341 post_encryption_epsilon);
1342 }
1343 }
1344
TEST_P(TestPtxtCKKS,canEncryptAndDecryptRealPtxtsWithKeys)1345 TEST_P(TestPtxtCKKS, canEncryptAndDecryptRealPtxtsWithKeys)
1346 {
1347 helib::buildModChain(context, 100, 2);
1348 helib::SecKey secret_key(context);
1349 secret_key.GenSecKey();
1350 const helib::PubKey& public_key(secret_key);
1351
1352 std::vector<double> data(context.ea->size());
1353 for (long i = 0; i < helib::lsize(data); ++i) {
1354 data[i] = (i - 3) / 10.0;
1355 }
1356 helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1357 helib::Ctxt ctxt(public_key);
1358
1359 public_key.Encrypt(ctxt, pre_encryption);
1360
1361 helib::Ptxt<helib::CKKS> post_decryption(context);
1362 secret_key.Decrypt(post_decryption, ctxt);
1363 EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1364 for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1365 EXPECT_NEAR(pre_encryption[i].real(),
1366 post_decryption[i].real(),
1367 post_encryption_epsilon);
1368 EXPECT_NEAR(pre_encryption[i].imag(),
1369 post_decryption[i].imag(),
1370 post_encryption_epsilon);
1371 }
1372 }
1373
TEST_P(TestPtxtCKKS,canEncryptAndDecryptComplexPtxtsWithEa)1374 TEST_P(TestPtxtCKKS, canEncryptAndDecryptComplexPtxtsWithEa)
1375 {
1376 helib::buildModChain(context, 100, 2);
1377 helib::SecKey secret_key(context);
1378 secret_key.GenSecKey();
1379 const helib::PubKey& public_key(secret_key);
1380
1381 std::vector<std::complex<double>> data(context.ea->size());
1382 for (long i = 0; i < helib::lsize(data); ++i) {
1383 data[i] = {(i - 3) / 10.0, i + 5.0};
1384 }
1385 helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1386 helib::Ctxt ctxt(public_key);
1387
1388 public_key.Encrypt(ctxt, pre_encryption);
1389
1390 helib::Ptxt<helib::CKKS> post_decryption(context);
1391 secret_key.Decrypt(post_decryption, ctxt);
1392 EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1393 for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1394 EXPECT_NEAR(std::norm(pre_encryption[i] - post_decryption[i]),
1395 0,
1396 post_encryption_epsilon);
1397 }
1398 }
1399
TEST_P(TestPtxtCKKS,canEncryptAndDecryptRealPtxtsWithEa)1400 TEST_P(TestPtxtCKKS, canEncryptAndDecryptRealPtxtsWithEa)
1401 {
1402 helib::buildModChain(context, 100, 2);
1403 const helib::EncryptedArrayCx& ea = context.ea->getCx();
1404 helib::SecKey secret_key(context);
1405 secret_key.GenSecKey();
1406 const helib::PubKey& public_key(secret_key);
1407
1408 std::vector<double> data(context.ea->size());
1409 for (long i = 0; i < helib::lsize(data); ++i) {
1410 data[i] = (i - 3) / 10.0;
1411 }
1412 helib::Ptxt<helib::CKKS> pre_encryption(context, data);
1413 helib::Ctxt ctxt(public_key);
1414
1415 public_key.Encrypt(ctxt, pre_encryption);
1416
1417 helib::Ptxt<helib::CKKS> post_decryption(context);
1418 ea.decrypt(ctxt, secret_key, post_decryption);
1419 EXPECT_EQ(pre_encryption.size(), post_decryption.size());
1420 for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
1421 EXPECT_NEAR(pre_encryption[i].real(),
1422 post_decryption[i].real(),
1423 post_encryption_epsilon);
1424 EXPECT_NEAR(pre_encryption[i].imag(),
1425 post_decryption[i].imag(),
1426 post_encryption_epsilon);
1427 }
1428 }
1429
TEST_P(TestPtxtCKKS,plusEqualsWithCiphertextWorks)1430 TEST_P(TestPtxtCKKS, plusEqualsWithCiphertextWorks)
1431 {
1432 helib::buildModChain(context, 150, 2);
1433 helib::SecKey secret_key(context);
1434 secret_key.GenSecKey();
1435 const helib::PubKey& public_key(secret_key);
1436
1437 // Encrypt the augend, addend is plaintext
1438 std::vector<std::complex<double>> augend_data(context.ea->size());
1439 std::vector<std::complex<double>> addend_data(context.ea->size());
1440 for (long i = 0; i < helib::lsize(augend_data); ++i) {
1441 augend_data[i] = {i / 10.0, -i * i / 63.0};
1442 addend_data[i] = {-i / 20.0, i * i * i * 2.6};
1443 }
1444 helib::Ptxt<helib::CKKS> augend_ptxt(context, augend_data);
1445 helib::Ptxt<helib::CKKS> addend(context, addend_data);
1446 helib::Ctxt augend(public_key);
1447 public_key.Encrypt(augend, augend_ptxt);
1448
1449 augend += addend;
1450 augend_ptxt += addend;
1451
1452 // augend_ptxt and augend should now match
1453 helib::Ptxt<helib::CKKS> result(context);
1454 secret_key.Decrypt(result, augend);
1455 EXPECT_EQ(result.size(), augend_ptxt.size());
1456 for (std::size_t i = 0; i < result.size(); ++i) {
1457 EXPECT_NEAR(std::norm(result[i] - augend_ptxt[i]),
1458 0,
1459 post_encryption_epsilon);
1460 }
1461 }
1462
TEST_P(TestPtxtCKKS,addConstantCKKSWithCiphertextWorks)1463 TEST_P(TestPtxtCKKS, addConstantCKKSWithCiphertextWorks)
1464 {
1465 helib::buildModChain(context, 150, 2);
1466 helib::SecKey secret_key(context);
1467 secret_key.GenSecKey();
1468 const helib::PubKey& public_key(secret_key);
1469
1470 // Encrypt the augend, addend is plaintext
1471 std::vector<std::complex<double>> augend_data(context.ea->size());
1472 std::vector<std::complex<double>> addend_data(context.ea->size());
1473 for (long i = 0; i < helib::lsize(augend_data); ++i) {
1474 augend_data[i] = {i / 70.0, -i * 10.5};
1475 addend_data[i] = {-i / 10.0, i * 0.8};
1476 }
1477 helib::Ptxt<helib::CKKS> augend_ptxt(context, augend_data);
1478 helib::Ptxt<helib::CKKS> addend(context, addend_data);
1479 helib::Ctxt augend(public_key);
1480 public_key.Encrypt(augend, augend_ptxt);
1481
1482 augend.addConstantCKKS(addend);
1483 augend_ptxt += addend;
1484
1485 // augend_ptxt and augend should now match
1486 helib::Ptxt<helib::CKKS> result(context);
1487 secret_key.Decrypt(result, augend);
1488 EXPECT_EQ(result.size(), augend_ptxt.size());
1489 for (std::size_t i = 0; i < result.size(); ++i) {
1490 EXPECT_NEAR(std::norm(result[i] - augend_ptxt[i]),
1491 0,
1492 post_encryption_epsilon);
1493 }
1494 }
1495
TEST_P(TestPtxtCKKS,minusEqualsWithCiphertextWorks)1496 TEST_P(TestPtxtCKKS, minusEqualsWithCiphertextWorks)
1497 {
1498 helib::buildModChain(context, 150, 2);
1499 helib::SecKey secret_key(context);
1500 secret_key.GenSecKey();
1501 const helib::PubKey& public_key(secret_key);
1502
1503 // Encrypt the minuend, subtrahend is plaintext
1504 std::vector<std::complex<double>> minuend_data(context.ea->size());
1505 std::vector<std::complex<double>> subtrahend_data(context.ea->size());
1506 for (long i = 0; i < helib::lsize(minuend_data); ++i) {
1507 minuend_data[i] = {i * i / 30.0, i * i / 4.5};
1508 subtrahend_data[i] = {(i + 3) / 4.0, -i * i / 1.3};
1509 }
1510 helib::Ptxt<helib::CKKS> minuend_ptxt(context, minuend_data);
1511 helib::Ptxt<helib::CKKS> subtrahend(context, subtrahend_data);
1512 helib::Ctxt minuend(public_key);
1513 public_key.Encrypt(minuend, minuend_ptxt);
1514
1515 minuend -= subtrahend;
1516 minuend_ptxt -= subtrahend;
1517
1518 // minuend_ptxt and minuend should now match
1519 helib::Ptxt<helib::CKKS> result(context);
1520 secret_key.Decrypt(result, minuend);
1521 EXPECT_EQ(result.size(), minuend_ptxt.size());
1522 for (std::size_t i = 0; i < result.size(); ++i) {
1523 EXPECT_NEAR(std::norm(result[i] - minuend_ptxt[i]),
1524 0,
1525 post_encryption_epsilon);
1526 }
1527 }
1528
TEST_P(TestPtxtCKKS,multByConstantCKKSFromCiphertextWorks)1529 TEST_P(TestPtxtCKKS, multByConstantCKKSFromCiphertextWorks)
1530 {
1531 helib::buildModChain(context, 150, 2);
1532 helib::SecKey secret_key(context);
1533 secret_key.GenSecKey();
1534 const helib::PubKey& public_key(secret_key);
1535
1536 // Encrypt the multiplier, multiplicand is plaintext
1537 std::vector<std::complex<double>> multiplier_data(context.ea->size());
1538 std::vector<std::complex<double>> multiplicand_data(context.ea->size());
1539 for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
1540 multiplier_data[i] = {i * 4.5, -i * i / 12.5};
1541 multiplicand_data[i] = {(i - 2.5) / 3.5, i * 4.2};
1542 }
1543 helib::Ptxt<helib::CKKS> multiplier_ptxt(context, multiplier_data);
1544 helib::Ptxt<helib::CKKS> multiplicand(context, multiplicand_data);
1545
1546 helib::Ctxt multiplier(public_key);
1547 public_key.Encrypt(multiplier, multiplier_ptxt);
1548
1549 multiplier.multByConstantCKKS(multiplicand);
1550 multiplier_ptxt *= multiplicand;
1551
1552 // multiplier_ptxt and multiplier should now match
1553 helib::Ptxt<helib::CKKS> result(context);
1554 secret_key.Decrypt(result, multiplier);
1555 EXPECT_EQ(result.size(), multiplier_ptxt.size());
1556 for (std::size_t i = 0; i < result.size(); ++i) {
1557 EXPECT_NEAR(std::norm(result[i] - multiplier_ptxt[i]),
1558 0,
1559 post_encryption_epsilon);
1560 }
1561 }
1562
TEST_P(TestPtxtCKKS,timesEqualsFromCiphertextWorks)1563 TEST_P(TestPtxtCKKS, timesEqualsFromCiphertextWorks)
1564 {
1565 helib::buildModChain(context, 150, 2);
1566 helib::SecKey secret_key(context);
1567 secret_key.GenSecKey();
1568 const helib::PubKey& public_key(secret_key);
1569
1570 // Encrypt the multiplier, multiplicand is plaintext
1571 std::vector<std::complex<double>> multiplier_data(context.ea->size());
1572 std::vector<std::complex<double>> multiplicand_data(context.ea->size());
1573 for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
1574 multiplier_data[i] = {i * 4.5, -i * i / 3.3};
1575 multiplicand_data[i] = {(i - 2.5) / 3.5, i * i / 12.4};
1576 }
1577 helib::Ptxt<helib::CKKS> multiplier_ptxt(context, multiplier_data);
1578 helib::Ptxt<helib::CKKS> multiplicand(context, multiplicand_data);
1579
1580 helib::Ctxt multiplier(public_key);
1581 public_key.Encrypt(multiplier, multiplier_ptxt);
1582
1583 multiplier *= multiplicand;
1584 multiplier_ptxt *= multiplicand;
1585
1586 // multiplier_ptxt and multiplier should now match
1587 helib::Ptxt<helib::CKKS> result(context);
1588 secret_key.Decrypt(result, multiplier);
1589 EXPECT_EQ(result.size(), multiplier_ptxt.size());
1590 for (std::size_t i = 0; i < result.size(); ++i) {
1591 EXPECT_NEAR(std::norm(result[i] - multiplier_ptxt[i]),
1592 0,
1593 post_encryption_epsilon);
1594 }
1595 }
1596
TEST_P(TestPtxtCKKS,plusOperatorWithOtherPtxtWorks)1597 TEST_P(TestPtxtCKKS, plusOperatorWithOtherPtxtWorks)
1598 {
1599 std::vector<std::complex<double>> augend_data(context.ea->size());
1600 std::vector<std::complex<double>> addend_data(context.ea->size());
1601 std::vector<std::complex<double>> expected_sum_data(context.ea->size());
1602 for (long i = 0; i < helib::lsize(augend_data); ++i) {
1603 augend_data[i] = {i / 10.0, -i * i / 3.0};
1604 addend_data[i] = {-i / 20.0, i * i * i * 42.6};
1605 expected_sum_data[i] = augend_data[i] + addend_data[i];
1606 }
1607 helib::Ptxt<helib::CKKS> augend(context, augend_data);
1608 helib::Ptxt<helib::CKKS> addend(context, addend_data);
1609 helib::Ptxt<helib::CKKS> sum;
1610
1611 sum = augend + addend;
1612
1613 COMPARE_CXDOUBLE_VECS(expected_sum_data, sum);
1614 }
1615
TEST_P(TestPtxtCKKS,minusOperatorWithOtherPtxtWorks)1616 TEST_P(TestPtxtCKKS, minusOperatorWithOtherPtxtWorks)
1617 {
1618 std::vector<std::complex<double>> minuend_data(context.ea->size());
1619 std::vector<std::complex<double>> subtrahend_data(context.ea->size());
1620 std::vector<std::complex<double>> expected_diff_data(context.ea->size());
1621 for (long i = 0; i < helib::lsize(minuend_data); ++i) {
1622 minuend_data[i] = {i / 10.0, -i * i / 3.0};
1623 subtrahend_data[i] = {-i / 20.0, i * i * i * 42.6};
1624 expected_diff_data[i] = minuend_data[i] - subtrahend_data[i];
1625 }
1626 helib::Ptxt<helib::CKKS> minuend(context, minuend_data);
1627 helib::Ptxt<helib::CKKS> subtrahend(context, subtrahend_data);
1628 helib::Ptxt<helib::CKKS> diff;
1629
1630 diff = minuend - subtrahend;
1631
1632 COMPARE_CXDOUBLE_VECS(expected_diff_data, diff);
1633 }
1634
TEST_P(TestPtxtCKKS,timesOperatorWithOtherPtxtWorks)1635 TEST_P(TestPtxtCKKS, timesOperatorWithOtherPtxtWorks)
1636 {
1637 std::vector<std::complex<double>> multiplier_data(context.ea->size());
1638 std::vector<std::complex<double>> multiplicand_data(context.ea->size());
1639 std::vector<std::complex<double>> expected_product_data(context.ea->size());
1640 for (long i = 0; i < helib::lsize(multiplier_data); ++i) {
1641 multiplier_data[i] = {i / 10.0, -i * i / 3.0};
1642 multiplicand_data[i] = {-i / 20.0, i * i * i * 42.6};
1643 expected_product_data[i] = multiplier_data[i] * multiplicand_data[i];
1644 }
1645 helib::Ptxt<helib::CKKS> multiplier(context, multiplier_data);
1646 helib::Ptxt<helib::CKKS> multiplicand(context, multiplicand_data);
1647 helib::Ptxt<helib::CKKS> product;
1648
1649 product = multiplier * multiplicand;
1650
1651 COMPARE_CXDOUBLE_VECS(expected_product_data, product);
1652 }
1653
1654 class TestPtxtBGV : public ::testing::TestWithParam<BGVParameters>
1655 {
1656 protected:
TestPtxtBGV()1657 TestPtxtBGV() :
1658 m(GetParam().m),
1659 p(GetParam().p),
1660 r(GetParam().r),
1661 ppowr(power(p, r)),
1662 context(m, p, r)
1663 {}
1664
power(long base,unsigned long exponent)1665 static long power(long base, unsigned long exponent)
1666 {
1667 long result = base;
1668 while (--exponent)
1669 result *= base;
1670 return result;
1671 }
1672
1673 const unsigned long m;
1674 const unsigned long p;
1675 const unsigned long r;
1676 const unsigned long ppowr;
1677
1678 helib::Context context;
1679 };
1680
TEST_P(TestPtxtBGV,canBeConstructedWithBGVContext)1681 TEST_P(TestPtxtBGV, canBeConstructedWithBGVContext)
1682 {
1683 helib::Ptxt<helib::BGV> ptxt(context);
1684 }
1685
TEST_P(TestPtxtBGV,canBeDefaultConstructed)1686 TEST_P(TestPtxtBGV, canBeDefaultConstructed) { helib::Ptxt<helib::BGV> ptxt; }
1687
TEST_P(TestPtxtBGV,canBeCopyConstructed)1688 TEST_P(TestPtxtBGV, canBeCopyConstructed)
1689 {
1690 helib::Ptxt<helib::BGV> ptxt(context);
1691 helib::Ptxt<helib::BGV> ptxt2(ptxt);
1692 }
1693
TEST_P(TestPtxtBGV,canBeAssignedFromOtherPtxt)1694 TEST_P(TestPtxtBGV, canBeAssignedFromOtherPtxt)
1695 {
1696 helib::Ptxt<helib::BGV> ptxt(context);
1697 helib::Ptxt<helib::BGV> ptxt2 = ptxt;
1698 }
1699
TEST_P(TestPtxtBGV,reportsWhetherItIsValid)1700 TEST_P(TestPtxtBGV, reportsWhetherItIsValid)
1701 {
1702 helib::Ptxt<helib::BGV> invalid_ptxt;
1703 helib::Ptxt<helib::BGV> valid_ptxt(context);
1704 EXPECT_FALSE(invalid_ptxt.isValid());
1705 EXPECT_TRUE(valid_ptxt.isValid());
1706 }
1707
TEST_P(TestPtxtBGV,preservesLongDataPassedIntoConstructor)1708 TEST_P(TestPtxtBGV, preservesLongDataPassedIntoConstructor)
1709 {
1710 std::vector<long> data(context.ea->size());
1711 std::iota(data.begin(), data.end(), 0);
1712 helib::Ptxt<helib::BGV> ptxt(context, data);
1713 EXPECT_EQ(ptxt.size(), data.size());
1714 for (std::size_t i = 0; i < data.size(); ++i) {
1715 EXPECT_EQ(ptxt[i], data[i]);
1716 }
1717 }
1718
TEST_P(TestPtxtBGV,preservesCoefficientVectorDataPassedIntoConstructor)1719 TEST_P(TestPtxtBGV, preservesCoefficientVectorDataPassedIntoConstructor)
1720 {
1721 std::vector<std::vector<long>> data(context.ea->size());
1722 for (std::size_t i = 0; i < data.size(); ++i) {
1723 data[i] = {1};
1724 }
1725 helib::Ptxt<helib::BGV> ptxt(context, data);
1726 EXPECT_EQ(ptxt.size(), data.size());
1727 for (std::size_t i = 0; i < data.size(); ++i) {
1728 EXPECT_EQ(ptxt[i], data[i]);
1729 }
1730 }
1731
TEST_P(TestPtxtBGV,preservesZzxDataPassedIntoConstructor)1732 TEST_P(TestPtxtBGV, preservesZzxDataPassedIntoConstructor)
1733 {
1734 std::vector<NTL::ZZX> data(context.ea->size());
1735 std::iota(data.begin(), data.end(), 0);
1736 helib::Ptxt<helib::BGV> ptxt(context, data);
1737 EXPECT_EQ(ptxt.size(), data.size());
1738 for (std::size_t i = 0; i < data.size(); ++i) {
1739 EXPECT_EQ(ptxt[i], data[i]);
1740 }
1741 }
1742
TEST_P(TestPtxtBGV,writesDataCorrectlyToOstream)1743 TEST_P(TestPtxtBGV, writesDataCorrectlyToOstream)
1744 {
1745 const long p2r = context.slotRing->p2r;
1746 const long d = context.zMStar.getOrdP();
1747 helib::PolyMod poly(context.slotRing);
1748 std::vector<helib::PolyMod> data(context.ea->size(), poly);
1749 std::stringstream ss;
1750 ss << "[";
1751 for (long i = 0; i < helib::lsize(data); ++i) {
1752 NTL::ZZX input;
1753 NTL::SetCoeff(input, 0, i % p2r);
1754 if (d != 1) {
1755 NTL::SetCoeff(input, 1, (i + 2) % p2r);
1756 }
1757 data[i] = input;
1758 // Serialisation of data[i] (i.e. PolyMod) is tested in `TestPolyMod.cpp`
1759 ss << data[i] << (i != helib::lsize(data) - 1 ? ", " : "");
1760 }
1761 ss << "]";
1762 helib::Ptxt<helib::BGV> ptxt(context, data);
1763 std::string expected = ss.str();
1764 std::ostringstream os;
1765 os << ptxt;
1766
1767 EXPECT_EQ(os.str(), expected);
1768 }
1769
TEST_P(TestPtxtBGV,readsDataCorrectlyFromIstream)1770 TEST_P(TestPtxtBGV, readsDataCorrectlyFromIstream)
1771 {
1772 helib::PolyMod poly(context.slotRing);
1773 std::vector<helib::PolyMod> data(context.ea->size(), poly);
1774 for (long i = 0; i < helib::lsize(data); ++i) {
1775 data[i] = {i, i + 2};
1776 }
1777 helib::Ptxt<helib::BGV> ptxt(context);
1778 std::stringstream ss;
1779 ss << "[";
1780 for (auto it = data.begin(); it != data.end(); it++) {
1781 ss << *it;
1782 if (it != data.end() - 1) {
1783 ss << ", ";
1784 }
1785 }
1786 ss << "]";
1787
1788 ss >> ptxt;
1789
1790 for (std::size_t i = 0; i < ptxt.size(); ++i) {
1791 EXPECT_EQ(ptxt[i], data[i]);
1792 }
1793 }
1794
TEST_P(TestPtxtBGV,deserializeIsInverseOfSerialize)1795 TEST_P(TestPtxtBGV, deserializeIsInverseOfSerialize)
1796 {
1797 helib::PolyMod poly(context.slotRing);
1798 std::vector<helib::PolyMod> data(context.ea->size(), poly);
1799 for (long i = 0; i < helib::lsize(data); ++i) {
1800 data[i] = {i, i + 2};
1801 }
1802 helib::Ptxt<helib::BGV> ptxt(context);
1803
1804 std::stringstream str;
1805 str << ptxt;
1806
1807 helib::Ptxt<helib::BGV> deserialized(context);
1808 str >> deserialized;
1809
1810 EXPECT_EQ(ptxt, deserialized);
1811 }
1812
TEST_P(TestPtxtBGV,serializeFunctionSerializesCorrectly)1813 TEST_P(TestPtxtBGV, serializeFunctionSerializesCorrectly)
1814 {
1815 helib::PolyMod poly(context.slotRing);
1816 std::vector<helib::PolyMod> data(context.ea->size(), poly);
1817 std::stringstream ptxt_string_stream;
1818 ptxt_string_stream << "[";
1819 for (long i = 0; i < helib::lsize(data); ++i) {
1820 data[i] = 2 * i;
1821 ptxt_string_stream << "[" << helib::mcMod(2 * i, ppowr) << "]";
1822 if (i < helib::lsize(data) - 1)
1823 ptxt_string_stream << ", ";
1824 }
1825 ptxt_string_stream << "]";
1826 helib::Ptxt<helib::BGV> ptxt(context, data);
1827
1828 std::stringstream ss;
1829 helib::serialize(ss, ptxt);
1830
1831 EXPECT_EQ(ss.str(), ptxt_string_stream.str());
1832 }
1833
TEST_P(TestPtxtBGV,deserializeFunctionDeserializesCorrectly)1834 TEST_P(TestPtxtBGV, deserializeFunctionDeserializesCorrectly)
1835 {
1836 helib::PolyMod poly(context.slotRing);
1837 std::vector<helib::PolyMod> data(context.ea->size(), poly);
1838 std::stringstream ptxt_string_stream;
1839 ptxt_string_stream << "[";
1840 for (long i = 0; i < helib::lsize(data); ++i) {
1841 NTL::ZZX tmp;
1842 ptxt_string_stream << "[";
1843 for (long j = 0; j < context.zMStar.getOrdP(); ++j) {
1844 NTL::SetCoeff(tmp, j, j * j);
1845 ptxt_string_stream << j * j;
1846 if (j < context.zMStar.getOrdP() - 1)
1847 ptxt_string_stream << ",";
1848 }
1849 data[i] = tmp;
1850 ptxt_string_stream << "]";
1851 if (i < helib::lsize(data) - 1)
1852 ptxt_string_stream << ", ";
1853 }
1854 ptxt_string_stream << "]";
1855 helib::Ptxt<helib::BGV> ptxt(context, data);
1856
1857 helib::Ptxt<helib::BGV> deserialized_ptxt(context);
1858 helib::deserialize(ptxt_string_stream, deserialized_ptxt);
1859
1860 EXPECT_EQ(ptxt, deserialized_ptxt);
1861 }
1862
TEST_P(TestPtxtBGV,deserializeFunctionThrowsIfMoreElementsThanSlots)1863 TEST_P(TestPtxtBGV, deserializeFunctionThrowsIfMoreElementsThanSlots)
1864 {
1865 helib::PolyMod poly(context.slotRing);
1866 std::vector<helib::PolyMod> data(context.ea->size() + 1, poly);
1867 std::stringstream ptxt_string_stream;
1868 ptxt_string_stream << "[";
1869 for (long i = 0; i < helib::lsize(data); ++i) {
1870 data[i] = {i, i * i};
1871 ptxt_string_stream << "[" << i << ", " << i * i << "]";
1872 if (i < helib::lsize(data) - 1)
1873 ptxt_string_stream << ", ";
1874 }
1875 ptxt_string_stream << "]";
1876
1877 helib::Ptxt<helib::BGV> deserialized_ptxt(context);
1878
1879 EXPECT_THROW(helib::deserialize(ptxt_string_stream, deserialized_ptxt),
1880 helib::IOError);
1881 }
1882
TEST_P(TestPtxtBGV,rightShiftOperatorThrowsIfMoreElementsThanSlots)1883 TEST_P(TestPtxtBGV, rightShiftOperatorThrowsIfMoreElementsThanSlots)
1884 {
1885 helib::PolyMod poly(context.slotRing);
1886 std::vector<helib::PolyMod> data(context.ea->size() + 1, poly);
1887 std::stringstream ptxt_string_stream;
1888 ptxt_string_stream << "[";
1889 for (long i = 0; i < helib::lsize(data); ++i) {
1890 data[i] = {i, i * i};
1891 ptxt_string_stream << "[" << i << ", " << i * i << "]";
1892 if (i < helib::lsize(data) - 1)
1893 ptxt_string_stream << ", ";
1894 }
1895 ptxt_string_stream << "]";
1896
1897 helib::Ptxt<helib::BGV> deserialized_ptxt(context);
1898 EXPECT_THROW(ptxt_string_stream >> deserialized_ptxt, helib::IOError);
1899 }
1900
TEST_P(TestPtxtBGV,readsManyPtxtsFromStream)1901 TEST_P(TestPtxtBGV, readsManyPtxtsFromStream)
1902 {
1903 helib::PolyMod poly(context.slotRing);
1904 std::vector<helib::PolyMod> data1(context.ea->size(), poly);
1905 std::vector<helib::PolyMod> data2(context.ea->size(), poly);
1906 std::vector<helib::PolyMod> data3(context.ea->size(), poly);
1907 for (long i = 0; i < helib::lsize(data1); ++i) {
1908 data1[i] = {i, i + 2};
1909 data2[i] = {2 * i, 2 * (i + 2)};
1910 data3[i] = {3 * i, 3 * (i + 2)};
1911 }
1912 helib::Ptxt<helib::BGV> ptxt1(context, data1);
1913 helib::Ptxt<helib::BGV> ptxt2(context, data2);
1914 helib::Ptxt<helib::BGV> ptxt3(context, data3);
1915
1916 std::stringstream ss;
1917 ss << ptxt1 << std::endl;
1918 ss << ptxt2 << std::endl;
1919 ss << ptxt3 << std::endl;
1920
1921 helib::Ptxt<helib::BGV> deserialized1(context);
1922 helib::Ptxt<helib::BGV> deserialized2(context);
1923 helib::Ptxt<helib::BGV> deserialized3(context);
1924 ss >> deserialized1;
1925 ss >> deserialized2;
1926 ss >> deserialized3;
1927
1928 EXPECT_EQ(ptxt1, deserialized1);
1929 EXPECT_EQ(ptxt2, deserialized2);
1930 EXPECT_EQ(ptxt3, deserialized3);
1931 }
1932
TEST_P(TestPtxtBGV,preservesPolyModDataPassedIntoConstructor)1933 TEST_P(TestPtxtBGV, preservesPolyModDataPassedIntoConstructor)
1934 {
1935 helib::PolyMod poly(context.slotRing);
1936 std::vector<helib::PolyMod> data(context.ea->size(), poly);
1937 helib::Ptxt<helib::BGV> ptxt(context, data);
1938 EXPECT_EQ(ptxt.size(), data.size());
1939 for (std::size_t i = 0; i < data.size(); ++i) {
1940 EXPECT_EQ(ptxt[i], data[i]);
1941 }
1942 }
1943
TEST_P(TestPtxtBGV,throwsIfp2rAndGDoNotMatchThoseFromContext)1944 TEST_P(TestPtxtBGV, throwsIfp2rAndGDoNotMatchThoseFromContext)
1945 {
1946 NTL::ZZX G = context.slotRing->G;
1947 long p = context.slotRing->p;
1948 long r = context.slotRing->r;
1949 // Non-matching p^r
1950 std::shared_ptr<helib::PolyModRing> badPolyModRing1(
1951 new helib::PolyModRing(p + 1, r, G));
1952 helib::PolyMod badPolyMod1(badPolyModRing1);
1953 // Non-matching G
1954 std::shared_ptr<helib::PolyModRing> badPolyModRing2(
1955 new helib::PolyModRing(p, r, G + 1));
1956 helib::PolyMod badPolyMod2(badPolyModRing2);
1957 // All good
1958 std::shared_ptr<helib::PolyModRing> goodPolyModRing(
1959 new helib::PolyModRing(p, r, G));
1960 helib::PolyMod goodPolyMod(goodPolyModRing);
1961
1962 std::vector<helib::PolyMod> data(context.ea->size(), goodPolyMod);
1963
1964 // Make all of them good except 1, make sure it still notices
1965 data.back() = badPolyMod1;
1966 EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, data),
1967 helib::RuntimeError);
1968 data.back() = badPolyMod2;
1969 EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, data),
1970 helib::RuntimeError);
1971
1972 // Make sure it complains if it's just given 1 bad PolyMod too
1973 EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, badPolyMod1),
1974 helib::RuntimeError);
1975 EXPECT_THROW(helib::Ptxt<helib::BGV> ptxt(context, badPolyMod2),
1976 helib::RuntimeError);
1977 }
1978
TEST_P(TestPtxtBGV,lsizeReportsCorrectSize)1979 TEST_P(TestPtxtBGV, lsizeReportsCorrectSize)
1980 {
1981 std::vector<long> data(context.ea->size());
1982 helib::Ptxt<helib::BGV> ptxt(context, data);
1983 EXPECT_EQ(ptxt.lsize(), data.size());
1984 }
1985
TEST_P(TestPtxtBGV,sizeReportsCorrectSize)1986 TEST_P(TestPtxtBGV, sizeReportsCorrectSize)
1987 {
1988 std::vector<long> data(context.ea->size());
1989 helib::Ptxt<helib::BGV> ptxt(context, data);
1990 EXPECT_EQ(ptxt.size(), data.size());
1991 }
1992
TEST_P(TestPtxtBGV,padsWithZerosWhenPassingInSmallDataVector)1993 TEST_P(TestPtxtBGV, padsWithZerosWhenPassingInSmallDataVector)
1994 {
1995 std::vector<long> data(context.ea->size() - 1);
1996 std::iota(data.begin(), data.end(), 0);
1997 helib::Ptxt<helib::BGV> ptxt(context, data);
1998 for (std::size_t i = 0; i < data.size(); ++i) {
1999 EXPECT_EQ(ptxt[i], helib::mcMod(data[i], ppowr));
2000 }
2001 for (std::size_t i = data.size(); i < ptxt.size(); ++i) {
2002 EXPECT_EQ(ptxt[i], 0l);
2003 }
2004 }
2005
TEST_P(TestPtxtBGV,hasSameNumberOfSlotsAsContext)2006 TEST_P(TestPtxtBGV, hasSameNumberOfSlotsAsContext)
2007 {
2008 helib::Ptxt<helib::BGV> ptxt(context);
2009 EXPECT_EQ(context.ea->size(), ptxt.size());
2010 }
2011
TEST_P(TestPtxtBGV,randomSetsDataRandomly)2012 TEST_P(TestPtxtBGV, randomSetsDataRandomly)
2013 {
2014 helib::Ptxt<helib::BGV> ptxt(context);
2015 ptxt.random();
2016 std::vector<helib::Ptxt<helib::BGV>> ptxts(5, ptxt);
2017 for (auto& p : ptxts)
2018 p.random();
2019
2020 bool all_equal = true;
2021 for (std::size_t i = 0; i < ptxts.size() - 1; ++i)
2022 if (ptxts[i] != ptxts[i + 1]) {
2023 all_equal = false;
2024 break;
2025 }
2026 EXPECT_FALSE(all_equal) << "5 random ptxts are all equal - likely that"
2027 " random() is not actually randomising!";
2028 }
2029
TEST_P(TestPtxtBGV,runningSumsWorksCorrectly)2030 TEST_P(TestPtxtBGV, runningSumsWorksCorrectly)
2031 {
2032 std::vector<long> data(context.ea->size());
2033 std::iota(data.begin(), data.end(), 1);
2034 std::vector<long> expected_result(data.size());
2035 for (std::size_t i = 0; i < data.size(); ++i)
2036 expected_result[i] = ((i + 1) * (i + 2)) / 2;
2037
2038 helib::Ptxt<helib::BGV> ptxt(context, data);
2039 ptxt.runningSums();
2040
2041 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2042 EXPECT_EQ(ptxt[i], expected_result[i]);
2043 }
2044 }
2045
TEST_P(TestPtxtBGV,totalSumsWorksCorrectly)2046 TEST_P(TestPtxtBGV, totalSumsWorksCorrectly)
2047 {
2048 std::vector<long> data(context.ea->size());
2049 std::iota(data.begin(), data.end(), 1);
2050 std::vector<long> expected_result(data.size());
2051 for (std::size_t i = 0; i < data.size(); ++i)
2052 expected_result[i] = (data.size() * (data.size() + 1)) / 2;
2053
2054 helib::Ptxt<helib::BGV> ptxt(context, data);
2055 ptxt.totalSums();
2056
2057 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2058 EXPECT_EQ(ptxt[i], expected_result[i]);
2059 }
2060 }
2061
TEST_P(TestPtxtBGV,incrementalProductWorksCorrectly)2062 TEST_P(TestPtxtBGV, incrementalProductWorksCorrectly)
2063 {
2064 std::vector<long> data(context.ea->size());
2065 std::iota(data.begin(), data.end(), 1);
2066 std::vector<long> expected_result(data);
2067 for (std::size_t i = 1; i < data.size(); ++i)
2068 expected_result[i] =
2069 (expected_result[i] * expected_result[i - 1]) % context.slotRing->p2r;
2070
2071 helib::Ptxt<helib::BGV> ptxt(context, data);
2072 ptxt.incrementalProduct();
2073
2074 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2075 EXPECT_EQ(ptxt[i], expected_result[i]);
2076 }
2077 }
2078
TEST_P(TestPtxtBGV,totalProductWorksCorrectly)2079 TEST_P(TestPtxtBGV, totalProductWorksCorrectly)
2080 {
2081 std::vector<long> data(context.ea->size());
2082 std::iota(data.begin(), data.end(), 1);
2083 long product = 1;
2084 for (std::size_t i = 0; i < data.size(); ++i) {
2085 product *= data[i];
2086 product %= context.slotRing->p2r;
2087 }
2088 std::vector<long> expected_result(data.size(), product);
2089
2090 helib::Ptxt<helib::BGV> ptxt(context, data);
2091 ptxt.totalProduct();
2092
2093 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2094 EXPECT_EQ(ptxt[i], expected_result[i]);
2095 }
2096 }
2097
TEST_P(TestPtxtBGV,innerProductWorksCorrectly)2098 TEST_P(TestPtxtBGV, innerProductWorksCorrectly)
2099 {
2100 std::vector<long> data(context.ea->size());
2101 std::iota(data.begin(), data.end(), 0);
2102 helib::Ptxt<helib::BGV> ptxt(context, data);
2103 std::vector<helib::Ptxt<helib::BGV>> first_ptxt_vector(4, ptxt);
2104 ptxt += ptxt;
2105 std::vector<helib::Ptxt<helib::BGV>> second_ptxt_vector(4, ptxt);
2106
2107 helib::Ptxt<helib::BGV> result(context);
2108 innerProduct(result, first_ptxt_vector, second_ptxt_vector);
2109
2110 std::vector<long> expected_result(data.size());
2111 for (std::size_t i = 0; i < data.size(); ++i) {
2112 expected_result[i] = 4 * (data[i] * (2 * data[i]));
2113 }
2114
2115 for (std::size_t i = 0; i < result.size(); ++i) {
2116 EXPECT_EQ(result[i], expected_result[i]);
2117 }
2118 }
2119
TEST_P(TestPtxtBGV,mapTo01MapsSlotsCorrectly)2120 TEST_P(TestPtxtBGV, mapTo01MapsSlotsCorrectly)
2121 {
2122 std::vector<long> data(context.ea->size());
2123 std::iota(data.begin(), data.end(), 0);
2124 std::vector<long> expected_result(data.size(), 1);
2125 for (std::size_t i = 0; i < data.size(); ++i)
2126 if (i % p == 0)
2127 expected_result[i] = 0;
2128
2129 // Should exist as a free function and a member function
2130 helib::Ptxt<helib::BGV> ptxt(context, data);
2131 helib::Ptxt<helib::BGV> ptxt2(context, data);
2132 ptxt.mapTo01();
2133 mapTo01(*(context.ea), ptxt2);
2134
2135 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2136 EXPECT_EQ(ptxt[i], expected_result[i]);
2137 EXPECT_EQ(ptxt2[i], expected_result[i]);
2138 }
2139 }
2140
TEST(TestPtxtBGV,automorphWorksCorrectly)2141 TEST(TestPtxtBGV, automorphWorksCorrectly)
2142 {
2143 std::vector<long> gens = {11, 2};
2144 std::vector<long> ords = {6, 2};
2145 const helib::Context context(45, 19, 1, gens, ords);
2146 std::vector<NTL::ZZX> data(context.ea->size());
2147 for (std::size_t i = 0; i < data.size(); ++i) {
2148 NTL::SetX(data[i]);
2149 (data[i] += 1) *= i;
2150 }
2151
2152 helib::Ptxt<helib::BGV> ptxt(context, data);
2153 helib::Ptxt<helib::BGV> expected_result(ptxt);
2154 expected_result[0] = {13, 10};
2155 expected_result[1] = {0, 0};
2156 expected_result[2] = {18, 8};
2157 expected_result[3] = {2, 2};
2158 expected_result[4] = {12, 18};
2159 expected_result[5] = {4, 4};
2160 expected_result[6] = {17, 16};
2161 expected_result[7] = {6, 6};
2162 expected_result[8] = {3, 14};
2163 expected_result[9] = {8, 8};
2164 expected_result[10] = {8, 12};
2165 expected_result[11] = {10, 10};
2166
2167 ptxt.automorph(2);
2168
2169 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2170 EXPECT_EQ(ptxt[i], expected_result[i]);
2171 }
2172 }
2173
TEST_P(TestPtxtBGV,frobeniusAutomorphWithConstantsWorksCorrectly)2174 TEST_P(TestPtxtBGV, frobeniusAutomorphWithConstantsWorksCorrectly)
2175 {
2176 std::vector<long> data(context.ea->size());
2177 std::iota(data.begin(), data.end(), 0);
2178 std::vector<long> expected_result(data);
2179 helib::Ptxt<helib::BGV> ptxt(context, data);
2180 for (long i = 0; i <= context.zMStar.getOrdP(); ++i) {
2181 auto ptxtUnderTest = ptxt;
2182 ptxtUnderTest.frobeniusAutomorph(i);
2183 for (std::size_t j = 0; j < ptxtUnderTest.size(); ++j) {
2184 ASSERT_EQ(ptxtUnderTest[j], expected_result[j])
2185 << "i = " << i << " j = " << j << std::endl;
2186 }
2187 }
2188 }
2189
TEST(TestPtxtBGV,frobeniusAutomorphWithPolynomialsWorksCorrectly)2190 TEST(TestPtxtBGV, frobeniusAutomorphWithPolynomialsWorksCorrectly)
2191 {
2192 const helib::Context context(45, 19, 1);
2193 std::vector<NTL::ZZX> data(context.ea->size());
2194 for (std::size_t i = 0; i < data.size(); ++i) {
2195 NTL::SetX(data[i]);
2196 (data[i] += 1) *= i;
2197 }
2198
2199 helib::Ptxt<helib::BGV> ptxt(context, data);
2200 helib::Ptxt<helib::BGV> expected_result(ptxt);
2201 expected_result[0] = {0, 0};
2202 expected_result[1] = {12, 18};
2203 expected_result[2] = {5, 17};
2204 expected_result[3] = {17, 16};
2205 expected_result[4] = {10, 15};
2206 expected_result[5] = {3, 14};
2207 expected_result[6] = {15, 13};
2208 expected_result[7] = {8, 12};
2209 expected_result[8] = {1, 11};
2210 expected_result[9] = {13, 10};
2211 expected_result[10] = {6, 9};
2212 expected_result[11] = {18, 8};
2213
2214 ptxt.frobeniusAutomorph(1);
2215
2216 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2217 EXPECT_EQ(ptxt[i], expected_result[i]);
2218 }
2219 }
2220
TEST_P(TestPtxtBGV,timesEqualsOtherPlaintextWorks)2221 TEST_P(TestPtxtBGV, timesEqualsOtherPlaintextWorks)
2222 {
2223 std::vector<long> product_data(context.ea->size(), 3);
2224 std::vector<long> multiplier_data(context.ea->size());
2225 std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
2226
2227 std::vector<long> expected_result(product_data);
2228 for (std::size_t i = 0; i < product_data.size(); ++i) {
2229 expected_result[i] = expected_result[i] * multiplier_data[i];
2230 }
2231
2232 helib::Ptxt<helib::BGV> product(context, product_data);
2233 helib::Ptxt<helib::BGV> multiplier(context, multiplier_data);
2234
2235 product *= multiplier;
2236
2237 for (std::size_t i = 0; i < product.size(); ++i) {
2238 EXPECT_EQ(product[i], expected_result[i]);
2239 }
2240 }
2241
TEST_P(TestPtxtBGV,minusEqualsOtherPlaintextWorks)2242 TEST_P(TestPtxtBGV, minusEqualsOtherPlaintextWorks)
2243 {
2244 std::vector<long> difference_data(context.ea->size(), 1);
2245 std::vector<long> subtrahend_data(context.ea->size());
2246 std::iota(subtrahend_data.begin(), subtrahend_data.end(), 0);
2247
2248 std::vector<long> expected_result(difference_data);
2249 for (std::size_t i = 0; i < subtrahend_data.size(); ++i) {
2250 expected_result[i] = expected_result[i] - subtrahend_data[i];
2251 }
2252
2253 helib::Ptxt<helib::BGV> difference(context, difference_data);
2254 helib::Ptxt<helib::BGV> subtrahend(context, subtrahend_data);
2255
2256 difference -= subtrahend;
2257
2258 for (std::size_t i = 0; i < difference.size(); ++i) {
2259 EXPECT_EQ(difference[i], expected_result[i]);
2260 }
2261 }
2262
TEST_P(TestPtxtBGV,minusEqualsScalarWorks)2263 TEST_P(TestPtxtBGV, minusEqualsScalarWorks)
2264 {
2265 std::vector<long> data(context.ea->size());
2266 std::iota(data.begin(), data.end(), 0);
2267
2268 const long scalar = 3;
2269
2270 std::vector<long> expected_result(data);
2271 for (auto& num : expected_result)
2272 num = num - scalar;
2273
2274 helib::Ptxt<helib::BGV> ptxt(context, data);
2275 ptxt -= scalar;
2276
2277 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2278 EXPECT_EQ(ptxt[i], expected_result[i]);
2279 }
2280 }
2281
TEST_P(TestPtxtBGV,plusEqualsOtherPlaintextWorks)2282 TEST_P(TestPtxtBGV, plusEqualsOtherPlaintextWorks)
2283 {
2284 std::vector<long> augend_data(context.ea->size());
2285 std::iota(augend_data.begin(), augend_data.end(), 0);
2286 std::vector<long> addend_data(context.ea->size());
2287 for (long i = 0; i < helib::lsize(addend_data); ++i)
2288 addend_data[i] = helib::mcMod(2 * i + 1, p);
2289 std::vector<long> expected_result(context.ea->size());
2290 for (long i = 0; i < helib::lsize(expected_result); ++i)
2291 expected_result[i] = augend_data[i] + addend_data[i];
2292
2293 helib::Ptxt<helib::BGV> sum(context, augend_data);
2294 helib::Ptxt<helib::BGV> addend(context, addend_data);
2295 sum += addend;
2296
2297 for (std::size_t i = 0; i < sum.size(); ++i) {
2298 EXPECT_EQ(sum[i], expected_result[i]);
2299 }
2300 }
2301
TEST_P(TestPtxtBGV,plusEqualsScalarWorks)2302 TEST_P(TestPtxtBGV, plusEqualsScalarWorks)
2303 {
2304 std::vector<long> data(context.ea->size());
2305 std::iota(data.begin(), data.end(), 0);
2306
2307 const long scalar = 3;
2308
2309 std::vector<long> expected_result(data);
2310 for (auto& num : expected_result)
2311 num = num + scalar;
2312
2313 helib::Ptxt<helib::BGV> ptxt(context, data);
2314 ptxt += scalar;
2315
2316 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2317 EXPECT_EQ(ptxt[i], expected_result[i]);
2318 }
2319 }
2320
TEST_P(TestPtxtBGV,timesEqualsScalarWorks)2321 TEST_P(TestPtxtBGV, timesEqualsScalarWorks)
2322 {
2323 std::vector<long> data(context.ea->size());
2324 std::iota(data.begin(), data.end(), 0);
2325
2326 const long scalar = 3;
2327
2328 std::vector<long> expected_result(data);
2329 for (auto& num : expected_result)
2330 num = num * scalar;
2331
2332 helib::Ptxt<helib::BGV> ptxt(context, data);
2333 ptxt *= scalar;
2334
2335 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2336 EXPECT_EQ(ptxt[i], expected_result[i]);
2337 }
2338 }
2339
TEST_P(TestPtxtBGV,equalityWithOtherPlaintextWorks)2340 TEST_P(TestPtxtBGV, equalityWithOtherPlaintextWorks)
2341 {
2342 std::vector<long> data(context.ea->size());
2343 std::iota(data.begin(), data.end(), 0);
2344
2345 helib::Ptxt<helib::BGV> ptxt1(context, data);
2346 helib::Ptxt<helib::BGV> ptxt2(context, data);
2347 EXPECT_TRUE(ptxt1 == ptxt2);
2348 }
2349
TEST_P(TestPtxtBGV,notEqualsOperatorWithOtherPlaintextWorks)2350 TEST_P(TestPtxtBGV, notEqualsOperatorWithOtherPlaintextWorks)
2351 {
2352 std::vector<long> data1(context.ea->size());
2353 std::iota(data1.begin(), data1.end(), 0);
2354 std::vector<long> data2(context.ea->size());
2355 std::iota(data2.begin(), data2.end(), 1);
2356
2357 helib::Ptxt<helib::BGV> ptxt1(context, data1);
2358 helib::Ptxt<helib::BGV> ptxt2(context, data2);
2359 EXPECT_TRUE(ptxt1 != ptxt2);
2360 }
2361
TEST_P(TestPtxtBGV,negateNegatesCorrectly)2362 TEST_P(TestPtxtBGV, negateNegatesCorrectly)
2363 {
2364 std::vector<long> data(context.ea->size());
2365 std::iota(data.begin(), data.end(), 0);
2366
2367 std::vector<long> expected_result(data);
2368 for (auto& num : expected_result)
2369 num = -num;
2370
2371 helib::Ptxt<helib::BGV> ptxt(context, data);
2372 ptxt.negate();
2373
2374 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2375 EXPECT_EQ(ptxt[i], expected_result[i]);
2376 }
2377 }
2378
TEST_P(TestPtxtBGV,addConstantWorksCorrectly)2379 TEST_P(TestPtxtBGV, addConstantWorksCorrectly)
2380 {
2381 NTL::ZZX input;
2382 NTL::SetCoeff(input, 0, 2);
2383 NTL::SetCoeff(input, 1, 1);
2384
2385 helib::PolyMod poly(input, context.slotRing);
2386 std::vector<helib::PolyMod> data(context.ea->size());
2387 for (std::size_t i = 0; i < data.size(); ++i)
2388 data[i] = poly + i;
2389
2390 std::vector<helib::PolyMod> expected_result(data);
2391 for (auto& num : expected_result)
2392 (num += input) += 3L;
2393
2394 helib::Ptxt<helib::BGV> ptxt(context, data);
2395 ptxt.addConstant(input).addConstant(3l);
2396
2397 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2398 EXPECT_EQ(ptxt[i], expected_result[i]);
2399 }
2400 }
2401
TEST_P(TestPtxtBGV,multiplyByMultipliesCorrectly)2402 TEST_P(TestPtxtBGV, multiplyByMultipliesCorrectly)
2403 {
2404 std::vector<long> product_data(context.ea->size(), 3);
2405 std::vector<long> multiplier_data(context.ea->size());
2406 std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
2407
2408 std::vector<long> expected_result(product_data);
2409 for (std::size_t i = 0; i < product_data.size(); ++i) {
2410 expected_result[i] = expected_result[i] * multiplier_data[i];
2411 }
2412
2413 helib::Ptxt<helib::BGV> product(context, product_data);
2414 helib::Ptxt<helib::BGV> multiplier(context, multiplier_data);
2415
2416 product.multiplyBy(multiplier);
2417
2418 for (std::size_t i = 0; i < product.size(); ++i) {
2419 EXPECT_EQ(product[i], expected_result[i]);
2420 }
2421 }
2422
TEST_P(TestPtxtBGV,multiplyBy2MultipliesCorrectly)2423 TEST_P(TestPtxtBGV, multiplyBy2MultipliesCorrectly)
2424 {
2425 std::vector<long> product_data(context.ea->size(), 3);
2426 std::vector<long> multiplier_data1(context.ea->size());
2427 std::vector<long> multiplier_data2(context.ea->size());
2428 std::iota(multiplier_data1.begin(), multiplier_data1.end(), 0);
2429 std::iota(multiplier_data2.begin(), multiplier_data2.end(), 0);
2430
2431 std::vector<long> expected_result(product_data);
2432 for (std::size_t i = 0; i < product_data.size(); ++i) {
2433 expected_result[i] =
2434 expected_result[i] * multiplier_data1[i] * multiplier_data2[i];
2435 }
2436
2437 helib::Ptxt<helib::BGV> product(context, product_data);
2438 helib::Ptxt<helib::BGV> multiplier1(context, multiplier_data1);
2439 helib::Ptxt<helib::BGV> multiplier2(context, multiplier_data2);
2440
2441 product.multiplyBy2(multiplier1, multiplier2);
2442
2443 for (std::size_t i = 0; i < product.size(); ++i) {
2444 EXPECT_EQ(product[i], expected_result[i]);
2445 }
2446 }
2447
TEST_P(TestPtxtBGV,squareSquaresCorrectly)2448 TEST_P(TestPtxtBGV, squareSquaresCorrectly)
2449 {
2450 std::vector<long> data(context.ea->size());
2451 std::iota(data.begin(), data.end(), 0);
2452 std::vector<long> expected_result(data);
2453 for (auto& num : expected_result)
2454 num = num * num;
2455 helib::Ptxt<helib::BGV> ptxt(context, data);
2456 ptxt.square();
2457 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2458 EXPECT_EQ(ptxt[i], expected_result[i]);
2459 }
2460 }
2461
TEST_P(TestPtxtBGV,cubeCubesCorrectly)2462 TEST_P(TestPtxtBGV, cubeCubesCorrectly)
2463 {
2464 std::vector<long> data(context.ea->size());
2465 std::iota(data.begin(), data.end(), 0);
2466 std::vector<long> expected_result(data);
2467 for (auto& num : expected_result)
2468 num = num * num * num;
2469 helib::Ptxt<helib::BGV> ptxt(context, data);
2470 ptxt.cube();
2471 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2472 EXPECT_EQ(ptxt[i], expected_result[i]);
2473 }
2474 }
2475
TEST_P(TestPtxtBGV,powerCorrectlyRaisesToPowers)2476 TEST_P(TestPtxtBGV, powerCorrectlyRaisesToPowers)
2477 {
2478 std::vector<long> data(context.ea->size());
2479 std::iota(data.begin(),
2480 data.end(),
2481 -(static_cast<long>(context.ea->size()) / 2));
2482 std::vector<long> exponents{1, 3, 4, 5, 300};
2483
2484 const auto naive_powermod =
2485 [](long base, unsigned long exponent, unsigned long mod) {
2486 if (exponent == 0)
2487 return 1l;
2488
2489 long result = base;
2490 while (--exponent)
2491 result = helib::mcMod(result * base, mod);
2492 return result;
2493 };
2494
2495 for (const auto& exponent : exponents) {
2496 std::vector<long> expected_result(data);
2497 for (auto& num : expected_result)
2498 num = naive_powermod(num, exponent, ppowr);
2499 helib::Ptxt<helib::BGV> ptxt(context, data);
2500 ptxt.power(exponent);
2501 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2502 EXPECT_EQ(ptxt[i], expected_result[i]);
2503 }
2504 }
2505
2506 // Make sure raising to 0 throws
2507 helib::Ptxt<helib::CKKS> ptxt(context, data);
2508 EXPECT_THROW(ptxt.power(0l), helib::InvalidArgument);
2509 }
2510
TEST_P(TestPtxtBGV,shiftShiftsRightCorrectly)2511 TEST_P(TestPtxtBGV, shiftShiftsRightCorrectly)
2512 {
2513 std::vector<long> data(context.ea->size());
2514 std::vector<long> right_shifted_data(context.ea->size());
2515 const auto non_neg_mod = [](int x, int mod) {
2516 return ((x % mod) + mod) % mod;
2517 };
2518 for (std::size_t i = 0; i < data.size(); ++i) {
2519 if (i > 3) {
2520 right_shifted_data[i] = non_neg_mod(i - 3, data.size());
2521 }
2522 data[i] = i;
2523 }
2524 helib::Ptxt<helib::BGV> ptxt(context, data);
2525
2526 ptxt.shift(3);
2527 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2528 EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2529 }
2530 }
2531
TEST_P(TestPtxtBGV,shiftShiftsLeftCorrectly)2532 TEST_P(TestPtxtBGV, shiftShiftsLeftCorrectly)
2533 {
2534 std::vector<long> data(context.ea->size());
2535 std::vector<long> left_shifted_data(context.ea->size());
2536 const auto non_neg_mod = [](int x, int mod) {
2537 return ((x % mod) + mod) % mod;
2538 };
2539 for (std::size_t i = 0; i < data.size(); ++i) {
2540 if (i < data.size() - 3 && data.size() > 3) {
2541 left_shifted_data[i] = non_neg_mod(i + 3, data.size());
2542 }
2543 data[i] = i;
2544 }
2545 helib::Ptxt<helib::BGV> ptxt(context, data);
2546
2547 ptxt.shift(-3);
2548 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2549 EXPECT_EQ(ptxt[i], left_shifted_data[i]);
2550 }
2551 }
2552
TEST(TestPtxtBGV,shift1DShiftsRightCorrectly)2553 TEST(TestPtxtBGV, shift1DShiftsRightCorrectly)
2554 {
2555 long amount = 1;
2556 const helib::Context context(45, 19, 1);
2557 std::vector<long> data(context.ea->size());
2558 std::vector<long> right_shifted_data(context.ea->size());
2559 const auto shift_first_dim = [](long amount, std::vector<long>& data) {
2560 std::vector<long> new_data(data.size(), 0l);
2561 for (long i = 0; i < helib::lsize(data); ++i)
2562 if (i + 2 * amount < 12 && i + 2 * amount >= 0)
2563 new_data[i + 2 * amount] = data[i];
2564 data = std::move(new_data);
2565 };
2566 const auto shift_second_dim = [](long amount, std::vector<long>& data) {
2567 std::vector<long> new_data(data.size(), 0l);
2568 for (long i = 0; i < helib::lsize(data); ++i)
2569 switch (amount) {
2570 case 1l:
2571 if (i < helib::lsize(data) - 1)
2572 new_data[i + 1] = i & 1 ? 0 : data[i];
2573 break;
2574 case 0l:
2575 new_data[i] = data[i];
2576 break;
2577 case -1l:
2578 if (i > 0)
2579 new_data[i - 1] = i & 1 ? data[i] : 0;
2580 break;
2581 default:
2582 new_data[i] = 0;
2583 break;
2584 }
2585 data = std::move(new_data);
2586 };
2587 for (long i = 0; i < helib::lsize(data); ++i) {
2588 data[i] = i;
2589 }
2590 {
2591 helib::Ptxt<helib::BGV> ptxt(context, data);
2592
2593 right_shifted_data = data;
2594 shift_first_dim(amount, right_shifted_data);
2595 ptxt.shift1D(0, amount);
2596
2597 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2598 EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2599 }
2600 }
2601 {
2602 helib::Ptxt<helib::BGV> ptxt(context, data);
2603
2604 right_shifted_data = data;
2605 shift_second_dim(amount, right_shifted_data);
2606 ptxt.shift1D(1, amount);
2607
2608 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2609 EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2610 }
2611 }
2612 }
2613
TEST(TestPtxtBGV,shift1DShiftsLeftCorrectly)2614 TEST(TestPtxtBGV, shift1DShiftsLeftCorrectly)
2615 {
2616 long amount = -1;
2617 const helib::Context context(45, 19, 1);
2618 std::vector<long> data(context.ea->size());
2619 std::vector<long> right_shifted_data(context.ea->size());
2620 const auto shift_first_dim = [](long amount, std::vector<long>& data) {
2621 std::vector<long> new_data(data.size(), 0l);
2622 for (long i = 0; i < helib::lsize(data); ++i)
2623 if (i + 2 * amount < 12 && i + 2 * amount >= 0)
2624 new_data[i + 2 * amount] = data[i];
2625 data = std::move(new_data);
2626 };
2627 const auto shift_second_dim = [](long amount, std::vector<long>& data) {
2628 std::vector<long> new_data(data.size(), 0l);
2629 for (long i = 0; i < helib::lsize(data); ++i)
2630 switch (amount) {
2631 case 1l:
2632 if (i < helib::lsize(data) - 1)
2633 new_data[i + 1] = i & 1 ? 0 : data[i];
2634 break;
2635 case 0l:
2636 new_data[i] = data[i];
2637 break;
2638 case -1l:
2639 if (i > 0)
2640 new_data[i - 1] = i & 1 ? data[i] : 0;
2641 break;
2642 default:
2643 new_data[i] = 0;
2644 break;
2645 }
2646 data = std::move(new_data);
2647 };
2648 for (long i = 0; i < helib::lsize(data); ++i) {
2649 data[i] = i;
2650 }
2651 {
2652 helib::Ptxt<helib::BGV> ptxt(context, data);
2653
2654 right_shifted_data = data;
2655 shift_first_dim(amount, right_shifted_data);
2656 ptxt.shift1D(0, amount);
2657
2658 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2659 EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2660 }
2661 }
2662 {
2663 helib::Ptxt<helib::BGV> ptxt(context, data);
2664
2665 right_shifted_data = data;
2666 shift_second_dim(amount, right_shifted_data);
2667 ptxt.shift1D(1, amount);
2668
2669 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2670 EXPECT_EQ(ptxt[i], right_shifted_data[i]);
2671 }
2672 }
2673 }
2674
TEST(TestPtxtBGV,rotate1DRotatesCorrectly)2675 TEST(TestPtxtBGV, rotate1DRotatesCorrectly)
2676 {
2677 long amount = 1;
2678 const helib::Context context(45, 19, 1);
2679 std::vector<long> data(context.ea->size());
2680 std::vector<long> left_rotated_data(context.ea->size());
2681 const auto rotate_first_dim = [](long amount, std::vector<long>& data) {
2682 amount = helib::mcMod(amount, 12);
2683 std::vector<long> new_data(data);
2684 for (long i = 0; i < helib::lsize(data); ++i)
2685 new_data[(i + 2 * amount) % 12] = data[i];
2686 data = std::move(new_data);
2687 };
2688 const auto rotate_second_dim = [](long amount, std::vector<long>& data) {
2689 std::vector<long> new_data(data);
2690 for (long i = 0; i < helib::lsize(data); ++i)
2691 if (amount % 2)
2692 new_data[i + (i & 1 ? -1 : 1)] = data[i];
2693 data = std::move(new_data);
2694 };
2695 for (long i = 0; i < helib::lsize(data); ++i) {
2696 data[i] = i;
2697 }
2698 helib::Ptxt<helib::BGV> ptxt(context, data);
2699
2700 // Rotate in first dimension (Good Dimension)
2701 rotate_first_dim(-amount, data);
2702 ptxt.rotate1D(0, -amount);
2703
2704 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2705 EXPECT_EQ(ptxt[i], data[i]);
2706 }
2707
2708 rotate_first_dim(amount, data);
2709 ptxt.rotate1D(0, amount);
2710 // Rotating back and forth gives the original data back
2711 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2712 EXPECT_EQ(ptxt[i], data[i]);
2713 }
2714
2715 // Rotate in second dimension (Bad Dimension)
2716 rotate_second_dim(-amount, data);
2717 ptxt.rotate1D(1, -amount);
2718
2719 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2720 EXPECT_EQ(ptxt[i], data[i]);
2721 }
2722
2723 rotate_second_dim(amount, data);
2724 ptxt.rotate1D(1, amount);
2725 // Rotating back and forth gives the original data back
2726 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2727 EXPECT_EQ(ptxt[i], data[i]);
2728 }
2729 }
2730
TEST_P(TestPtxtBGV,rotateRotatesCorrectly)2731 TEST_P(TestPtxtBGV, rotateRotatesCorrectly)
2732 {
2733 std::vector<long> data(context.ea->size());
2734 std::vector<long> left_rotated_data(context.ea->size());
2735 const auto non_neg_mod = [](int x, int mod) {
2736 return ((x % mod) + mod) % mod;
2737 };
2738 for (int i = 0; i < helib::lsize(data); ++i) {
2739 data[i] = non_neg_mod(i - 3, data.size());
2740 left_rotated_data[i] = i;
2741 }
2742 helib::Ptxt<helib::BGV> ptxt(context, data);
2743
2744 ptxt.rotate(-3);
2745 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2746 EXPECT_EQ(ptxt[i], left_rotated_data[i]);
2747 }
2748 ptxt.rotate(3);
2749 // Rotating back and forth gives the original data back
2750 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2751 EXPECT_EQ(ptxt[i], data[i]);
2752 }
2753 }
2754
TEST_P(TestPtxtBGV,replicateReplicatesCorrectly)2755 TEST_P(TestPtxtBGV, replicateReplicatesCorrectly)
2756 {
2757 std::vector<long> data(context.ea->size());
2758 std::iota(data.begin(), data.end(), 0);
2759 helib::Ptxt<helib::BGV> ptxt(context, data);
2760 helib::replicate(*context.ea, ptxt, data.size() - 1);
2761 std::vector<long> replicated_data(context.ea->size(), data[data.size() - 1]);
2762 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2763 EXPECT_EQ(ptxt[i], replicated_data[i]);
2764 }
2765 }
2766
TEST_P(TestPtxtBGV,replicateAllWorksCorrectly)2767 TEST_P(TestPtxtBGV, replicateAllWorksCorrectly)
2768 {
2769 std::vector<long> data(context.ea->size());
2770 std::iota(data.begin(), data.end(), 0);
2771 helib::Ptxt<helib::BGV> ptxt(context, data);
2772 std::vector<helib::Ptxt<helib::BGV>> replicated_ptxts = ptxt.replicateAll();
2773 for (long i = 0; i < helib::lsize(data); ++i) {
2774 for (const auto& slot : replicated_ptxts[i].getSlotRepr()) {
2775 EXPECT_EQ(slot, data[i]);
2776 }
2777 }
2778 }
2779
TEST_P(TestPtxtBGV,clearZeroesAllSlots)2780 TEST_P(TestPtxtBGV, clearZeroesAllSlots)
2781 {
2782 std::vector<long> data(context.ea->size());
2783 std::iota(data.begin(), data.end(), 0);
2784 std::vector<long> expected_result(context.ea->size(), 0);
2785 helib::Ptxt<helib::BGV> ptxt(context, data);
2786 ptxt.clear();
2787 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2788 EXPECT_EQ(ptxt[i], expected_result[i]);
2789 }
2790 }
2791
TEST_P(TestPtxtBGV,defaultConstructedPtxtThrowsWhenOperatedOn)2792 TEST_P(TestPtxtBGV, defaultConstructedPtxtThrowsWhenOperatedOn)
2793 {
2794 helib::Ptxt<helib::BGV> ptxt1;
2795 helib::Ptxt<helib::BGV> ptxt2;
2796 helib::PolyMod poly;
2797 EXPECT_THROW(ptxt1.getSlotRepr(), helib::RuntimeError);
2798 EXPECT_THROW(ptxt1.setData(poly), helib::RuntimeError);
2799 EXPECT_THROW(ptxt1[0], helib::RuntimeError);
2800 EXPECT_THROW(ptxt1 *= ptxt2, helib::RuntimeError);
2801 EXPECT_THROW(ptxt1 += ptxt2, helib::RuntimeError);
2802 EXPECT_THROW(ptxt1 -= ptxt2, helib::RuntimeError);
2803 EXPECT_THROW(ptxt1 += 1l, helib::RuntimeError);
2804 EXPECT_THROW(ptxt1 *= 3l, helib::RuntimeError);
2805 EXPECT_THROW(ptxt1.negate(), helib::RuntimeError);
2806 EXPECT_THROW(ptxt1.multiplyBy(ptxt2), helib::RuntimeError);
2807 EXPECT_THROW(ptxt1.multiplyBy2(ptxt1, ptxt2), helib::RuntimeError);
2808 EXPECT_THROW(ptxt1.square(), helib::RuntimeError);
2809 EXPECT_THROW(ptxt1.cube(), helib::RuntimeError);
2810 EXPECT_THROW(ptxt1.power(4l), helib::RuntimeError);
2811 EXPECT_THROW(ptxt1.size(), helib::RuntimeError);
2812 EXPECT_THROW(ptxt1.rotate(1), helib::RuntimeError);
2813 EXPECT_THROW(ptxt1.rotate1D(0, 1), helib::RuntimeError);
2814 EXPECT_THROW(ptxt1.shift(1), helib::RuntimeError);
2815 EXPECT_THROW(ptxt1.lsize(), helib::LogicError);
2816 }
2817
TEST_P(TestPtxtBGV,defaultConstructedContextCannotBeRightOperand)2818 TEST_P(TestPtxtBGV, defaultConstructedContextCannotBeRightOperand)
2819 {
2820 std::vector<long> data(context.ea->size());
2821 std::iota(data.begin(), data.end(), 0);
2822 helib::Ptxt<helib::BGV> valid_ptxt(context, data);
2823 helib::Ptxt<helib::BGV> invalid_ptxt;
2824
2825 EXPECT_THROW(valid_ptxt *= invalid_ptxt, helib::RuntimeError);
2826 EXPECT_THROW(valid_ptxt += invalid_ptxt, helib::RuntimeError);
2827 EXPECT_THROW(valid_ptxt -= invalid_ptxt, helib::RuntimeError);
2828 EXPECT_THROW(valid_ptxt.multiplyBy(invalid_ptxt), helib::RuntimeError);
2829 EXPECT_THROW(valid_ptxt.multiplyBy2(invalid_ptxt, valid_ptxt),
2830 helib::RuntimeError);
2831 EXPECT_THROW(valid_ptxt.multiplyBy2(valid_ptxt, invalid_ptxt),
2832 helib::RuntimeError);
2833 }
2834
TEST_P(TestPtxtBGV,cannotOperateBetweenPtxtsWithDifferentContexts)2835 TEST_P(TestPtxtBGV, cannotOperateBetweenPtxtsWithDifferentContexts)
2836 {
2837 helib::Context different_context(m, p, 2 * r);
2838 std::vector<long> data(context.ea->size(), 1);
2839 helib::Ptxt<helib::BGV> ptxt1(context, data);
2840 helib::Ptxt<helib::BGV> ptxt2(different_context, data);
2841 EXPECT_THROW(ptxt1 *= ptxt2, helib::LogicError);
2842 EXPECT_THROW(ptxt1 -= ptxt2, helib::LogicError);
2843 EXPECT_THROW(ptxt1 += ptxt2, helib::LogicError);
2844 EXPECT_THROW(ptxt1.multiplyBy(ptxt2), helib::LogicError);
2845 EXPECT_THROW(ptxt1.multiplyBy2(ptxt1, ptxt2), helib::LogicError);
2846 }
2847
TEST_P(TestPtxtBGV,preservesDataPassedAsZZX)2848 TEST_P(TestPtxtBGV, preservesDataPassedAsZZX)
2849 {
2850 // Put in x + 1 and make sure we get x + 1 out
2851 NTL::ZZX input_polynomial;
2852 SetCoeff(input_polynomial, 0, 1);
2853 SetCoeff(input_polynomial, 1, 1);
2854
2855 helib::Ptxt<helib::BGV> ptxt(context, input_polynomial);
2856
2857 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2858 EXPECT_EQ(ptxt[i], input_polynomial);
2859 }
2860 }
2861
TEST_P(TestPtxtBGV,setDataWorksWithZZXSameOrderAsPhiMX)2862 TEST_P(TestPtxtBGV, setDataWorksWithZZXSameOrderAsPhiMX)
2863 {
2864 NTL::ZZX phi_mx;
2865 switch (context.alMod.getTag()) {
2866 case helib::PA_GF2_tag:
2867 phi_mx = NTL::conv<NTL::ZZX>(
2868 context.alMod.getDerived(helib::PA_GF2()).getPhimXMod());
2869 break;
2870 case helib::PA_zz_p_tag:
2871 helib::convert(phi_mx,
2872 context.alMod.getDerived(helib::PA_zz_p()).getPhimXMod());
2873 break;
2874 case helib::PA_cx_tag:
2875 // CKKS: do nothing
2876 break;
2877 default:
2878 throw helib::LogicError("No valid tag found in EncryptedArray");
2879 }
2880 // Put phi_mx + 1 as data
2881 NTL::ZZX input_polynomial(phi_mx + 1);
2882
2883 helib::Ptxt<helib::BGV> ptxt(context, input_polynomial);
2884
2885 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2886 EXPECT_EQ(ptxt[i], 1l);
2887 }
2888 }
2889
TEST_P(TestPtxtBGV,decodeSetDataWorks)2890 TEST_P(TestPtxtBGV, decodeSetDataWorks)
2891 {
2892 // Put in x + 1 and make sure we get x + 1 out
2893 NTL::ZZX input_polynomial;
2894 SetCoeff(input_polynomial, 0, 1);
2895 SetCoeff(input_polynomial, 1, 1);
2896
2897 std::vector<NTL::ZZX> test_decoded;
2898 context.ea->decode(test_decoded, input_polynomial);
2899
2900 helib::Ptxt<helib::BGV> ptxt(context);
2901
2902 ptxt.decodeSetData(input_polynomial);
2903
2904 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2905 EXPECT_EQ(ptxt[i], test_decoded[i]);
2906 }
2907 }
2908
TEST_P(TestPtxtBGV,decodeSetDataWorksWithZZXSameOrderAsPhiMX)2909 TEST_P(TestPtxtBGV, decodeSetDataWorksWithZZXSameOrderAsPhiMX)
2910 {
2911 NTL::ZZX phi_mx;
2912 switch (context.alMod.getTag()) {
2913 case helib::PA_GF2_tag:
2914 phi_mx = NTL::conv<NTL::ZZX>(
2915 context.alMod.getDerived(helib::PA_GF2()).getPhimXMod());
2916 break;
2917 case helib::PA_zz_p_tag:
2918 helib::convert(phi_mx,
2919 context.alMod.getDerived(helib::PA_zz_p()).getPhimXMod());
2920 break;
2921 case helib::PA_cx_tag:
2922 // CKKS: do nothing
2923 break;
2924 default:
2925 throw helib::LogicError("No valid tag found in EncryptedArray");
2926 }
2927 // Put phi_mx + 1 as data
2928 NTL::ZZX input_polynomial(phi_mx + 1);
2929
2930 helib::Ptxt<helib::BGV> ptxt(context, input_polynomial);
2931
2932 std::vector<NTL::ZZX> test_decoded;
2933 context.ea->decode(test_decoded, input_polynomial);
2934
2935 for (std::size_t i = 0; i < ptxt.size(); ++i) {
2936 EXPECT_EQ(ptxt[i], test_decoded[i]);
2937 }
2938 }
2939
TEST_P(TestPtxtBGV,canEncryptAndDecryptPtxts)2940 TEST_P(TestPtxtBGV, canEncryptAndDecryptPtxts)
2941 {
2942 helib::buildModChain(context, 30, 2);
2943 helib::SecKey secret_key(context);
2944 secret_key.GenSecKey();
2945 const helib::PubKey& public_key(secret_key);
2946
2947 std::vector<long> data(context.ea->size());
2948 std::iota(data.begin(), data.end(), 0);
2949 helib::Ptxt<helib::BGV> pre_encryption(context, data);
2950 helib::Ctxt ctxt(public_key);
2951 public_key.Encrypt(ctxt, pre_encryption);
2952 helib::Ptxt<helib::BGV> post_decryption(context);
2953 secret_key.Decrypt(post_decryption, ctxt);
2954 for (std::size_t i = 0; i < pre_encryption.size(); ++i) {
2955 EXPECT_EQ(pre_encryption[i], post_decryption[i]);
2956 }
2957 }
2958
TEST_P(TestPtxtBGV,plusEqualsWithCiphertextWorks)2959 TEST_P(TestPtxtBGV, plusEqualsWithCiphertextWorks)
2960 {
2961 helib::buildModChain(context, 30, 2);
2962 helib::SecKey secret_key(context);
2963 secret_key.GenSecKey();
2964 const helib::PubKey& public_key(secret_key);
2965
2966 // Encrypt the augend, addend is plaintext
2967 std::vector<long> augend_data(context.ea->size());
2968 std::vector<long> addend_data(context.ea->size());
2969 std::iota(augend_data.begin(), augend_data.end(), 0);
2970 std::iota(addend_data.begin(), addend_data.end(), 7);
2971 helib::Ptxt<helib::BGV> augend_ptxt(context, augend_data);
2972 helib::Ptxt<helib::BGV> addend(context, addend_data);
2973 helib::Ctxt augend(public_key);
2974 public_key.Encrypt(augend, augend_ptxt);
2975
2976 augend += addend;
2977 augend_ptxt += addend;
2978
2979 // augend_ptxt and augend should now match
2980 helib::Ptxt<helib::BGV> result(context);
2981 secret_key.Decrypt(result, augend);
2982 for (std::size_t i = 0; i < result.size(); ++i) {
2983 EXPECT_EQ(result[i], augend_ptxt[i]);
2984 }
2985 }
2986
TEST_P(TestPtxtBGV,addConstantFromCiphertextWorks)2987 TEST_P(TestPtxtBGV, addConstantFromCiphertextWorks)
2988 {
2989 helib::buildModChain(context, 30, 2);
2990 helib::SecKey secret_key(context);
2991 secret_key.GenSecKey();
2992 const helib::PubKey& public_key(secret_key);
2993
2994 // Encrypt the augend, addend is plaintext
2995 std::vector<long> augend_data(context.ea->size());
2996 std::vector<long> addend_data(context.ea->size());
2997 std::iota(augend_data.begin(), augend_data.end(), 0);
2998 std::iota(addend_data.begin(), addend_data.end(), 7);
2999 helib::Ptxt<helib::BGV> augend_ptxt(context, augend_data);
3000 helib::Ptxt<helib::BGV> addend(context, addend_data);
3001 helib::Ctxt augend(public_key);
3002 public_key.Encrypt(augend, augend_ptxt);
3003
3004 augend.addConstant(addend);
3005 augend_ptxt += addend;
3006
3007 // augend_ptxt and augend should now match
3008 helib::Ptxt<helib::BGV> result(context);
3009 secret_key.Decrypt(result, augend);
3010 for (std::size_t i = 0; i < result.size(); ++i) {
3011 EXPECT_EQ(result[i], augend_ptxt[i]);
3012 }
3013 }
3014
TEST_P(TestPtxtBGV,minusEqualsWithCiphertextWorks)3015 TEST_P(TestPtxtBGV, minusEqualsWithCiphertextWorks)
3016 {
3017 helib::buildModChain(context, 30, 2);
3018 helib::SecKey secret_key(context);
3019 secret_key.GenSecKey();
3020 const helib::PubKey& public_key(secret_key);
3021
3022 // Encrypt the minuend, subtrahend is plaintext
3023 std::vector<long> minuend_data(context.ea->size());
3024 std::vector<long> subtrahend_data(context.ea->size());
3025 std::iota(minuend_data.begin(), minuend_data.end(), 0);
3026 std::iota(subtrahend_data.begin(), subtrahend_data.end(), 7);
3027 helib::Ptxt<helib::BGV> minuend_ptxt(context, minuend_data);
3028 helib::Ptxt<helib::BGV> subtrahend(context, subtrahend_data);
3029 helib::Ctxt minuend(public_key);
3030 public_key.Encrypt(minuend, minuend_ptxt);
3031
3032 minuend -= subtrahend;
3033 minuend_ptxt -= subtrahend;
3034
3035 // minuend_ptxt and minuend should now match
3036 helib::Ptxt<helib::BGV> result(context);
3037 secret_key.Decrypt(result, minuend);
3038 for (std::size_t i = 0; i < result.size(); ++i) {
3039 EXPECT_EQ(result[i], minuend_ptxt[i]);
3040 }
3041 }
3042
TEST_P(TestPtxtBGV,timesEqualsWithCiphertextWorks)3043 TEST_P(TestPtxtBGV, timesEqualsWithCiphertextWorks)
3044 {
3045 helib::buildModChain(context, 30, 2);
3046 helib::SecKey secret_key(context);
3047 secret_key.GenSecKey();
3048 const helib::PubKey& public_key(secret_key);
3049
3050 // Encrypt the multiplier, multiplicand is plaintext
3051 std::vector<long> multiplier_data(context.ea->size());
3052 std::vector<long> multiplicand_data(context.ea->size());
3053 std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
3054 std::iota(multiplicand_data.begin(), multiplicand_data.end(), 7);
3055 helib::Ptxt<helib::BGV> multiplier_ptxt(context, multiplier_data);
3056 helib::Ptxt<helib::BGV> multiplicand(context, multiplicand_data);
3057 helib::Ctxt multiplier(public_key);
3058 public_key.Encrypt(multiplier, multiplier_ptxt);
3059
3060 multiplier *= multiplicand;
3061 multiplier_ptxt *= multiplicand;
3062
3063 // multiplier_ptxt and multiplier should now match
3064 helib::Ptxt<helib::BGV> result(context);
3065 secret_key.Decrypt(result, multiplier);
3066 for (std::size_t i = 0; i < result.size(); ++i) {
3067 EXPECT_EQ(result[i], multiplier_ptxt[i]);
3068 }
3069 }
3070
TEST_P(TestPtxtBGV,multByConstantFromCiphertextWorks)3071 TEST_P(TestPtxtBGV, multByConstantFromCiphertextWorks)
3072 {
3073 helib::buildModChain(context, 30, 2);
3074 helib::SecKey secret_key(context);
3075 secret_key.GenSecKey();
3076 const helib::PubKey& public_key(secret_key);
3077
3078 // Encrypt the multiplier, multiplicand is plaintext
3079 std::vector<long> multiplier_data(context.ea->size());
3080 std::vector<long> multiplicand_data(context.ea->size());
3081 std::iota(multiplier_data.begin(), multiplier_data.end(), 0);
3082 std::iota(multiplicand_data.begin(), multiplicand_data.end(), 7);
3083 helib::Ptxt<helib::BGV> multiplier_ptxt(context, multiplier_data);
3084 helib::Ptxt<helib::BGV> multiplicand(context, multiplicand_data);
3085 helib::Ctxt multiplier(public_key);
3086 public_key.Encrypt(multiplier, multiplier_ptxt);
3087
3088 multiplier.multByConstant(multiplicand);
3089 multiplier_ptxt *= multiplicand;
3090
3091 // multiplier_ptxt and multiplier should now match
3092 helib::Ptxt<helib::BGV> result(context);
3093 secret_key.Decrypt(result, multiplier);
3094 for (std::size_t i = 0; i < result.size(); ++i) {
3095 EXPECT_EQ(result[i], multiplier_ptxt[i]);
3096 }
3097 }
3098
3099 // Useful for testing non-power of 2 for CKKS
3100 // INSTANTIATE_TEST_SUITE_P(various_parameters, TestPtxtCKKS,
3101 // ::testing::Values( 17, 168, 126, 78, 33, 50, 64)
3102 // );
3103
3104 INSTANTIATE_TEST_SUITE_P(
3105 various_Parameters,
3106 TestPtxtCKKS,
3107 ::testing::Values(2 << 1, 2 << 2, 2 << 3, 2 << 4, 2 << 5, 2 << 6, 2 << 7));
3108
3109 INSTANTIATE_TEST_SUITE_P(
3110 various_Parameters,
3111 TestPtxtBGV,
3112 ::testing::Values(BGVParameters(17, 2, 1),
3113 BGVParameters(17, 2, 3),
3114 BGVParameters(168, 13, 1),
3115 BGVParameters(126, 127, 1),
3116 BGVParameters(78, 79, 1),
3117 BGVParameters(33, 19, 2),
3118 // NOTE: This was used because it has 3 good dimensions
3119 // BGVParameters(10005, 37, 1),
3120 BGVParameters(50, 53, 1)));
3121
3122 } // namespace
3123