1 /* Copyright (C) 2012-2020 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12
13 #include <iostream>
14 #include <NTL/BasicThreadPool.h>
15 #include <helib/intraSlot.h>
16 #include <helib/tableLookup.h>
17 #include <helib/debugging.h>
18
19 #include "gtest/gtest.h"
20 #include "test_common.h"
21
22 namespace {
23
24 struct Parameters
25 {
Parameters__anonb16697110111::Parameters26 Parameters(long prm,
27 long bitSize,
28 long outSize,
29 long nTests,
30 bool bootstrap,
31 long seed,
32 long nthreads) :
33 prm(prm),
34 bitSize(bitSize),
35 outSize(outSize),
36 nTests(nTests),
37 bootstrap(bootstrap),
38 seed(seed),
39 nthreads(nthreads){};
40
41 long prm; // parameter size (0-tiny,...,4-huge)
42 long bitSize; // bitSize of input integers (<=32)
43 long outSize; // bitSize of output integers
44 long nTests; // number of tests to run
45 bool bootstrap; // test multiplication with bootstrapping
46 long seed; // PRG seed
47 long nthreads; // number of threads
48
operator <<(std::ostream & os,const Parameters & params)49 friend std::ostream& operator<<(std::ostream& os, const Parameters& params)
50 {
51 return os << "{"
52 << "prm=" << params.prm << ","
53 << "bitSize=" << params.bitSize << ","
54 << "outSize=" << params.outSize << ","
55 << "nTests=" << params.nTests << ","
56 << "bootstrap=" << params.bootstrap << ","
57 << "seed=" << params.seed << ","
58 << "nthreads=" << params.nthreads << "}";
59 };
60 };
61
62 class GTestTableLookup : public ::testing::TestWithParam<Parameters>
63 {
64 protected:
65 // clang-format off
66 static constexpr long mValues[][15] = {
67 // {p,phi(m), m, d, m1, m2, m3, g1, g2, g3,ord1,ord2,ord3, B, c}
68 {2, 48, 105, 12, 3, 35, 0, 71, 76, 0, 2, 2, 0, 25, 2},
69 {2, 600, 1023, 10, 11, 93, 0, 838, 584, 0, 10, 6, 0, 25, 2},
70 {2, 2304, 4641, 24, 7, 3,221, 3979, 3095, 3760, 6, 2, -8, 25, 3},
71 {2, 15004,15709, 22, 23,683, 0, 4099,13663, 0, 22, 31, 0, 25, 3},
72 {2, 27000,32767, 15, 31, 7,151,11628,28087,25824, 30, 6, -10, 28, 4}
73 };
74 // clang-format on
75
76 // Utility encryption/decryption methods
encryptIndex(std::vector<helib::Ctxt> & ei,long index,const helib::SecKey & sKey)77 static void encryptIndex(std::vector<helib::Ctxt>& ei,
78 long index,
79 const helib::SecKey& sKey)
80 {
81 for (long i = 0; i < helib::lsize(ei); i++)
82 sKey.Encrypt(ei[i], NTL::to_ZZX((index >> i) & 1)); // i'th bit of index
83 }
84
decryptIndex(std::vector<helib::Ctxt> & ei,const helib::SecKey & sKey)85 static long decryptIndex(std::vector<helib::Ctxt>& ei,
86 const helib::SecKey& sKey)
87 {
88 long num = 0;
89 for (long i = 0; i < helib::lsize(ei); i++) {
90 NTL::ZZX poly;
91 sKey.Decrypt(poly, ei[i]);
92 num += to_long(NTL::ConstTerm(poly)) << i;
93 }
94 return num;
95 }
96
validatePrm(const long prm)97 static long validatePrm(const long prm)
98 {
99 if (prm < 0 || prm >= 5)
100 throw std::invalid_argument("Invalid prm value");
101 return prm;
102 };
103
validateBitSize(const long bitSize)104 static long validateBitSize(const long bitSize)
105 {
106 if (bitSize > 7)
107 throw std::invalid_argument("Invalid bitSize value: must be <=7");
108 else if (bitSize <= 0)
109 throw std::invalid_argument("Invalid bitSize value: must be >0");
110 return bitSize;
111 };
112
calculateMvec(const long * vals)113 static NTL::Vec<long> calculateMvec(const long* vals)
114 {
115 NTL::Vec<long> mvec;
116 append(mvec, vals[4]);
117 if (vals[5] > 1)
118 append(mvec, vals[5]);
119 if (vals[6] > 1)
120 append(mvec, vals[6]);
121 return mvec;
122 };
123
calculateGens(const long * vals)124 static std::vector<long> calculateGens(const long* vals)
125 {
126 std::vector<long> gens;
127 gens.push_back(vals[7]);
128 if (vals[8] > 1)
129 gens.push_back(vals[8]);
130 if (vals[9] > 1)
131 gens.push_back(vals[9]);
132 return gens;
133 };
134
calculateOrds(const long * vals)135 static std::vector<long> calculateOrds(const long* vals)
136 {
137 std::vector<long> ords;
138 ords.push_back(vals[10]);
139 if (abs(vals[11]) > 1)
140 ords.push_back(vals[11]);
141 if (abs(vals[12]) > 1)
142 ords.push_back(vals[12]);
143 return ords;
144 };
145
calculateLevels(const bool bootstrap,const long bitSize)146 static long calculateLevels(const bool bootstrap, const long bitSize)
147 {
148 long L;
149 if (bootstrap)
150 L = 900; // that should be enough
151 else
152 L = 30 * (5 + bitSize);
153 return L;
154 };
155
printPreContextPrepDiagnostics(const long bitSize,const long outSize,const long nTests,const long nthreads)156 static void printPreContextPrepDiagnostics(const long bitSize,
157 const long outSize,
158 const long nTests,
159 const long nthreads)
160 {
161 if (helib_test::verbose) {
162 std::cout << "input bitSize=" << bitSize
163 << ", output size bound=" << outSize << ", running " << nTests
164 << " tests for each function\n";
165 if (nthreads > 1)
166 std::cout << " using " << NTL::AvailableThreads() << " threads\n";
167 std::cout << "computing key-independent tables..." << std::flush;
168 }
169 };
170
printPostContextPrepDiagnostics(const helib::Context & context,const long L)171 static void printPostContextPrepDiagnostics(const helib::Context& context,
172 const long L)
173 {
174 if (helib_test::verbose) {
175 std::cout << " done.\n";
176 context.zMStar.printout();
177 std::cout << " L=" << L << std::endl;
178 };
179 }
180
181 // Not static as many instance variables are required.
prepareContext(helib::Context & context)182 helib::Context& prepareContext(helib::Context& context)
183 {
184 printPreContextPrepDiagnostics(bitSize, outSize, nTests, nthreads);
185 helib::buildModChain(context, L, c, /*willBeBootstrappable*/ bootstrap);
186 if (bootstrap) {
187 context.enableBootStrapping(mvec);
188 }
189 helib::buildUnpackSlotEncoding(unpackSlotEncoding, *context.ea);
190 printPostContextPrepDiagnostics(context, L);
191 return context;
192 };
193
prepareSecKey(helib::SecKey & secretKey,const bool bootstrap)194 static void prepareSecKey(helib::SecKey& secretKey, const bool bootstrap)
195 {
196 if (helib_test::verbose)
197 std::cout << "\ncomputing key-dependent tables..." << std::flush;
198 secretKey.GenSecKey();
199 helib::addSome1DMatrices(secretKey); // compute key-switching matrices
200 helib::addFrbMatrices(secretKey);
201 if (bootstrap)
202 secretKey.genRecryptData();
203 if (helib_test::verbose)
204 std::cout << " done\n";
205 };
206
setSeedIfNeeded(const long seed)207 static void setSeedIfNeeded(const long seed)
208 {
209 if (seed)
210 NTL::SetSeed(NTL::ZZ(seed));
211 ;
212 };
213
setThreadsIfNeeded(const long nthreads)214 static void setThreadsIfNeeded(const long nthreads)
215 {
216 if (nthreads > 1)
217 NTL::SetNumThreads(nthreads);
218 };
219
GTestTableLookup()220 GTestTableLookup() :
221 prm(validatePrm(GetParam().prm)),
222 bitSize(validateBitSize(GetParam().bitSize)),
223 outSize(GetParam().outSize),
224 nTests(GetParam().nTests),
225 bootstrap(GetParam().bootstrap),
226 seed((setSeedIfNeeded(GetParam().seed), GetParam().seed)),
227 nthreads((setThreadsIfNeeded(GetParam().nthreads), GetParam().nthreads)),
228 vals(mValues[prm]),
229 p(vals[0]),
230 m(vals[2]),
231 mvec(calculateMvec(vals)),
232 gens(calculateGens(vals)),
233 ords(calculateOrds(vals)),
234 c(vals[14]),
235 L(calculateLevels(bootstrap, bitSize)),
236 context(m, p, /*r=*/1, gens, ords),
237 secretKey(prepareContext(context))
238 {
239 prepareSecKey(secretKey, bootstrap);
240 };
241
242 std::vector<helib::zzX> unpackSlotEncoding;
243 const long prm;
244 const long bitSize;
245 const long outSize;
246 const long nTests;
247 const bool bootstrap;
248 const long seed;
249 const long nthreads;
250 const long* vals;
251 const long p;
252 const long m;
253 const NTL::Vec<long> mvec;
254 const std::vector<long> gens;
255 const std::vector<long> ords;
256 const long c;
257 const long L;
258 helib::Context context;
259 helib::SecKey secretKey;
260
SetUp()261 void SetUp() override
262 {
263 helib::activeContext = &context; // make things a little easier sometimes
264 helib::setupDebugGlobals(&secretKey, context.ea);
265 };
266
TearDown()267 virtual void TearDown() override
268 {
269 #ifdef HELIB_DEBUG
270 helib::cleanupDebugGlobals();
271 #endif
272 }
273
274 public:
TearDownTestCase()275 static void TearDownTestCase()
276 {
277 if (helib_test::verbose) {
278 helib::printAllTimers(std::cout);
279 }
280 };
281 };
282
283 constexpr long GTestTableLookup::mValues[][15];
284
TEST_P(GTestTableLookup,lookupFunctionsCorrectly)285 TEST_P(GTestTableLookup, lookupFunctionsCorrectly)
286 {
287 // Build a table s.t. T[i] = 2^{outSize -1}/(i+1), i=0,...,2^bitSize -1
288 std::vector<helib::zzX> T;
289 helib::buildLookupTable(
290 T,
291 [](double x) { return 1 / (x + 1.0); },
292 bitSize,
293 /*scale_in=*/0,
294 /*sign_in=*/0,
295 outSize,
296 /*scale_out=*/1 - outSize,
297 /*sign_out=*/0,
298 *(secretKey.getContext().ea));
299
300 ASSERT_EQ(helib::lsize(T), 1L << bitSize);
301 for (long i = 0; i < helib::lsize(T); i++) {
302 helib::Ctxt c(secretKey);
303 std::vector<helib::Ctxt> ei(bitSize, c);
304 encryptIndex(ei, i, secretKey); // encrypt the index
305 helib::tableLookup(c,
306 T,
307 helib::CtPtrs_vectorCt(ei)); // get the encrypted entry
308 // decrypt and compare
309 NTL::ZZX poly;
310 secretKey.Decrypt(poly, c); // decrypt
311 helib::zzX poly2;
312 helib::convert(poly2, poly); // convert to zzX
313 EXPECT_EQ(poly2, T[i]) << "testLookup error: decrypted T[" << i << "]\n";
314 }
315 }
316
TEST_P(GTestTableLookup,writeinFunctionsCorrectly)317 TEST_P(GTestTableLookup, writeinFunctionsCorrectly)
318 {
319 long tSize = 1L << bitSize; // table size
320
321 // encrypt a random table
322 std::vector<long> pT(tSize, 0); // plaintext table
323 std::vector<helib::Ctxt> T(tSize, helib::Ctxt(secretKey)); // encrypted table
324 for (long i = 0; i < bitSize; i++) {
325 long bit = NTL::RandomBits_long(1); // a random bit
326 secretKey.Encrypt(T[i], NTL::to_ZZX(bit));
327 pT[i] = bit;
328 }
329
330 // Add 1 to 20 random entries in the table
331 for (long count = 0; count < nTests; count++) {
332 // encrypt a random index into the table
333 long index = NTL::RandomBnd(tSize); // 0 <= index < tSize
334 std::vector<helib::Ctxt> I(bitSize, helib::Ctxt(secretKey));
335 encryptIndex(I, index, secretKey);
336
337 // do the table write-in
338 tableWriteIn(helib::CtPtrs_vectorCt(T),
339 helib::CtPtrs_vectorCt(I),
340 &unpackSlotEncoding);
341 pT[index]++; // add 1 to entry 'index' in the plaintext table
342 }
343
344 // Check that the ciphertext and plaintext tables still match
345 for (int i = 0; i < tSize; i++) {
346 NTL::ZZX poly;
347 secretKey.Decrypt(poly, T[i]);
348 long decrypted = to_long(NTL::ConstTerm(poly));
349 long p = T[i].getPtxtSpace();
350 ASSERT_EQ((pT[i] - decrypted) % p, 0) // should be equal mod p
351 << "testWritein error: decrypted T[" << i << "]=" << decrypted
352 << " but should be " << pT[i] << " (mod " << p << ")\n";
353 }
354 }
355
356 INSTANTIATE_TEST_SUITE_P(typicalParameters,
357 GTestTableLookup,
358 ::testing::Values(
359 // SLOW
360 Parameters(1, 5, 0, 3, false, 0, 1)
361 // FAST
362 // Parameters(0, 5, 0, 3, false, 0, 1)
363 ));
364
365 } // namespace
366