1 /* Copyright (C) 2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 
13 #include <fstream>
14 #include <ctime>
15 
16 #include <helib/helib.h>
17 #include <helib/ArgMap.h>
18 #include <helib/debugging.h>
19 
20 #include <NTL/BasicThreadPool.h>
21 
22 #include "common.h"
23 
24 struct CmdLineOpts
25 {
26   std::string paramFileName;
27   std::string outputPrefixPath;
28   std::string scheme = "BGV";
29   std::string bootstrappable = "NONE"; // NONE | THIN | FAT
30   bool noSKM = false;
31   bool frobSKM = false;
32   bool infoFile = false;
33 };
34 
35 // Captures parameters of both BGV and CKKS
36 struct ParamsFileOpts
37 {
38   long m = 0;
39   long p = 0;
40   long r = 0;
41   long c = 0;
42   long Qbits = 0;
43   long scale = 4;
44   long c_m = 100;
45   NTL::Vec<long> mvec;
46   NTL::Vec<long> gens;
47   NTL::Vec<long> ords;
48 };
49 
50 // Write context.printout to file.out
printoutToStream(const helib::Context & context,std::ostream & out,bool noSKM,bool frobSKM,bool bootstrappable)51 void printoutToStream(const helib::Context& context,
52                       std::ostream& out,
53                       bool noSKM,
54                       bool frobSKM,
55                       bool bootstrappable)
56 {
57   if (!noSKM || bootstrappable)
58     out << "Key switching matrices created.\n";
59   if (frobSKM || bootstrappable)
60     out << "Frobenius matrices created.\n";
61   if (bootstrappable)
62     out << "Recrypt data created.\n";
63 
64   // write the algebra info
65   context.printout(out);
66 }
67 
68 // sk is child of pk in HElib.
writeKeyToFile(std::string & pathPrefix,helib::Context & context,helib::SecKey & secretKey,bool pkNotSk)69 void writeKeyToFile(std::string& pathPrefix,
70                     helib::Context& context,
71                     helib::SecKey& secretKey,
72                     bool pkNotSk)
73 {
74   std::string path = pathPrefix + (pkNotSk ? ".pk" : ".sk");
75   std::ofstream keysFile(path, std::ios::binary);
76   if (!keysFile.is_open()) {
77     std::runtime_error("Cannot write keys to file at '" + path);
78   }
79 
80   // write the context
81   helib::writeContextBaseBinary(keysFile, context);
82   helib::writeContextBinary(keysFile, context);
83 
84   // write the keys
85   if (pkNotSk)
86     helib::writePubKeyBinary(keysFile, secretKey);
87   else
88     helib::writeSecKeyBinary(keysFile, secretKey);
89 }
90 
main(int argc,char * argv[])91 int main(int argc, char* argv[])
92 {
93   CmdLineOpts cmdLineOpts;
94 
95   // clang-format off
96   helib::ArgMap()
97         .toggle()
98          .arg("--no-skm", cmdLineOpts.noSKM,
99                "disable switch-key matrices.", nullptr)
100          .arg("--frob-skm", cmdLineOpts.frobSKM,
101                "generate Frobenius switch-key matrices.", nullptr)
102          .arg("--info-file", cmdLineOpts.infoFile,
103                "print algebra info to file.", nullptr)
104         .separator(helib::ArgMap::Separator::WHITESPACE)
105         .named()
106           .arg("--scheme", cmdLineOpts.scheme,
107                "choose scheme BGV | CKKS.")
108           .arg("-o", cmdLineOpts.outputPrefixPath,
109                "choose an output prefix path.", nullptr)
110           .arg("--bootstrap", cmdLineOpts.bootstrappable,
111                "choose boostrapping option NONE | THIN | THICK.")
112         .required()
113         .positional()
114           .arg("<params-file>", cmdLineOpts.paramFileName,
115                "the parameters file.", nullptr)
116         .parse(argc, argv);
117   // clang-format on
118 
119   ParamsFileOpts paramsOpts;
120 
121   try {
122     // clang-format off
123     helib::ArgMap()
124           .arg("p", paramsOpts.p, "require p.", "")
125           .arg("m", paramsOpts.m, "require m.", "")
126           .arg("r", paramsOpts.r, "require r.", "")
127           .arg("c", paramsOpts.c, "require c.", "")
128           .arg("Qbits", paramsOpts.Qbits, "require Q bits.", "")
129           .optional()
130             .arg("scale", paramsOpts.scale, "require scale for CKKS")
131             .arg("c_m", paramsOpts.c_m, "require c_m for bootstrapping.", "")
132             .arg("mvec", paramsOpts.mvec, "require mvec for bootstrapping.", "")
133             .arg("gens", paramsOpts.gens, "require gens for bootstrapping.", "")
134             .arg("ords", paramsOpts.ords, "require ords for bootstrapping.", "")
135           .parse(cmdLineOpts.paramFileName);
136     // clang-format on
137   } catch (const helib::RuntimeError& e) {
138     std::cerr << e.what() << std::endl;
139     return EXIT_FAILURE;
140   }
141 
142   // Create the FHE context
143   long p;
144   if (cmdLineOpts.scheme.empty() || cmdLineOpts.scheme == "BGV") {
145     if (paramsOpts.p < 2) {
146       std::cerr << "BGV invalid plaintext modulus. "
147                    "In BGV it must be a prime number greater than 1."
148                 << std::endl;
149       return EXIT_FAILURE;
150     }
151     p = paramsOpts.p;
152   } else if (cmdLineOpts.scheme == "CKKS") {
153     if (paramsOpts.p != -1) {
154       std::cerr << "CKKS invalid plaintext modulus. "
155                    "In CKKS it must be set to -1."
156                 << std::endl;
157       return EXIT_FAILURE;
158     }
159     p = -1;
160     if (cmdLineOpts.bootstrappable != "NONE") {
161       std::cerr << "CKKS does not currently support bootstrapping."
162                 << std::endl;
163       return EXIT_FAILURE;
164     }
165   } else {
166     std::cerr << "Unrecognized scheme '" << cmdLineOpts.scheme << "'."
167               << std::endl;
168     return EXIT_FAILURE;
169   }
170 
171   if (cmdLineOpts.noSKM && cmdLineOpts.frobSKM) {
172     std::cerr << "Frobenius matrices reqires switch-key matrices to be "
173                  "generated."
174               << std::endl;
175     return EXIT_FAILURE;
176   }
177 
178   if (cmdLineOpts.bootstrappable != "NONE" &&
179       cmdLineOpts.bootstrappable != "THIN" &&
180       cmdLineOpts.bootstrappable != "THICK") {
181     std::cerr << "Bad boostrap option: " << cmdLineOpts.bootstrappable
182               << ".  Allowed options are NONE, THIN, THICK." << std::endl;
183     return EXIT_FAILURE;
184   }
185 
186   if (cmdLineOpts.bootstrappable != "NONE") {
187     if (cmdLineOpts.noSKM) {
188       std::cerr << "Cannot generate bootstrappable context without switch-key "
189                    "and frobenius matrices."
190                 << std::endl;
191       return EXIT_FAILURE;
192     }
193     if (paramsOpts.mvec.length() == 0) {
194       std::cerr << "Missing mvec parameter for bootstrapping in "
195                 << cmdLineOpts.paramFileName << "." << std::endl;
196       return EXIT_FAILURE;
197     }
198     if (paramsOpts.gens.length() == 0) {
199       std::cerr << "Missing gens parameter for bootstrapping in "
200                 << cmdLineOpts.paramFileName << "." << std::endl;
201       return EXIT_FAILURE;
202     }
203     if (paramsOpts.ords.length() == 0) {
204       std::cerr << "Missing ords parameter for bootstrapping in "
205                 << cmdLineOpts.paramFileName << "." << std::endl;
206       return EXIT_FAILURE;
207     }
208   }
209 
210   try {
211     helib::Context context(paramsOpts.m,
212                            p,
213                            paramsOpts.r,
214                            helib::convert<std::vector<long>>(paramsOpts.gens),
215                            helib::convert<std::vector<long>>(paramsOpts.ords));
216     if (cmdLineOpts.bootstrappable == "NONE") {
217       helib::buildModChain(context, paramsOpts.Qbits, paramsOpts.c);
218     } else {
219       context.zMStar.set_cM(paramsOpts.c_m / 100.0);
220       helib::buildModChain(context,
221                            paramsOpts.Qbits,
222                            paramsOpts.c,
223                            /*willBeBootstrappable=*/true);
224       if (cmdLineOpts.bootstrappable == "THICK")
225         context.enableBootStrapping(paramsOpts.mvec,
226                                     /*build_cache=*/false,
227                                     /*alsoThick=*/true);
228       else if (cmdLineOpts.bootstrappable == "THIN")
229         context.enableBootStrapping(paramsOpts.mvec,
230                                     /*build_cache=*/false,
231                                     /*alsoThick=*/false);
232     }
233 
234     if (p == -1)
235       context.scale = paramsOpts.scale;
236 
237     // and a new secret/public key
238     helib::SecKey secretKey(context);
239     secretKey.GenSecKey(); // A +-1/0 secret key
240 
241     // compute key-switching matrices
242     if (!cmdLineOpts.noSKM || cmdLineOpts.bootstrappable != "NONE") {
243       helib::addSome1DMatrices(secretKey);
244       if (cmdLineOpts.frobSKM || cmdLineOpts.bootstrappable != "NONE") {
245         helib::addFrbMatrices(secretKey);
246       }
247     }
248 
249     if (cmdLineOpts.bootstrappable != "NONE") {
250       secretKey.genRecryptData();
251     }
252 
253     // If not set by user, returns params file name with truncated UTC
254     if (cmdLineOpts.outputPrefixPath.empty()) {
255       cmdLineOpts.outputPrefixPath =
256           stripExtension(cmdLineOpts.paramFileName) +
257           std::to_string(std::time(nullptr) % 100000);
258       std::cout << "File prefix: " << cmdLineOpts.outputPrefixPath << std::endl;
259     }
260 
261     // Printout important info
262     if (cmdLineOpts.infoFile) {
263       // outputPrefixPath should be set further up main.
264       std::string path = cmdLineOpts.outputPrefixPath + ".info";
265       std::ofstream out(path);
266       if (!out.is_open()) {
267         throw std::runtime_error("Cannot write keys to file at '" + path +
268                                  "'.");
269       }
270       printoutToStream(context,
271                        out,
272                        cmdLineOpts.noSKM,
273                        cmdLineOpts.frobSKM,
274                        cmdLineOpts.bootstrappable != "NONE");
275     } else {
276       printoutToStream(context,
277                        std::cout,
278                        cmdLineOpts.noSKM,
279                        cmdLineOpts.frobSKM,
280                        cmdLineOpts.bootstrappable != "NONE");
281     }
282 
283     NTL::SetNumThreads(2);
284 
285     NTL_EXEC_INDEX(2, skOrPk)
286     writeKeyToFile(cmdLineOpts.outputPrefixPath, context, secretKey, skOrPk);
287     NTL_EXEC_INDEX_END
288 
289   } catch (const std::invalid_argument& e) {
290     std::cerr << "Exit due to invalid argument thrown:\n"
291               << e.what() << std::endl;
292     return EXIT_FAILURE;
293   } catch (const helib::IOError& e) {
294     std::cerr << "Exit due to IOError thrown:\n" << e.what() << std::endl;
295     return EXIT_FAILURE;
296   } catch (const std::runtime_error& e) {
297     std::cerr << "Exit due to runtime error thrown:\n" << e.what() << std::endl;
298     return EXIT_FAILURE;
299   } catch (const std::logic_error& e) {
300     std::cerr << "Exit due to logic error thrown:\n" << e.what() << std::endl;
301     return EXIT_FAILURE;
302   } catch (const std::exception& e) {
303     std::cerr << "Exit due to unknown exception thrown:\n"
304               << e.what() << std::endl;
305     return EXIT_FAILURE;
306   }
307 
308   return EXIT_SUCCESS;
309 }
310