1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 
13 #include <iostream>
14 #include <NTL/BasicThreadPool.h>
15 #include <helib/intraSlot.h>
16 #include <helib/tableLookup.h>
17 #include <helib/debugging.h>
18 
19 #include "gtest/gtest.h"
20 #include "test_common.h"
21 
22 namespace {
23 
24 struct Parameters
25 {
Parameters__anonb16697110111::Parameters26   Parameters(long prm,
27              long bitSize,
28              long outSize,
29              long nTests,
30              bool bootstrap,
31              long seed,
32              long nthreads) :
33       prm(prm),
34       bitSize(bitSize),
35       outSize(outSize),
36       nTests(nTests),
37       bootstrap(bootstrap),
38       seed(seed),
39       nthreads(nthreads){};
40 
41   long prm;       // parameter size (0-tiny,...,4-huge)
42   long bitSize;   // bitSize of input integers (<=32)
43   long outSize;   // bitSize of output integers
44   long nTests;    // number of tests to run
45   bool bootstrap; // test multiplication with bootstrapping
46   long seed;      // PRG seed
47   long nthreads;  // number of threads
48 
operator <<(std::ostream & os,const Parameters & params)49   friend std::ostream& operator<<(std::ostream& os, const Parameters& params)
50   {
51     return os << "{"
52               << "prm=" << params.prm << ","
53               << "bitSize=" << params.bitSize << ","
54               << "outSize=" << params.outSize << ","
55               << "nTests=" << params.nTests << ","
56               << "bootstrap=" << params.bootstrap << ","
57               << "seed=" << params.seed << ","
58               << "nthreads=" << params.nthreads << "}";
59   };
60 };
61 
62 class GTestTableLookup : public ::testing::TestWithParam<Parameters>
63 {
64 protected:
65   // clang-format off
66   static constexpr long mValues[][15] = {
67   //  {p,phi(m),    m,  d, m1, m2, m3,   g1,   g2,   g3,ord1,ord2,ord3,  B, c}
68       {2,    48,  105, 12,  3, 35,  0,   71,   76,    0,   2,   2,   0, 25, 2},
69       {2,   600, 1023, 10, 11, 93,  0,  838,  584,    0,  10,   6,   0, 25, 2},
70       {2,  2304, 4641, 24,  7,  3,221, 3979, 3095, 3760,   6,   2,  -8, 25, 3},
71       {2, 15004,15709, 22, 23,683,  0, 4099,13663,    0,  22,  31,   0, 25, 3},
72       {2, 27000,32767, 15, 31,  7,151,11628,28087,25824,  30,   6, -10, 28, 4}
73   };
74   // clang-format on
75 
76   // Utility encryption/decryption methods
encryptIndex(std::vector<helib::Ctxt> & ei,long index,const helib::SecKey & sKey)77   static void encryptIndex(std::vector<helib::Ctxt>& ei,
78                            long index,
79                            const helib::SecKey& sKey)
80   {
81     for (long i = 0; i < helib::lsize(ei); i++)
82       sKey.Encrypt(ei[i], NTL::to_ZZX((index >> i) & 1)); // i'th bit of index
83   }
84 
decryptIndex(std::vector<helib::Ctxt> & ei,const helib::SecKey & sKey)85   static long decryptIndex(std::vector<helib::Ctxt>& ei,
86                            const helib::SecKey& sKey)
87   {
88     long num = 0;
89     for (long i = 0; i < helib::lsize(ei); i++) {
90       NTL::ZZX poly;
91       sKey.Decrypt(poly, ei[i]);
92       num += to_long(NTL::ConstTerm(poly)) << i;
93     }
94     return num;
95   }
96 
validatePrm(const long prm)97   static long validatePrm(const long prm)
98   {
99     if (prm < 0 || prm >= 5)
100       throw std::invalid_argument("Invalid prm value");
101     return prm;
102   };
103 
validateBitSize(const long bitSize)104   static long validateBitSize(const long bitSize)
105   {
106     if (bitSize > 7)
107       throw std::invalid_argument("Invalid bitSize value: must be <=7");
108     else if (bitSize <= 0)
109       throw std::invalid_argument("Invalid bitSize value: must be >0");
110     return bitSize;
111   };
112 
calculateMvec(const long * vals)113   static NTL::Vec<long> calculateMvec(const long* vals)
114   {
115     NTL::Vec<long> mvec;
116     append(mvec, vals[4]);
117     if (vals[5] > 1)
118       append(mvec, vals[5]);
119     if (vals[6] > 1)
120       append(mvec, vals[6]);
121     return mvec;
122   };
123 
calculateGens(const long * vals)124   static std::vector<long> calculateGens(const long* vals)
125   {
126     std::vector<long> gens;
127     gens.push_back(vals[7]);
128     if (vals[8] > 1)
129       gens.push_back(vals[8]);
130     if (vals[9] > 1)
131       gens.push_back(vals[9]);
132     return gens;
133   };
134 
calculateOrds(const long * vals)135   static std::vector<long> calculateOrds(const long* vals)
136   {
137     std::vector<long> ords;
138     ords.push_back(vals[10]);
139     if (abs(vals[11]) > 1)
140       ords.push_back(vals[11]);
141     if (abs(vals[12]) > 1)
142       ords.push_back(vals[12]);
143     return ords;
144   };
145 
calculateLevels(const bool bootstrap,const long bitSize)146   static long calculateLevels(const bool bootstrap, const long bitSize)
147   {
148     long L;
149     if (bootstrap)
150       L = 900; // that should be enough
151     else
152       L = 30 * (5 + bitSize);
153     return L;
154   };
155 
printPreContextPrepDiagnostics(const long bitSize,const long outSize,const long nTests,const long nthreads)156   static void printPreContextPrepDiagnostics(const long bitSize,
157                                              const long outSize,
158                                              const long nTests,
159                                              const long nthreads)
160   {
161     if (helib_test::verbose) {
162       std::cout << "input bitSize=" << bitSize
163                 << ", output size bound=" << outSize << ", running " << nTests
164                 << " tests for each function\n";
165       if (nthreads > 1)
166         std::cout << "  using " << NTL::AvailableThreads() << " threads\n";
167       std::cout << "computing key-independent tables..." << std::flush;
168     }
169   };
170 
printPostContextPrepDiagnostics(const helib::Context & context,const long L)171   static void printPostContextPrepDiagnostics(const helib::Context& context,
172                                               const long L)
173   {
174     if (helib_test::verbose) {
175       std::cout << " done.\n";
176       context.zMStar.printout();
177       std::cout << " L=" << L << std::endl;
178     };
179   }
180 
181   // Not static as many instance variables are required.
prepareContext(helib::Context & context)182   helib::Context& prepareContext(helib::Context& context)
183   {
184     printPreContextPrepDiagnostics(bitSize, outSize, nTests, nthreads);
185     helib::buildModChain(context, L, c, /*willBeBootstrappable*/ bootstrap);
186     if (bootstrap) {
187       context.enableBootStrapping(mvec);
188     }
189     helib::buildUnpackSlotEncoding(unpackSlotEncoding, *context.ea);
190     printPostContextPrepDiagnostics(context, L);
191     return context;
192   };
193 
prepareSecKey(helib::SecKey & secretKey,const bool bootstrap)194   static void prepareSecKey(helib::SecKey& secretKey, const bool bootstrap)
195   {
196     if (helib_test::verbose)
197       std::cout << "\ncomputing key-dependent tables..." << std::flush;
198     secretKey.GenSecKey();
199     helib::addSome1DMatrices(secretKey); // compute key-switching matrices
200     helib::addFrbMatrices(secretKey);
201     if (bootstrap)
202       secretKey.genRecryptData();
203     if (helib_test::verbose)
204       std::cout << " done\n";
205   };
206 
setSeedIfNeeded(const long seed)207   static void setSeedIfNeeded(const long seed)
208   {
209     if (seed)
210       NTL::SetSeed(NTL::ZZ(seed));
211     ;
212   };
213 
setThreadsIfNeeded(const long nthreads)214   static void setThreadsIfNeeded(const long nthreads)
215   {
216     if (nthreads > 1)
217       NTL::SetNumThreads(nthreads);
218   };
219 
GTestTableLookup()220   GTestTableLookup() :
221       prm(validatePrm(GetParam().prm)),
222       bitSize(validateBitSize(GetParam().bitSize)),
223       outSize(GetParam().outSize),
224       nTests(GetParam().nTests),
225       bootstrap(GetParam().bootstrap),
226       seed((setSeedIfNeeded(GetParam().seed), GetParam().seed)),
227       nthreads((setThreadsIfNeeded(GetParam().nthreads), GetParam().nthreads)),
228       vals(mValues[prm]),
229       p(vals[0]),
230       m(vals[2]),
231       mvec(calculateMvec(vals)),
232       gens(calculateGens(vals)),
233       ords(calculateOrds(vals)),
234       c(vals[14]),
235       L(calculateLevels(bootstrap, bitSize)),
236       context(m, p, /*r=*/1, gens, ords),
237       secretKey(prepareContext(context))
238   {
239     prepareSecKey(secretKey, bootstrap);
240   };
241 
242   std::vector<helib::zzX> unpackSlotEncoding;
243   const long prm;
244   const long bitSize;
245   const long outSize;
246   const long nTests;
247   const bool bootstrap;
248   const long seed;
249   const long nthreads;
250   const long* vals;
251   const long p;
252   const long m;
253   const NTL::Vec<long> mvec;
254   const std::vector<long> gens;
255   const std::vector<long> ords;
256   const long c;
257   const long L;
258   helib::Context context;
259   helib::SecKey secretKey;
260 
SetUp()261   void SetUp() override
262   {
263     helib::activeContext = &context; // make things a little easier sometimes
264     helib::setupDebugGlobals(&secretKey, context.ea);
265   };
266 
TearDown()267   virtual void TearDown() override
268   {
269 #ifdef HELIB_DEBUG
270     helib::cleanupDebugGlobals();
271 #endif
272   }
273 
274 public:
TearDownTestCase()275   static void TearDownTestCase()
276   {
277     if (helib_test::verbose) {
278       helib::printAllTimers(std::cout);
279     }
280   };
281 };
282 
283 constexpr long GTestTableLookup::mValues[][15];
284 
TEST_P(GTestTableLookup,lookupFunctionsCorrectly)285 TEST_P(GTestTableLookup, lookupFunctionsCorrectly)
286 {
287   // Build a table s.t. T[i] = 2^{outSize -1}/(i+1), i=0,...,2^bitSize -1
288   std::vector<helib::zzX> T;
289   helib::buildLookupTable(
290       T,
291       [](double x) { return 1 / (x + 1.0); },
292       bitSize,
293       /*scale_in=*/0,
294       /*sign_in=*/0,
295       outSize,
296       /*scale_out=*/1 - outSize,
297       /*sign_out=*/0,
298       *(secretKey.getContext().ea));
299 
300   ASSERT_EQ(helib::lsize(T), 1L << bitSize);
301   for (long i = 0; i < helib::lsize(T); i++) {
302     helib::Ctxt c(secretKey);
303     std::vector<helib::Ctxt> ei(bitSize, c);
304     encryptIndex(ei, i, secretKey); // encrypt the index
305     helib::tableLookup(c,
306                        T,
307                        helib::CtPtrs_vectorCt(ei)); // get the encrypted entry
308     // decrypt and compare
309     NTL::ZZX poly;
310     secretKey.Decrypt(poly, c); // decrypt
311     helib::zzX poly2;
312     helib::convert(poly2, poly); // convert to zzX
313     EXPECT_EQ(poly2, T[i]) << "testLookup error: decrypted T[" << i << "]\n";
314   }
315 }
316 
TEST_P(GTestTableLookup,writeinFunctionsCorrectly)317 TEST_P(GTestTableLookup, writeinFunctionsCorrectly)
318 {
319   long tSize = 1L << bitSize; // table size
320 
321   // encrypt a random table
322   std::vector<long> pT(tSize, 0);                            // plaintext table
323   std::vector<helib::Ctxt> T(tSize, helib::Ctxt(secretKey)); // encrypted table
324   for (long i = 0; i < bitSize; i++) {
325     long bit = NTL::RandomBits_long(1); // a random bit
326     secretKey.Encrypt(T[i], NTL::to_ZZX(bit));
327     pT[i] = bit;
328   }
329 
330   // Add 1 to 20 random entries in the table
331   for (long count = 0; count < nTests; count++) {
332     // encrypt a random index into the table
333     long index = NTL::RandomBnd(tSize); // 0 <= index < tSize
334     std::vector<helib::Ctxt> I(bitSize, helib::Ctxt(secretKey));
335     encryptIndex(I, index, secretKey);
336 
337     // do the table write-in
338     tableWriteIn(helib::CtPtrs_vectorCt(T),
339                  helib::CtPtrs_vectorCt(I),
340                  &unpackSlotEncoding);
341     pT[index]++; // add 1 to entry 'index' in the plaintext table
342   }
343 
344   // Check that the ciphertext and plaintext tables still match
345   for (int i = 0; i < tSize; i++) {
346     NTL::ZZX poly;
347     secretKey.Decrypt(poly, T[i]);
348     long decrypted = to_long(NTL::ConstTerm(poly));
349     long p = T[i].getPtxtSpace();
350     ASSERT_EQ((pT[i] - decrypted) % p, 0) // should be equal mod p
351         << "testWritein error: decrypted T[" << i << "]=" << decrypted
352         << " but should be " << pT[i] << " (mod " << p << ")\n";
353   }
354 }
355 
356 INSTANTIATE_TEST_SUITE_P(typicalParameters,
357                          GTestTableLookup,
358                          ::testing::Values(
359                              // SLOW
360                              Parameters(1, 5, 0, 3, false, 0, 1)
361                              // FAST
362                              // Parameters(0, 5, 0, 3, false, 0, 1)
363                              ));
364 
365 } // namespace
366