1 /** Test_AES.cpp - test program for homomorphic AES using HElib
2 */
3 #if defined(__unix__) || defined(__unix) || defined(unix)
4 #include <sys/time.h>
5 #include <sys/resource.h>
6 #endif
7
8 namespace std {} using namespace std;
9 namespace NTL {} using namespace NTL;
10 #include <cstring>
11 #include "homAES.h"
12 #include "Ctxt.h"
13
14 static long mValues[][14] = {
15 //{ p, phi(m), m, d, m1, m2, m3, g1, g2, g3,ord1,ord2,ord3, c_m}
16 { 2, 512, 771, 16,771, 0, 0, 5, 0, 0,-32, 0, 0, 100}, // m=(3)*{257} :-( m/phim(m)=1.5 C=77 D=2 E=4
17 { 2, 4096, 4369, 16, 17, 257, 0, 258, 4115, 0, 16,-16, 0, 100}, // m=17*(257) :-( m/phim(m)=1.06 C=61 D=3 E=4
18 { 2, 16384, 21845, 16, 17, 5, 257, 8996,17477,21591, 16, 4,-16,1600}, // m=5*17*(257) :-( m/phim(m)=1.33 C=65 D=4 E=4
19 { 2, 23040, 28679, 24, 17, 7, 241, 15184, 4098,28204, 16, 6,-10,1500}, // m=7*17*(241) m/phim(m)=1.24 C=63 D=4 E=3
20 { 2, 46080, 53261, 24, 17,13, 241, 43863,28680,15913, 16, 12,-10, 100}, // m=13*17*(241) m/phim(m)=1.15 C=69 D=4 E=3
21 { 2, 64512, 65281, 48, 97,673, 0, 43073,22214, 0, 96,-14, 0, 100} // m=97*(673) :-( m/phim(m)=1.01 C=169 D=3 E=4
22 };
23
24 #ifdef DEBUG_PRINTOUT
25 extern SecKey* dbgKey;
26 extern EncryptedArray* dbgEa;
27 #define FLAG_PRINT_ZZX 1
28 #define FLAG_PRINT_POLY 2
29 #define FLAG_PRINT_VEC 4
30 extern void decryptAndPrint(ostream& s, const Ctxt& ctxt, const SecKey& sk,
31 const EncryptedArray& ea, long flags=0);
32 #endif
33
34 void printState(Vec<uint8_t>& st);
35 extern long AESKeyExpansion(unsigned char RoundKey[],
36 unsigned char Key[], int NN);
37 extern void Cipher(unsigned char out[16],
38 unsigned char in[16], unsigned char RoundKey[], int Nr);
39
main(int argc,char ** argv)40 int main(int argc, char **argv)
41 {
42 ArgMapping amap;
43
44 long idx = 0;
45 amap.arg("sz", idx, "parameter-sets: toy=0 through huge=5");
46
47 long c=3;
48 amap.arg("c", c, "number of columns in the key-switching matrices");
49
50 long L=0;
51 amap.arg("L", L, "# of levels in the modulus chain", "heuristic");
52
53 long B=23;
54 amap.arg("B", B, "# of bits per level (only 64-bit machines)");
55
56 bool boot=false;
57 amap.arg("boot", boot, "includes bootstrapping");
58
59 bool packed=true;
60 amap.arg("packed", packed, "use packed bootstrapping");
61
62 amap.parse(argc, argv);
63 if (idx>5) idx = 5;
64
65 Vec<long> mvec;
66 vector<long> gens;
67 vector<long> ords;
68
69 if (boot) {
70 if (L<23) L=23;
71 if (idx<1) idx=1; // the sz=0 params are incompatible with bootstrapping
72 } else {
73 #if (NTL_SP_NBITS<50)
74 if (L<46) L=46;
75 #else
76 if (L<42) L=42;
77 #endif
78 }
79
80 long p = mValues[idx][0];
81 // long phim = mValues[idx][1];
82 long m = mValues[idx][2];
83
84 append(mvec, mValues[idx][4]);
85 if (mValues[idx][5]>1) append(mvec, mValues[idx][5]);
86 if (mValues[idx][6]>1) append(mvec, mValues[idx][6]);
87
88 gens.push_back(mValues[idx][7]);
89 if (mValues[idx][8]>1) gens.push_back(mValues[idx][8]);
90 if (mValues[idx][9]>1) gens.push_back(mValues[idx][9]);
91
92 ords.push_back(mValues[idx][10]);
93 if (abs(mValues[idx][11])>1) ords.push_back(mValues[idx][11]);
94 if (abs(mValues[idx][12])>1) ords.push_back(mValues[idx][12]);
95
96 cout << "*** Test_AES: c=" << c
97 << ", L=" << L
98 << ", B=" << B
99 << ", boot=" << boot
100 << ", packed=" << packed
101 << ", m=" << m
102 << " (=" << mvec << "), gens="<<gens<<", ords="<<ords
103 << endl;
104
105 setTimersOn();
106 double tm = -GetTime();
107 cout << "computing key-independent tables..." << std::flush;
108 Context context(m, p, /*r=*/1, gens, ords);
109 #if (NTL_SP_NBITS>=50) // 64-bit machines
110 context.bitsPerLevel = B;
111 #endif
112 context.zMStar.set_cM(mValues[idx][13]/100.0); // the ring constant
113 buildModChain(context, L, c);
114
115 if (boot) context.makeBootstrappable(mvec);
116 tm += GetTime();
117 cout << "done in "<<tm<<" seconds\n";
118
119 // context.zMStar.printout();
120 {IndexSet allPrimes(0,context.numPrimes()-1);
121 cout <<" "<<context.numPrimes()<<" primes ("
122 <<context.ctxtPrimes.card()<<" ctxt/"
123 <<context.specialPrimes.card()<<" special), total bitsize="
124 <<context.logOfProduct(allPrimes)
125 <<", security level: "<<context.securityLevel() << endl;}
126
127 long e = mValues[idx][3] /8; // extension degree
128 cout << " "<<context.zMStar.getNSlots()<<" slots ("
129 << (context.zMStar.getNSlots()/16)<<" blocks) per ctxt";
130 if (boot && packed)
131 cout << ". x"<<e<<" ctxts";
132 cout << endl;
133
134 cout << "computing key-dependent tables..." << std::flush;
135 tm = -GetTime();
136 SecKey secretKey(context);
137 PubKey& publicKey = secretKey;
138 secretKey.GenSecKey(64); // A Hamming-weight-64 secret key
139
140 // Add key-switching matrices for the automorphisms that we need
141 long ord = context.zMStar.OrderOf(0);
142 for (long i = 1; i < 16; i++) { // rotation along 1st dim by size i*ord/16
143 long exp = i*ord/16;
144 long val = PowerMod(context.zMStar.ZmStarGen(0), exp, m); // val = g^exp
145
146 // From s(X^val) to s(X)
147 secretKey.GenKeySWmatrix(1, val);
148 if (!context.zMStar.SameOrd(0))
149 // also from s(X^{1/val}) to s(X)
150 secretKey.GenKeySWmatrix(1, InvMod(val,m));
151 }
152
153 addFrbMatrices(secretKey); // Also add Frobenius key-switching
154 if (boot) { // more tables
155 addSome1DMatrices(secretKey); // compute more key-switching matrices
156 secretKey.genRecryptData();
157 }
158 tm += GetTime();
159 cout << "done in "<<tm<<" seconds\n";
160
161 #ifdef DEBUG_PRINTOUT
162 dbgKey = &secretKey; // debugging key and ea
163
164 ZZX aesPoly; // X^8+X^4+X^3+X+1
165 SetCoeff(aesPoly,8); SetCoeff(aesPoly,4);
166 SetCoeff(aesPoly,3); SetCoeff(aesPoly,1); SetCoeff(aesPoly,0);
167 dbgEa = new EncryptedArray(context, aesPoly);
168 #endif
169
170 cout << "computing AES tables..." << std::flush;
171 tm = -GetTime();
172 HomAES hAES(context); // compute AES-specific key-independent tables
173 const EncryptedArrayDerived<PA_GF2>& ea2 = hAES.getEA();
174 long blocksPerCtxt = ea2.size() / 16;
175
176 long nBlocks;
177 if (boot && packed)
178 nBlocks = blocksPerCtxt * e;
179 else
180 nBlocks = blocksPerCtxt;
181
182 Vec<uint8_t> ptxt(INIT_SIZE, nBlocks*16);
183 Vec<uint8_t> aesCtxt(INIT_SIZE, nBlocks*16);
184 Vec<uint8_t> aesKey(INIT_SIZE, 16); // AES-128
185 uint8_t keySchedule[240];
186
187 // Choose random key, data
188 {GF2X rnd;
189 random(rnd, 8*ptxt.length());
190 BytesFromGF2X(ptxt.data(), rnd, aesKey.length());
191 random(rnd, 8*aesKey.length());
192 BytesFromGF2X(aesKey.data(), rnd, aesKey.length());}
193
194 // Encrypt the AES key under the HE key
195 vector< Ctxt > encryptedAESkey;
196 hAES.encryptAESkey(encryptedAESkey, aesKey, publicKey);
197 tm += GetTime();
198 cout << "done in "<<tm<<" seconds\n";
199
200 // Perform homomorphic AES
201 cout << "AES encryption "<< std::flush;
202 vector< Ctxt > doublyEncrypted;
203 tm = -GetTime();
204 hAES.homAESenc(doublyEncrypted, encryptedAESkey, ptxt);
205 tm += GetTime();
206
207 // Check that AES succeeeded
208 Vec<ZZX> poly(INIT_SIZE, doublyEncrypted.size());
209 for (long i=0; i<poly.length(); i++)
210 secretKey.Decrypt(poly[i], doublyEncrypted[i]);
211 decode4AES(aesCtxt, poly, hAES.getEA());
212
213 AESKeyExpansion(keySchedule, aesKey.data(), /*keyLength=*/128);
214 Vec<uint8_t> tmpBytes(INIT_SIZE, nBlocks*16);
215 for (long i=0; i<nBlocks; i++) {
216 Vec<uint8_t> tmp(INIT_SIZE, 16);
217 Cipher(&tmpBytes[16*i], &ptxt[16*i], keySchedule, /*numRounds=*/10);
218 }
219 if (aesCtxt != tmpBytes) {
220 cerr << "@ encryption error\n";
221 if (aesCtxt.length()!=tmpBytes.length())
222 cerr << " size mismatch, should be "<<tmpBytes.length()
223 << " but is "<<aesCtxt.length()<<endl;
224 else {
225 cerr << " input = "; printState(ptxt); cerr << endl;
226 cerr << " output ="; printState(aesCtxt); cerr << endl;
227 cerr << "should be "; printState(tmpBytes); cerr << endl;
228 }
229 }
230 else {
231 cout << "in "<<tm<<" seconds\n";
232 printNamedTimer(cout, "batchRecrypt");
233 printNamedTimer(cout, "recryption");
234 }
235 resetAllTimers();
236
237 // Decrypt and check that you have the same thing as before
238 cout << "AES decryption "<< std::flush;
239 tm = -GetTime();
240 hAES.homAESdec(doublyEncrypted, encryptedAESkey, aesCtxt);
241 tm += GetTime();
242
243 for (long i=0; i<poly.length(); i++)
244 secretKey.Decrypt(poly[i], doublyEncrypted[i]);
245 decode4AES(tmpBytes, poly, hAES.getEA());
246 if (ptxt != tmpBytes) {
247 cerr << "@ decryption error\n";
248 if (ptxt.length()!=tmpBytes.length())
249 cerr << " size mismatch, should be "<<tmpBytes.length()
250 << " but is "<<ptxt.length()<<endl;
251 else {
252 cerr << " input = "; printState(aesCtxt); cerr << endl;
253 cerr << " output ="; printState(tmpBytes); cerr << endl;
254 cerr << "should be "; printState(ptxt); cerr << endl;
255 }
256 }
257 else {
258 cout << "in "<<tm<<" seconds\n";
259 printNamedTimer(cout, "batchRecrypt");
260 printNamedTimer(cout, "recryption");
261 }
262 #if (defined(__unix__) || defined(__unix) || defined(unix))
263 struct rusage rusage;
264 getrusage( RUSAGE_SELF, &rusage );
265 cout << "rusage.ru_maxrss="<<rusage.ru_maxrss << endl << endl;
266 #endif
267 }
268
269 #include <iomanip>
printState(Vec<uint8_t> & st)270 void printState(Vec<uint8_t>& st)
271 {
272 cerr << "[";
273 for (long i=0; i<st.length() && i<32; i++) {
274 cerr << std::hex << std::setw(2) << (long) st[i] << " ";
275 }
276 if (st.length()>32) cerr << "...";
277 cerr << std::dec << "]";
278 }
279