1 /* Copyright (C) 2012-2019 IBM Corp.
2 * This program is Licensed under the Apache License, Version 2.0
3 * (the "License"); you may not use this file except in compliance
4 * with the License. You may obtain a copy of the License at
5 * http://www.apache.org/licenses/LICENSE-2.0
6 * Unless required by applicable law or agreed to in writing, software
7 * distributed under the License is distributed on an "AS IS" BASIS,
8 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 * See the License for the specific language governing permissions and
10 * limitations under the License. See accompanying LICENSE file.
11 */
12 namespace std {} using namespace std;
13 namespace NTL {} using namespace NTL;
14
15 #include <NTL/BasicThreadPool.h>
16
17 #include <cassert>
18
19 #include <helib/EvalMap.h>
20 #include <helib/hypercube.h>
21 #include <helib/powerful.h>
22 #include <helib/ArgMap.h>
23
24 NTL_CLIENT
25 using namespace helib;
26
27 static bool dry = false; // a dry-run flag
28 static bool noPrint = true;
29
TestIt(long p,long r,long c,long _k,long L,Vec<long> & mvec,Vec<long> & gens,Vec<long> & ords,long useCache)30 void TestIt(long p, long r, long c, long _k,
31 long L, Vec<long>& mvec,
32 Vec<long>& gens, Vec<long>& ords, long useCache)
33 {
34 if (lsize(mvec)<1) { // use default values
35 mvec.SetLength(3); gens.SetLength(3); ords.SetLength(3);
36 mvec[0] = 7; mvec[1] = 3; mvec[2] = 221;
37 gens[0] = 3979; gens[1] = 3095; gens[2] = 3760;
38 ords[0] = 6; ords[1] = 2; ords[2] = -8;
39 }
40 if (!noPrint)
41 cout << "*** TestIt"
42 << (dry? " (dry run):" : ":")
43 << " p=" << p
44 << ", r=" << r
45 << ", c=" << c
46 << ", k=" << _k
47 << ", L=" << L
48 << ", mvec=" << mvec << ", "
49 << ", useCache = " << useCache
50 << endl;
51
52 setTimersOn();
53 setDryRun(false); // Need to get a "real context" to test EvalMap
54
55 // mvec is supposed to include the prime-power factorization of m
56 long nfactors = mvec.length();
57 for (long i = 0; i < nfactors; i++)
58 for (long j = i+1; j < nfactors; j++)
59 assert(GCD(mvec[i], mvec[j]) == 1);
60
61 // multiply all the prime powers to get m itself
62 long m = computeProd(mvec);
63 assert(GCD(p, m) == 1);
64
65 // build a context with these generators and orders
66 vector<long> gens1, ords1;
67 convert(gens1, gens);
68 convert(ords1, ords);
69 Context context(m, p, r, gens1, ords1);
70 buildModChain(context, L, c);
71
72 if (!noPrint) {
73 context.zMStar.printout(); // print structure of Zm* /(p) to cout
74 cout << endl;
75 }
76 long d = context.zMStar.getOrdP();
77 long phim = context.zMStar.getPhiM();
78 long nslots = phim/d;
79
80 setDryRun(dry); // Now we can set the dry-run flag if desired
81
82 SecKey secretKey(context);
83 const PubKey& publicKey = secretKey;
84 secretKey.GenSecKey(); // A Hamming-weight-w secret key
85 addSome1DMatrices(secretKey); // compute key-switching matrices that we need
86 addFrbMatrices(secretKey); // compute key-switching matrices that we need
87
88 // GG defines the plaintext space Z_p[X]/GG(X)
89 ZZX GG;
90 GG = context.alMod.getFactorsOverZZ()[0];
91 EncryptedArray ea(context, GG);
92
93 zz_p::init(context.alMod.getPPowR());
94 zz_pX F;
95 random(F, phim); // a random polynomial of degree phi(m)-1 modulo p
96
97 // convert F to powerful representation: cube represents a multi-variate
98 // polynomial with as many variables Xi as factors mi in mvec. cube has
99 // degree phi(mi) in the variable Xi, and the coefficients are given
100 // in lexicographic order.
101
102 // compute tables for converting between powerful and zz_pX
103 PowerfulTranslationIndexes ind(mvec); // indpendent of p
104 PowerfulConversion pConv(ind); // depends on p
105
106 HyperCube<zz_p> cube(pConv.getShortSig());
107 pConv.polyToPowerful(cube, F);
108
109 // Sanity check: convert back and compare
110 zz_pX F2;
111 pConv.powerfulToPoly(F2, cube);
112 if (F != F2) {
113 cout << "BAD\n";
114 if (!noPrint) cout << " @@@ conversion error ):\n";
115 }
116 // pack the coefficients from cube in the plaintext slots: the j'th
117 // slot contains the polynomial pj(X) = \sum_{t=0}^{d-1} cube[jd+t] X^t
118 vector<ZZX> val1;
119 val1.resize(nslots);
120 for (long i = 0; i < phim; i++) {
121 val1[i/d] += conv<ZZX>(conv<ZZ>(cube[i])) << (i % d);
122 }
123 PlaintextArray pa1(ea);
124 encode(ea, pa1, val1);
125
126 Ctxt ctxt(publicKey);
127 ea.encrypt(ctxt, publicKey, pa1);
128
129 resetAllTimers();
130 HELIB_NTIMER_START(ALL);
131
132 // Compute homomorphically the transformation that takes the
133 // coefficients packed in the slots and produces the polynomial
134 // corresponding to cube
135
136 if (!noPrint) CheckCtxt(ctxt, "init");
137
138 if (!noPrint) cout << "build EvalMap\n";
139 EvalMap map(ea, /*minimal=*/false, mvec,
140 /*invert=*/false, /*build_cache=*/false, /*normal_basis=*/false);
141 // compute the transformation to apply
142
143 if (!noPrint) cout << "apply EvalMap\n";
144 if (useCache) map.upgrade();
145 map.apply(ctxt); // apply the transformation to ctxt
146 if (!noPrint) CheckCtxt(ctxt, "EvalMap");
147 if (!noPrint) cout << "check results\n";
148
149 ZZX FF1;
150 secretKey.Decrypt(FF1, ctxt);
151 zz_pX F1 = conv<zz_pX>(FF1);
152
153 if (F1 == F)
154 cout << "GOOD\n";
155 else
156 cout << "BAD\n";
157
158 publicKey.Encrypt(ctxt, balanced_zzX(F1));
159 if (!noPrint) CheckCtxt(ctxt, "init");
160
161 // Compute homomorphically the inverse transformation that takes the
162 // polynomial corresponding to cube and produces the coefficients
163 // packed in the slots
164
165 if (!noPrint) cout << "build EvalMap\n";
166 EvalMap imap(ea, /*minimal=*/false, mvec,
167 /*invert=*/true, /*build_cache=*/false, /*normal_basis=*/false);
168 // compute the transformation to apply
169 if (!noPrint) cout << "apply EvalMap\n";
170 if (useCache) imap.upgrade();
171 imap.apply(ctxt); // apply the transformation to ctxt
172 if (!noPrint) {
173 CheckCtxt(ctxt, "EvalMap");
174 cout << "check results\n";
175 }
176 PlaintextArray pa2(ea);
177 ea.decrypt(ctxt, secretKey, pa2);
178
179 if (equals(ea, pa1, pa2))
180 cout << "GOOD\n";
181 else
182 cout << "BAD\n";
183 HELIB_NTIMER_STOP(ALL);
184
185 if (!noPrint) {
186 cout << "\n*********\n";
187 printAllTimers();
188 cout << endl;
189 }
190 }
191
192
193 /* Usage: Test_EvalMap_x.exe [ name=value ]...
194 * p plaintext base [ default=2 ]
195 * r lifting [ default=1 ]
196 * c number of columns in the key-switching matrices [ default=2 ]
197 * k security parameter [ default=80 ]
198 * L # of bits in the modulus chain
199 * s minimum number of slots [ default=0 ]
200 * seed PRG seed [ default=0 ]
201 * mvec use specified factorization of m
202 * e.g., mvec='[5 3 187]'
203 * gens use specified vector of generators
204 * e.g., gens='[562 1871 751]'
205 * ords use specified vector of orders
206 * e.g., ords='[4 2 -4]', negative means 'bad'
207 */
main(int argc,char * argv[])208 int main(int argc, char *argv[])
209 {
210 ArgMap amap;
211
212 long p=2;
213 amap.arg("p", p, "plaintext base");
214
215 long r=1;
216 amap.arg("r", r, "lifting");
217
218 long c=2;
219 amap.arg("c", c, "number of columns in the key-switching matrices");
220
221 long k=80;
222 amap.arg("k", k, "security parameter");
223
224 long L=300;
225 amap.arg("L", L, "# of bits in the modulus chain");
226
227 long s=0;
228 amap.arg("s", s, "minimum number of slots");
229
230 long seed=0;
231 amap.arg("seed", seed, "PRG seed");
232
233 Vec<long> mvec;
234 amap.arg("mvec", mvec, "use specified factorization of m", nullptr);
235 amap.note("e.g., mvec='[7 3 221]'");
236
237 Vec<long> gens;
238 amap.arg("gens", gens, "use specified vector of generators", nullptr);
239 amap.note("e.g., gens='[3979 3095 3760]'");
240
241 Vec<long> ords;
242 amap.arg("ords", ords, "use specified vector of orders", nullptr);
243 amap.note("e.g., ords='[6 2 -8]', negative means 'bad'");
244
245 amap.arg("dry", dry, "a dry-run flag to check the noise");
246
247 long nthreads=1;
248 amap.arg("nthreads", nthreads, "number of threads");
249
250 amap.arg("noPrint", noPrint, "suppress printouts");
251
252 long useCache=0;
253 amap.arg("useCache", useCache, "0: zzX cache, 2: DCRT cache");
254
255 amap.parse(argc, argv);
256
257 SetNumThreads(nthreads);
258
259 SetSeed(conv<ZZ>(seed));
260 TestIt(p, r, c, k, L, mvec, gens, ords, useCache);
261 }
262
263 // ./Test_EvalMap_x mvec="[73 433]" gens="[18620 12995]" ords="[72 -6]"
264