1 /** Test_AES.cpp - test program for homomorphic AES using HElib
2  */
3 #if defined(__unix__) || defined(__unix) || defined(unix)
4 #include <sys/time.h>
5 #include <sys/resource.h>
6 #endif
7 
8 namespace std {} using namespace std;
9 namespace NTL {} using namespace NTL;
10 #include <cstring>
11 #include "homAES.h"
12 #include "Ctxt.h"
13 
14 static long mValues[][14] = {
15 //{ p, phi(m),  m,   d, m1, m2, m3,   g1,    g2,   g3,ord1,ord2,ord3, c_m}
16   { 2,  512,    771, 16,771,  0,  0,     5,    0,    0,-32,  0,  0, 100}, // m=(3)*{257} :-( m/phim(m)=1.5 C=77 D=2 E=4
17   { 2, 4096,   4369, 16, 17, 257, 0,   258, 4115,    0, 16,-16,  0, 100}, // m=17*(257) :-( m/phim(m)=1.06 C=61 D=3 E=4
18   { 2, 16384, 21845, 16, 17, 5, 257,  8996,17477,21591, 16,  4,-16,1600}, // m=5*17*(257) :-( m/phim(m)=1.33 C=65 D=4 E=4
19   { 2, 23040, 28679, 24, 17, 7, 241, 15184, 4098,28204, 16,  6,-10,1500}, // m=7*17*(241) m/phim(m)=1.24    C=63  D=4 E=3
20   { 2, 46080, 53261, 24, 17,13, 241, 43863,28680,15913, 16, 12,-10, 100}, // m=13*17*(241) m/phim(m)=1.15   C=69  D=4 E=3
21   { 2, 64512, 65281, 48, 97,673,  0, 43073,22214,    0, 96,-14,  0, 100}  // m=97*(673) :-( m/phim(m)=1.01  C=169 D=3 E=4
22 };
23 
24 #ifdef DEBUG_PRINTOUT
25 extern SecKey* dbgKey;
26 extern EncryptedArray* dbgEa;
27 #define FLAG_PRINT_ZZX  1
28 #define FLAG_PRINT_POLY 2
29 #define FLAG_PRINT_VEC  4
30 extern void decryptAndPrint(ostream& s, const Ctxt& ctxt, const SecKey& sk,
31 			    const EncryptedArray& ea, long flags=0);
32 #endif
33 
34 void printState(Vec<uint8_t>& st);
35 extern long AESKeyExpansion(unsigned char RoundKey[],
36 			    unsigned char Key[], int NN);
37 extern void Cipher(unsigned char out[16],
38 		   unsigned char in[16], unsigned char RoundKey[], int Nr);
39 
main(int argc,char ** argv)40 int main(int argc, char **argv)
41 {
42   ArgMapping amap;
43 
44   long idx = 0;
45   amap.arg("sz", idx, "parameter-sets: toy=0 through huge=5");
46 
47   long c=3;
48   amap.arg("c", c, "number of columns in the key-switching matrices");
49 
50   long L=0;
51   amap.arg("L", L, "# of levels in the modulus chain",  "heuristic");
52 
53   long B=23;
54   amap.arg("B", B, "# of bits per level (only 64-bit machines)");
55 
56   bool boot=false;
57   amap.arg("boot", boot, "includes bootstrapping");
58 
59   bool packed=true;
60   amap.arg("packed", packed, "use packed bootstrapping");
61 
62   amap.parse(argc, argv);
63   if (idx>5) idx = 5;
64 
65   Vec<long> mvec;
66   vector<long> gens;
67   vector<long> ords;
68 
69   if (boot) {
70     if (L<23) L=23;
71     if (idx<1) idx=1; // the sz=0 params are incompatible with bootstrapping
72   } else {
73 #if (NTL_SP_NBITS<50)
74     if (L<46) L=46;
75 #else
76     if (L<42) L=42;
77 #endif
78   }
79 
80   long p = mValues[idx][0];
81   //  long phim = mValues[idx][1];
82   long m = mValues[idx][2];
83 
84   append(mvec, mValues[idx][4]);
85   if (mValues[idx][5]>1) append(mvec, mValues[idx][5]);
86   if (mValues[idx][6]>1) append(mvec, mValues[idx][6]);
87 
88   gens.push_back(mValues[idx][7]);
89   if (mValues[idx][8]>1)   gens.push_back(mValues[idx][8]);
90   if (mValues[idx][9]>1) gens.push_back(mValues[idx][9]);
91 
92   ords.push_back(mValues[idx][10]);
93   if (abs(mValues[idx][11])>1) ords.push_back(mValues[idx][11]);
94   if (abs(mValues[idx][12])>1) ords.push_back(mValues[idx][12]);
95 
96   cout << "*** Test_AES: c=" << c
97        << ", L=" << L
98        << ", B=" << B
99        << ", boot=" << boot
100        << ", packed=" << packed
101        << ", m=" << m
102        << " (=" << mvec << "), gens="<<gens<<", ords="<<ords
103        << endl;
104 
105   setTimersOn();
106   double tm = -GetTime();
107   cout << "computing key-independent tables..." << std::flush;
108   Context context(m, p, /*r=*/1, gens, ords);
109 #if (NTL_SP_NBITS>=50) // 64-bit machines
110   context.bitsPerLevel = B;
111 #endif
112   context.zMStar.set_cM(mValues[idx][13]/100.0); // the ring constant
113   buildModChain(context, L, c);
114 
115   if (boot) context.makeBootstrappable(mvec);
116   tm += GetTime();
117   cout << "done in "<<tm<<" seconds\n";
118 
119   //  context.zMStar.printout();
120   {IndexSet allPrimes(0,context.numPrimes()-1);
121    cout <<"  "<<context.numPrimes()<<" primes ("
122        <<context.ctxtPrimes.card()<<" ctxt/"
123        <<context.specialPrimes.card()<<" special), total bitsize="
124 	<<context.logOfProduct(allPrimes)
125 	<<", security level: "<<context.securityLevel() << endl;}
126 
127   long e = mValues[idx][3] /8; // extension degree
128   cout << "  "<<context.zMStar.getNSlots()<<" slots ("
129        << (context.zMStar.getNSlots()/16)<<" blocks) per ctxt";
130   if (boot && packed)
131     cout << ". x"<<e<<" ctxts";
132   cout << endl;
133 
134   cout << "computing key-dependent tables..." << std::flush;
135   tm = -GetTime();
136   SecKey secretKey(context);
137   PubKey& publicKey = secretKey;
138   secretKey.GenSecKey(64);      // A Hamming-weight-64 secret key
139 
140   // Add key-switching matrices for the automorphisms that we need
141   long ord = context.zMStar.OrderOf(0);
142   for (long i = 1; i < 16; i++) { // rotation along 1st dim by size i*ord/16
143     long exp = i*ord/16;
144     long val = PowerMod(context.zMStar.ZmStarGen(0), exp, m); // val = g^exp
145 
146     // From s(X^val) to s(X)
147     secretKey.GenKeySWmatrix(1, val);
148     if (!context.zMStar.SameOrd(0))
149       // also from s(X^{1/val}) to s(X)
150       secretKey.GenKeySWmatrix(1, InvMod(val,m));
151   }
152 
153   addFrbMatrices(secretKey);      // Also add Frobenius key-switching
154   if (boot) { // more tables
155     addSome1DMatrices(secretKey);   // compute more key-switching matrices
156     secretKey.genRecryptData();
157   }
158   tm += GetTime();
159   cout << "done in "<<tm<<" seconds\n";
160 
161 #ifdef DEBUG_PRINTOUT
162   dbgKey = &secretKey; // debugging key and ea
163 
164   ZZX aesPoly;         // X^8+X^4+X^3+X+1
165   SetCoeff(aesPoly,8);  SetCoeff(aesPoly,4);
166   SetCoeff(aesPoly,3);  SetCoeff(aesPoly,1);  SetCoeff(aesPoly,0);
167   dbgEa = new EncryptedArray(context, aesPoly);
168 #endif
169 
170   cout << "computing AES tables..." << std::flush;
171   tm = -GetTime();
172   HomAES hAES(context); // compute AES-specific key-independent tables
173   const EncryptedArrayDerived<PA_GF2>& ea2 = hAES.getEA();
174   long blocksPerCtxt = ea2.size() / 16;
175 
176   long nBlocks;
177   if (boot && packed)
178     nBlocks = blocksPerCtxt * e;
179   else
180     nBlocks = blocksPerCtxt;
181 
182   Vec<uint8_t> ptxt(INIT_SIZE, nBlocks*16);
183   Vec<uint8_t> aesCtxt(INIT_SIZE, nBlocks*16);
184   Vec<uint8_t> aesKey(INIT_SIZE, 16); // AES-128
185   uint8_t keySchedule[240];
186 
187   // Choose random key, data
188   {GF2X rnd;
189   random(rnd, 8*ptxt.length());
190   BytesFromGF2X(ptxt.data(), rnd, aesKey.length());
191   random(rnd, 8*aesKey.length());
192   BytesFromGF2X(aesKey.data(), rnd, aesKey.length());}
193 
194   // Encrypt the AES key under the HE key
195   vector< Ctxt > encryptedAESkey;
196   hAES.encryptAESkey(encryptedAESkey, aesKey, publicKey);
197   tm += GetTime();
198   cout << "done in "<<tm<<" seconds\n";
199 
200   // Perform homomorphic AES
201   cout << "AES encryption "<< std::flush;
202   vector< Ctxt > doublyEncrypted;
203   tm = -GetTime();
204   hAES.homAESenc(doublyEncrypted, encryptedAESkey, ptxt);
205   tm += GetTime();
206 
207   // Check that AES succeeeded
208   Vec<ZZX> poly(INIT_SIZE, doublyEncrypted.size());
209   for (long i=0; i<poly.length(); i++)
210     secretKey.Decrypt(poly[i], doublyEncrypted[i]);
211   decode4AES(aesCtxt, poly, hAES.getEA());
212 
213   AESKeyExpansion(keySchedule, aesKey.data(), /*keyLength=*/128);
214   Vec<uint8_t> tmpBytes(INIT_SIZE, nBlocks*16);
215   for (long i=0; i<nBlocks; i++) {
216     Vec<uint8_t> tmp(INIT_SIZE, 16);
217     Cipher(&tmpBytes[16*i], &ptxt[16*i], keySchedule, /*numRounds=*/10);
218   }
219   if (aesCtxt != tmpBytes) {
220     cerr << "@ encryption error\n";
221     if (aesCtxt.length()!=tmpBytes.length())
222       cerr << "  size mismatch, should be "<<tmpBytes.length()
223 	   << " but is "<<aesCtxt.length()<<endl;
224     else {
225       cerr << "  input = "; printState(ptxt); cerr << endl;
226       cerr << "  output ="; printState(aesCtxt); cerr << endl;
227       cerr << "should be "; printState(tmpBytes); cerr << endl;
228     }
229   }
230   else {
231     cout << "in "<<tm<<" seconds\n";
232     printNamedTimer(cout, "batchRecrypt");
233     printNamedTimer(cout, "recryption");
234   }
235   resetAllTimers();
236 
237   // Decrypt and check that you have the same thing as before
238   cout << "AES decryption "<< std::flush;
239   tm = -GetTime();
240   hAES.homAESdec(doublyEncrypted, encryptedAESkey, aesCtxt);
241   tm += GetTime();
242 
243   for (long i=0; i<poly.length(); i++)
244     secretKey.Decrypt(poly[i], doublyEncrypted[i]);
245   decode4AES(tmpBytes, poly, hAES.getEA());
246   if (ptxt != tmpBytes) {
247     cerr << "@ decryption error\n";
248     if (ptxt.length()!=tmpBytes.length())
249       cerr << "  size mismatch, should be "<<tmpBytes.length()
250 	   << " but is "<<ptxt.length()<<endl;
251     else {
252       cerr << "  input = "; printState(aesCtxt); cerr << endl;
253       cerr << "  output ="; printState(tmpBytes); cerr << endl;
254       cerr << "should be "; printState(ptxt); cerr << endl;
255     }
256   }
257   else {
258     cout << "in "<<tm<<" seconds\n";
259     printNamedTimer(cout, "batchRecrypt");
260     printNamedTimer(cout, "recryption");
261   }
262 #if (defined(__unix__) || defined(__unix) || defined(unix))
263   struct rusage rusage;
264   getrusage( RUSAGE_SELF, &rusage );
265   cout << "rusage.ru_maxrss="<<rusage.ru_maxrss << endl << endl;
266 #endif
267 }
268 
269 #include <iomanip>
printState(Vec<uint8_t> & st)270 void printState(Vec<uint8_t>& st)
271 {
272   cerr << "[";
273   for (long i=0; i<st.length() && i<32; i++) {
274     cerr << std::hex << std::setw(2) << (long) st[i] << " ";
275   }
276   if (st.length()>32) cerr << "...";
277   cerr << std::dec << "]";
278 }
279