1 // @file bfvrnsB-impl.cpp - template instantiations and methods for the BFVrnsB
2 // scheme
3 // @author TPOC: contact@palisade-crypto.org
4 //
5 // @copyright Copyright (c) 2019, New Jersey Institute of Technology (NJIT)
6 // All rights reserved.
7 // Redistribution and use in source and binary forms, with or without
8 // modification, are permitted provided that the following conditions are met:
9 // 1. Redistributions of source code must retain the above copyright notice,
10 // this list of conditions and the following disclaimer.
11 // 2. Redistributions in binary form must reproduce the above copyright notice,
12 // this list of conditions and the following disclaimer in the documentation
13 // and/or other materials provided with the distribution. THIS SOFTWARE IS
14 // PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
15 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
16 // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
17 // EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
18 // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
21 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 /*
25 Description:
26 
27 This code implements a RNS variant of the Brakerski-Fan-Vercauteren (BFV)
28 homomorphic encryption scheme.  This scheme is also referred to as the FV
29 scheme.
30 
31 The BFV scheme is introduced in the following papers:
32    - Zvika Brakerski (2012). Fully Homomorphic Encryption without Modulus
33 Switching from Classical GapSVP. Cryptology ePrint Archive, Report 2012/078.
34 (https://eprint.iacr.org/2012/078)
35    - Junfeng Fan and Frederik Vercauteren (2012). Somewhat Practical Fully
36 Homomorphic Encryption.  Cryptology ePrint Archive, Report 2012/144.
37 (https://eprint.iacr.org/2012/144.pdf)
38 
39  Our implementation builds from the designs here:
40    - Lepoint T., Naehrig M. (2014) A Comparison of the Homomorphic Encryption
41 Schemes FV and YASHE. In: Pointcheval D., Vergnaud D. (eds) Progress in
42 Cryptology – AFRICACRYPT 2014. AFRICACRYPT 2014. Lecture Notes in Computer
43 Science, vol 8469. Springer, Cham. (https://eprint.iacr.org/2014/062.pdf)
44    - Jean-Claude Bajard and Julien Eynard and Anwar Hasan and Vincent Zucca
45 (2016). A Full RNS Variant of FV like Somewhat Homomorphic Encryption Schemes.
46 Cryptology ePrint Archive, Report 2016/510. (https://eprint.iacr.org/2016/510)
47    - Ahmad Al Badawi and Yuriy Polyakov and Khin Mi Mi Aung and Bharadwaj
48 Veeravalli and Kurt Rohloff (2018). Implementation and Performance Evaluation of
49 RNS Variants of the BFV Homomorphic Encryption Scheme. Cryptology ePrint
50 Archive, Report 2018/589. {https://eprint.iacr.org/2018/589}
51  */
52 
53 #include "bfvrnsB.cpp"
54 #include "cryptocontext.h"
55 
56 // #define USE_KARATSUBA
57 
58 namespace lbcrypto {
59 
60 #define NOPOLY                                                                 \
61   std::string errMsg = "BFVrnsB does not support Poly. Use DCRTPoly instead."; \
62   PALISADE_THROW(not_implemented_error, errMsg);
63 
64 #define NONATIVEPOLY                                                \
65   std::string errMsg =                                              \
66       "BFVrnsB does not support NativePoly. Use DCRTPoly instead."; \
67   PALISADE_THROW(not_implemented_error, errMsg);
68 
69 template <>
LPCryptoParametersBFVrnsB()70 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB()
71     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
72   NOPOLY
73 }
74 
75 template <>
LPCryptoParametersBFVrnsB()76 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB()
77     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
78   NONATIVEPOLY
79 }
80 
81 template <>
LPCryptoParametersBFVrnsB(const LPCryptoParametersBFVrnsB & rhs)82 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB(
83     const LPCryptoParametersBFVrnsB &rhs)
84     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
85   NOPOLY
86 }
87 
88 template <>
LPCryptoParametersBFVrnsB(const LPCryptoParametersBFVrnsB & rhs)89 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB(
90     const LPCryptoParametersBFVrnsB &rhs)
91     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
92   NONATIVEPOLY
93 }
94 
95 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,const PlaintextModulus & plaintextModulus,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)96 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB(
97     shared_ptr<ParmType> params, const PlaintextModulus &plaintextModulus,
98     float distributionParameter, float assuranceMeasure, float securityLevel,
99     usint relinWindow, MODE mode, int depth, int maxDepth)
100     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
101   NOPOLY
102 }
103 
104 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,const PlaintextModulus & plaintextModulus,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)105 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB(
106     shared_ptr<ParmType> params, const PlaintextModulus &plaintextModulus,
107     float distributionParameter, float assuranceMeasure, float securityLevel,
108     usint relinWindow, MODE mode, int depth, int maxDepth)
109     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
110   NONATIVEPOLY
111 }
112 
113 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,EncodingParams encodingParams,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)114 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB(
115     shared_ptr<ParmType> params, EncodingParams encodingParams,
116     float distributionParameter, float assuranceMeasure, float securityLevel,
117     usint relinWindow, MODE mode, int depth, int maxDepth)
118     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
119   NOPOLY
120 }
121 
122 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,EncodingParams encodingParams,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)123 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB(
124     shared_ptr<ParmType> params, EncodingParams encodingParams,
125     float distributionParameter, float assuranceMeasure, float securityLevel,
126     usint relinWindow, MODE mode, int depth, int maxDepth)
127     : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
128   NONATIVEPOLY
129 }
130 
131 // Parameter generation for BFV-RNS
132 template <>
PrecomputeCRTTables()133 bool LPCryptoParametersBFVrnsB<Poly>::PrecomputeCRTTables() {
134   NOPOLY
135 }
136 
137 template <>
PrecomputeCRTTables()138 bool LPCryptoParametersBFVrnsB<NativePoly>::PrecomputeCRTTables() {
139   NONATIVEPOLY
140 }
141 
142 template <>
LPPublicKeyEncryptionSchemeBFVrnsB()143 LPPublicKeyEncryptionSchemeBFVrnsB<Poly>::LPPublicKeyEncryptionSchemeBFVrnsB() {
144   NOPOLY
145 }
146 
147 template <>
148 LPPublicKeyEncryptionSchemeBFVrnsB<
LPPublicKeyEncryptionSchemeBFVrnsB()149     NativePoly>::LPPublicKeyEncryptionSchemeBFVrnsB() {
150   NONATIVEPOLY
151 }
152 
153 template <>
ParamsGen(shared_ptr<LPCryptoParameters<Poly>> cryptoParams,int32_t evalAddCount,int32_t evalMultCount,int32_t keySwitchCount,size_t dcrtBits,uint32_t n) const154 bool LPAlgorithmParamsGenBFVrnsB<Poly>::ParamsGen(
155     shared_ptr<LPCryptoParameters<Poly>> cryptoParams, int32_t evalAddCount,
156     int32_t evalMultCount, int32_t keySwitchCount, size_t dcrtBits,
157     uint32_t n) const {
158   NOPOLY
159 }
160 
161 template <>
ParamsGen(shared_ptr<LPCryptoParameters<NativePoly>> cryptoParams,int32_t evalAddCount,int32_t evalMultCount,int32_t keySwitchCount,size_t dcrtBits,uint32_t n) const162 bool LPAlgorithmParamsGenBFVrnsB<NativePoly>::ParamsGen(
163     shared_ptr<LPCryptoParameters<NativePoly>> cryptoParams,
164     int32_t evalAddCount, int32_t evalMultCount, int32_t keySwitchCount,
165     size_t dcrtBits, uint32_t n) const {
166   NONATIVEPOLY
167 }
168 
169 template <>
Encrypt(const LPPublicKey<Poly> publicKey,Poly ptxt) const170 Ciphertext<Poly> LPAlgorithmBFVrnsB<Poly>::Encrypt(
171     const LPPublicKey<Poly> publicKey, Poly ptxt) const {
172   NOPOLY
173 }
174 
175 template <>
Encrypt(const LPPublicKey<NativePoly> publicKey,NativePoly ptxt) const176 Ciphertext<NativePoly> LPAlgorithmBFVrnsB<NativePoly>::Encrypt(
177     const LPPublicKey<NativePoly> publicKey, NativePoly ptxt) const {
178   NONATIVEPOLY
179 }
180 
181 template <>
Decrypt(const LPPrivateKey<Poly> privateKey,ConstCiphertext<Poly> ciphertext,NativePoly * plaintext) const182 DecryptResult LPAlgorithmBFVrnsB<Poly>::Decrypt(
183     const LPPrivateKey<Poly> privateKey, ConstCiphertext<Poly> ciphertext,
184     NativePoly *plaintext) const {
185   NOPOLY
186 }
187 
188 template <>
Decrypt(const LPPrivateKey<NativePoly> privateKey,ConstCiphertext<NativePoly> ciphertext,NativePoly * plaintext) const189 DecryptResult LPAlgorithmBFVrnsB<NativePoly>::Decrypt(
190     const LPPrivateKey<NativePoly> privateKey,
191     ConstCiphertext<NativePoly> ciphertext, NativePoly *plaintext) const {
192   NONATIVEPOLY
193 }
194 
195 template <>
Encrypt(const LPPrivateKey<Poly> privateKey,Poly ptxt) const196 Ciphertext<Poly> LPAlgorithmBFVrnsB<Poly>::Encrypt(
197     const LPPrivateKey<Poly> privateKey, Poly ptxt) const {
198   NOPOLY
199 }
200 
201 template <>
Encrypt(const LPPrivateKey<NativePoly> privateKey,NativePoly ptxt) const202 Ciphertext<NativePoly> LPAlgorithmBFVrnsB<NativePoly>::Encrypt(
203     const LPPrivateKey<NativePoly> privateKey, NativePoly ptxt) const {
204   NONATIVEPOLY
205 }
206 
207 template <>
EvalMult(ConstCiphertext<Poly> ciphertext1,ConstCiphertext<Poly> ciphertext2) const208 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalMult(
209     ConstCiphertext<Poly> ciphertext1,
210     ConstCiphertext<Poly> ciphertext2) const {
211   NOPOLY
212 }
213 
214 template <>
EvalMult(ConstCiphertext<NativePoly> ciphertext1,ConstCiphertext<NativePoly> ciphertext2) const215 Ciphertext<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::EvalMult(
216     ConstCiphertext<NativePoly> ciphertext1,
217     ConstCiphertext<NativePoly> ciphertext2) const {
218   NONATIVEPOLY
219 }
220 
221 template <>
EvalAdd(ConstCiphertext<Poly> ct,ConstPlaintext pt) const222 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalAdd(ConstCiphertext<Poly> ct,
223                                                       ConstPlaintext pt) const {
224   NOPOLY
225 }
226 
227 template <>
EvalAdd(ConstCiphertext<NativePoly> ct,ConstPlaintext pt) const228 Ciphertext<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::EvalAdd(
229     ConstCiphertext<NativePoly> ct, ConstPlaintext pt) const {
230   NONATIVEPOLY
231 }
232 
233 template <>
EvalSub(ConstCiphertext<Poly> ct,ConstPlaintext pt) const234 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalSub(ConstCiphertext<Poly> ct,
235                                                       ConstPlaintext pt) const {
236   NOPOLY
237 }
238 
239 template <>
EvalSub(ConstCiphertext<NativePoly> ct,ConstPlaintext pt) const240 Ciphertext<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::EvalSub(
241     ConstCiphertext<NativePoly> ct, ConstPlaintext pt) const {
242   NONATIVEPOLY
243 }
244 
245 template <>
KeySwitchGen(const LPPrivateKey<Poly> originalPrivateKey,const LPPrivateKey<Poly> newPrivateKey) const246 LPEvalKey<Poly> LPAlgorithmSHEBFVrnsB<Poly>::KeySwitchGen(
247     const LPPrivateKey<Poly> originalPrivateKey,
248     const LPPrivateKey<Poly> newPrivateKey) const {
249   NOPOLY
250 }
251 
252 template <>
KeySwitchGen(const LPPrivateKey<NativePoly> originalPrivateKey,const LPPrivateKey<NativePoly> newPrivateKey) const253 LPEvalKey<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::KeySwitchGen(
254     const LPPrivateKey<NativePoly> originalPrivateKey,
255     const LPPrivateKey<NativePoly> newPrivateKey) const {
256   NONATIVEPOLY
257 }
258 
259 template <>
KeySwitchInPlace(const LPEvalKey<Poly> keySwitchHint,Ciphertext<Poly> & cipherText) const260 void LPAlgorithmSHEBFVrnsB<Poly>::KeySwitchInPlace(
261     const LPEvalKey<Poly> keySwitchHint,
262     Ciphertext<Poly>& cipherText) const {
263   NOPOLY
264 }
265 
266 template <>
KeySwitchInPlace(const LPEvalKey<NativePoly> keySwitchHint,Ciphertext<NativePoly> & cipherText) const267 void LPAlgorithmSHEBFVrnsB<NativePoly>::KeySwitchInPlace(
268     const LPEvalKey<NativePoly> keySwitchHint,
269     Ciphertext<NativePoly>& cipherText) const {
270   NONATIVEPOLY
271 }
272 
273 template <>
EvalMultAndRelinearize(ConstCiphertext<Poly> ct1,ConstCiphertext<Poly> ct,const vector<LPEvalKey<Poly>> & ek) const274 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalMultAndRelinearize(
275     ConstCiphertext<Poly> ct1, ConstCiphertext<Poly> ct,
276     const vector<LPEvalKey<Poly>> &ek) const {
277   NOPOLY
278 }
279 
280 template <>
281 Ciphertext<NativePoly>
EvalMultAndRelinearize(ConstCiphertext<NativePoly> ct1,ConstCiphertext<NativePoly> ct,const vector<LPEvalKey<NativePoly>> & ek) const282 LPAlgorithmSHEBFVrnsB<NativePoly>::EvalMultAndRelinearize(
283     ConstCiphertext<NativePoly> ct1, ConstCiphertext<NativePoly> ct,
284     const vector<LPEvalKey<NativePoly>> &ek) const {
285   NONATIVEPOLY
286 }
287 
288 template <>
MultipartyDecryptFusion(const vector<Ciphertext<Poly>> & ciphertextVec,NativePoly * plaintext) const289 DecryptResult LPAlgorithmMultipartyBFVrnsB<Poly>::MultipartyDecryptFusion(
290     const vector<Ciphertext<Poly>> &ciphertextVec,
291     NativePoly *plaintext) const {
292   NOPOLY
293 }
294 
295 template <>
MultipartyDecryptFusion(const vector<Ciphertext<NativePoly>> & ciphertextVec,NativePoly * plaintext) const296 DecryptResult LPAlgorithmMultipartyBFVrnsB<NativePoly>::MultipartyDecryptFusion(
297     const vector<Ciphertext<NativePoly>> &ciphertextVec,
298     NativePoly *plaintext) const {
299   NONATIVEPOLY
300 }
301 
302 template <>
MultiKeySwitchGen(const LPPrivateKey<Poly> originalPrivateKey,const LPPrivateKey<Poly> newPrivateKey,const LPEvalKey<Poly> ek) const303 LPEvalKey<Poly> LPAlgorithmMultipartyBFVrnsB<Poly>::MultiKeySwitchGen(
304     const LPPrivateKey<Poly> originalPrivateKey,
305     const LPPrivateKey<Poly> newPrivateKey, const LPEvalKey<Poly> ek) const {
306   NOPOLY
307 }
308 
309 template <>
310 LPEvalKey<NativePoly>
MultiKeySwitchGen(const LPPrivateKey<NativePoly> originalPrivateKey,const LPPrivateKey<NativePoly> newPrivateKey,const LPEvalKey<NativePoly> ek) const311 LPAlgorithmMultipartyBFVrnsB<NativePoly>::MultiKeySwitchGen(
312     const LPPrivateKey<NativePoly> originalPrivateKey,
313     const LPPrivateKey<NativePoly> newPrivateKey,
314     const LPEvalKey<NativePoly> ek) const {
315   NONATIVEPOLY
316 }
317 
318 template class LPCryptoParametersBFVrnsB<Poly>;
319 template class LPPublicKeyEncryptionSchemeBFVrnsB<Poly>;
320 template class LPAlgorithmBFVrnsB<Poly>;
321 template class LPAlgorithmSHEBFVrnsB<Poly>;
322 template class LPAlgorithmMultipartyBFVrnsB<Poly>;
323 template class LPAlgorithmParamsGenBFVrnsB<Poly>;
324 
325 template class LPCryptoParametersBFVrnsB<NativePoly>;
326 template class LPPublicKeyEncryptionSchemeBFVrnsB<NativePoly>;
327 template class LPAlgorithmBFVrnsB<NativePoly>;
328 template class LPAlgorithmSHEBFVrnsB<NativePoly>;
329 template class LPAlgorithmMultipartyBFVrnsB<NativePoly>;
330 template class LPAlgorithmParamsGenBFVrnsB<NativePoly>;
331 
332 #undef NOPOLY
333 #undef NONATIVEPOLY
334 
335 // Precomputation of CRT tables encryption, decryption, and homomorphic
336 // multiplication
337 template <>
PrecomputeCRTTables()338 bool LPCryptoParametersBFVrnsB<DCRTPoly>::PrecomputeCRTTables() {
339   // read values for the CRT basis
340 
341   size_t sizeQ = GetElementParams()->GetParams().size();
342   auto ringDim = GetElementParams()->GetRingDimension();
343 
344   vector<NativeInteger> moduliQ(sizeQ);
345   vector<NativeInteger> rootsQ(sizeQ);
346 
347   const BigInteger BarrettBase128Bit(
348       "340282366920938463463374607431768211456");       // 2^128
349   const BigInteger TwoPower64("18446744073709551616");  // 2^64
350 
351   m_moduliQ.resize(sizeQ);
352   for (size_t i = 0; i < sizeQ; i++) {
353     moduliQ[i] = GetElementParams()->GetParams()[i]->GetModulus();
354     rootsQ[i] = GetElementParams()->GetParams()[i]->GetRootOfUnity();
355     m_moduliQ[i] = moduliQ[i];
356   }
357 
358   // compute the CRT delta table floor(Q/p) mod qi - used for encryption
359 
360   const BigInteger modulusQ = GetElementParams()->GetModulus();
361 
362   const BigInteger QDivt = modulusQ.DividedBy(GetPlaintextModulus());
363 
364   std::vector<NativeInteger> QDivtModq(sizeQ);
365 
366   for (size_t i = 0; i < sizeQ; i++) {
367     BigInteger qi = BigInteger(moduliQ[i].ConvertToInt());
368     BigInteger QDivtModqi = QDivt.Mod(qi);
369     QDivtModq[i] = NativeInteger(QDivtModqi.ConvertToInt());
370   }
371 
372   m_QDivtModq = QDivtModq;
373 
374   m_modqBarrettMu.resize(sizeQ);
375   for (uint32_t i = 0; i < m_modqBarrettMu.size(); i++) {
376     BigInteger mu = BarrettBase128Bit / BigInteger(m_moduliQ[i]);
377     uint64_t val[2];
378     val[0] = (mu % TwoPower64).ConvertToInt();
379     val[1] = mu.RShift(64).ConvertToInt();
380 
381     memcpy(&m_modqBarrettMu[i], val, sizeof(DoubleNativeInt));
382   }
383 
384   ChineseRemainderTransformFTT<NativeVector>::PreCompute(rootsQ, 2 * ringDim,
385                                                          moduliQ);
386 
387   // Compute Bajard's et al. RNS variant lookup tables
388 
389   // Populate EvalMulrns tables
390   // find the a suitable size of B
391   m_numq = sizeQ;
392 
393   BigInteger t = BigInteger(GetPlaintextModulus());
394   BigInteger Q(GetElementParams()->GetModulus());
395 
396   BigInteger B = 1;
397   BigInteger maxConvolutionValue =
398       BigInteger(2) * BigInteger(ringDim) * Q * Q * t;
399 
400   m_moduliB.push_back(
401       PreviousPrime<NativeInteger>(moduliQ[m_numq - 1], 2 * ringDim));
402   m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * ringDim, m_moduliB[0]));
403   B = B * BigInteger(m_moduliB[0]);
404 
405   for (usint i = 1; i < m_numq; i++) {  // we already added one prime
406     m_moduliB.push_back(
407         PreviousPrime<NativeInteger>(m_moduliB[i - 1], 2 * ringDim));
408     m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * ringDim, m_moduliB[i]));
409 
410     B = B * BigInteger(m_moduliB[i]);
411   }
412 
413   m_numb = m_numq;
414 
415   m_msk = PreviousPrime<NativeInteger>(m_moduliB[m_numq - 1], 2 * ringDim);
416 
417   usint s = 0;
418   NativeInteger tmp = m_msk;
419   while (tmp > 0) {
420     tmp >>= 1;
421     s++;
422   }
423 
424   // check msk is large enough
425   while (Q * B * BigInteger(m_msk) < maxConvolutionValue) {
426     NativeInteger firstInteger = FirstPrime<NativeInteger>(s + 1, 2 * ringDim);
427 
428     m_msk = NextPrime<NativeInteger>(firstInteger, 2 * ringDim);
429     s++;
430     if (s >= 60) PALISADE_THROW(math_error, "msk is larger than 60 bits");
431   }
432   m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * ringDim, m_msk));
433 
434   m_moduliBsk = m_moduliB;
435   m_moduliBsk.push_back(m_msk);
436 
437   m_paramsBsk = std::make_shared<ILDCRTParams<BigInteger>>(
438       2 * ringDim, m_moduliBsk, m_rootsBsk);
439 
440   ChineseRemainderTransformFTT<NativeVector>::PreCompute(
441       m_rootsBsk, 2 * ringDim, m_moduliBsk);
442 
443   // populate Barrett constant for m_BskModuli
444   m_modbskBarrettMu.resize(m_moduliBsk.size());
445   for (uint32_t i = 0; i < m_modbskBarrettMu.size(); i++) {
446     BigInteger mu = BarrettBase128Bit / BigInteger(m_moduliBsk[i]);
447     uint64_t val[2];
448     val[0] = (mu % TwoPower64).ConvertToInt();
449     val[1] = mu.RShift(64).ConvertToInt();
450 
451     memcpy(&m_modbskBarrettMu[i], val, sizeof(DoubleNativeInt));
452   }
453 
454   // Populate [(Q/q_i)^-1]_{q_i}
455   m_QHatInvModq.resize(m_numq);
456   for (uint32_t i = 0; i < m_QHatInvModq.size(); i++) {
457     BigInteger QHatInvModqi;
458     QHatInvModqi = Q.DividedBy(moduliQ[i]);
459     QHatInvModqi = QHatInvModqi.Mod(moduliQ[i]);
460     QHatInvModqi = QHatInvModqi.ModInverse(moduliQ[i]);
461     m_QHatInvModq[i] = QHatInvModqi.ConvertToInt();
462   }
463 
464   // Populate [t*(Q/q_i)^-1]_{q_i}
465   m_tQHatInvModq.resize(m_numq);
466   m_tQHatInvModqPrecon.resize(m_numq);
467   for (uint32_t i = 0; i < m_tQHatInvModq.size(); i++) {
468     BigInteger tQHatInvModqi;
469     tQHatInvModqi = Q.DividedBy(moduliQ[i]);
470     tQHatInvModqi = tQHatInvModqi.Mod(moduliQ[i]);
471     tQHatInvModqi = tQHatInvModqi.ModInverse(moduliQ[i]);
472     tQHatInvModqi = tQHatInvModqi.ModMul(t.ConvertToInt(), moduliQ[i]);
473     m_tQHatInvModq[i] = tQHatInvModqi.ConvertToInt();
474     m_tQHatInvModqPrecon[i] = m_tQHatInvModq[i].PrepModMulConst(moduliQ[i]);
475   }
476 
477   // Populate [Q/q_i]_{bsk_j, mtilde}
478   m_QHatModbsk.resize(m_numq);
479   m_QHatModmtilde.resize(m_numq);
480   for (uint32_t i = 0; i < m_QHatModbsk.size(); i++) {
481     m_QHatModbsk[i].resize(m_numb + 1);
482 
483     BigInteger QHati = Q.DividedBy(moduliQ[i]);
484     for (uint32_t j = 0; j < m_QHatModbsk[i].size(); j++) {
485       BigInteger QHatiModbskj = QHati.Mod(m_moduliBsk[j]);
486       m_QHatModbsk[i][j] = QHatiModbskj.ConvertToInt();
487     }
488     m_QHatModmtilde[i] = QHati.Mod(m_mtilde).ConvertToInt();
489   }
490 
491   // Populate [1/q_i]_{bsk_j}
492   m_qInvModbsk.resize(m_numq);
493   for (uint32_t i = 0; i < m_qInvModbsk.size(); i++) {
494     m_qInvModbsk[i].resize(m_numb + 1);
495     for (uint32_t j = 0; j < m_qInvModbsk[i].size(); j++)
496       m_qInvModbsk[i][j] = moduliQ[i].ModInverse(m_moduliBsk[j]);
497   }
498 
499   // Populate [mtilde*(Q/q_i)^{-1}]_{q_i}
500   m_mtildeQHatInvModq.resize(m_numq);
501   m_mtildeQHatInvModqPrecon.resize(m_numq);
502 
503   BigInteger bmtilde(m_mtilde);
504   for (uint32_t i = 0; i < m_mtildeQHatInvModq.size(); i++) {
505     BigInteger mtildeQHatInvModqi = Q.DividedBy(moduliQ[i]);
506     mtildeQHatInvModqi = mtildeQHatInvModqi.Mod(moduliQ[i]);
507     mtildeQHatInvModqi = mtildeQHatInvModqi.ModInverse(moduliQ[i]);
508     mtildeQHatInvModqi = mtildeQHatInvModqi * bmtilde;
509     mtildeQHatInvModqi = mtildeQHatInvModqi.Mod(moduliQ[i]);
510     m_mtildeQHatInvModq[i] = mtildeQHatInvModqi.ConvertToInt();
511     m_mtildeQHatInvModqPrecon[i] =
512         m_mtildeQHatInvModq[i].PrepModMulConst(moduliQ[i]);
513   }
514 
515   // Populate [-Q^{-1}]_{mtilde}
516   BigInteger negQInvModmtilde =
517       (BigInteger(m_mtilde - 1) * Q.ModInverse(m_mtilde));
518   negQInvModmtilde = negQInvModmtilde.Mod(m_mtilde);
519   m_negQInvModmtilde = negQInvModmtilde.ConvertToInt();
520 
521   // Populate [Q]_{bski_j}
522   m_QModbsk.resize(m_numq + 1);
523   m_QModbskPrecon.resize(m_numq + 1);
524 
525   for (uint32_t j = 0; j < m_QModbsk.size(); j++) {
526     BigInteger QModbskij = Q.Mod(m_moduliBsk[j]);
527     m_QModbsk[j] = QModbskij.ConvertToInt();
528     m_QModbskPrecon[j] = m_QModbsk[j].PrepModMulConst(m_moduliBsk[j]);
529   }
530 
531   // Populate [mtilde^{-1}]_{bsk_j}
532   m_mtildeInvModbsk.resize(m_numb + 1);
533   m_mtildeInvModbskPrecon.resize(m_numb + 1);
534   for (uint32_t j = 0; j < m_mtildeInvModbsk.size(); j++) {
535     BigInteger mtildeInvModbskij = m_mtilde % m_moduliBsk[j];
536     mtildeInvModbskij = mtildeInvModbskij.ModInverse(m_moduliBsk[j]);
537     m_mtildeInvModbsk[j] = mtildeInvModbskij.ConvertToInt();
538     m_mtildeInvModbskPrecon[j] =
539         m_mtildeInvModbsk[j].PrepModMulConst(m_moduliBsk[j]);
540   }
541 
542   // Populate {t/Q}_{bsk_j}
543   m_tQInvModbsk.resize(m_numb + 1);
544   m_tQInvModbskPrecon.resize(m_numb + 1);
545 
546   for (uint32_t i = 0; i < m_tQInvModbsk.size(); i++) {
547     BigInteger tDivqModBski = Q.ModInverse(m_moduliBsk[i]);
548     tDivqModBski.ModMulEq(t.ConvertToInt(), m_moduliBsk[i]);
549     m_tQInvModbsk[i] = tDivqModBski.ConvertToInt();
550     m_tQInvModbskPrecon[i] = m_tQInvModbsk[i].PrepModMulConst(m_moduliBsk[i]);
551   }
552 
553   // Populate [(B/b_j)^{-1}]_{b_j}
554   m_BHatInvModb.resize(m_numb);
555   m_BHatInvModbPrecon.resize(m_numb);
556 
557   for (uint32_t i = 0; i < m_BHatInvModb.size(); i++) {
558     BigInteger BDivBi;
559     BDivBi = B.DividedBy(m_moduliB[i]);
560     BDivBi = BDivBi.Mod(m_moduliB[i]);
561     BDivBi = BDivBi.ModInverse(m_moduliB[i]);
562     m_BHatInvModb[i] = BDivBi.ConvertToInt();
563     m_BHatInvModbPrecon[i] = m_BHatInvModb[i].PrepModMulConst(m_moduliB[i]);
564   }
565 
566   // Populate [B/b_j]_{q_i}
567   m_BHatModq.resize(m_numb);
568   for (uint32_t i = 0; i < m_BHatModq.size(); i++) {
569     m_BHatModq[i].resize(m_numq);
570     BigInteger BDivBi = B.DividedBy(m_moduliB[i]);
571     for (uint32_t j = 0; j < m_BHatModq[i].size(); j++) {
572       BigInteger BDivBiModqj = BDivBi.Mod(moduliQ[j]);
573       m_BHatModq[i][j] = BDivBiModqj.ConvertToInt();
574     }
575   }
576 
577   // Populate [B/b_j]_{msk}
578   m_BHatModmsk.resize(m_numb);
579   for (uint32_t i = 0; i < m_BHatModmsk.size(); i++) {
580     BigInteger BDivBi = B.DividedBy(m_moduliB[i]);
581     m_BHatModmsk[i] = (BDivBi.Mod(m_msk)).ConvertToInt();
582   }
583 
584   // Populate [B^{-1}]_{msk}
585   m_BInvModmsk = (B.ModInverse(m_msk)).ConvertToInt();
586   m_BInvModmskPrecon = m_BInvModmsk.PrepModMulConst(m_msk);
587 
588   // Populate [B]_{q_i}
589   m_BModq.resize(m_numq);
590   m_BModqPrecon.resize(m_numq);
591   for (uint32_t i = 0; i < m_BModq.size(); i++) {
592     m_BModq[i] = (B.Mod(moduliQ[i])).ConvertToInt();
593     m_BModqPrecon[i] = m_BModq[i].PrepModMulConst(moduliQ[i]);
594   }
595 
596   // Populate Decrns lookup tables
597 
598   NativeInteger tgamma = NativeInteger(t.ConvertToInt() * m_gamma);  // t*gamma
599 
600   m_tgamma = tgamma;
601 
602   // Populate [-1/q_i]_{t*gamma} (t*gamma < 2^58)
603   m_negInvqModtgamma.resize(m_numq);
604   m_negInvqModtgammaPrecon.resize(m_numq);
605   for (uint32_t i = 0; i < m_negInvqModtgamma.size(); i++) {
606     BigInteger imod(moduliQ[i]);
607     BigInteger negInvqi = BigInteger((tgamma - 1)) * imod.ModInverse(tgamma);
608 
609     BigInteger negInvqiModtgamma = negInvqi.Mod(tgamma);
610     m_negInvqModtgamma[i] = negInvqiModtgamma.ConvertToInt();
611     m_negInvqModtgammaPrecon[i] = m_negInvqModtgamma[i].PrepModMulConst(tgamma);
612   }
613 
614   // populate [t*gamma*(Q/q_i)^(-1)]_{q_i}
615   m_tgammaQHatInvModq.resize(m_numq);
616   m_tgammaQHatInvModqPrecon.resize(m_numq);
617 
618   BigInteger bmgamma(m_gamma);
619   for (uint32_t i = 0; i < m_tgammaQHatInvModq.size(); i++) {
620     BigInteger qDivqi = Q.DividedBy(moduliQ[i]);
621     BigInteger imod(moduliQ[i]);
622     qDivqi = qDivqi.ModInverse(moduliQ[i]);
623     BigInteger gammaqDivqi = (qDivqi * bmgamma) % imod;
624     BigInteger tgammaqDivqi = (gammaqDivqi * t) % imod;
625     m_tgammaQHatInvModq[i] = tgammaqDivqi.ConvertToInt();
626     m_tgammaQHatInvModqPrecon[i] =
627         m_tgammaQHatInvModq[i].PrepModMulConst(moduliQ[i]);
628   }
629 
630   return true;
631 }
632 
633 // Parameter generation for BFV-RNS
634 template <>
ParamsGen(shared_ptr<LPCryptoParameters<DCRTPoly>> cryptoParams,int32_t evalAddCount,int32_t evalMultCount,int32_t keySwitchCount,size_t dcrtBits,uint32_t nCustom) const635 bool LPAlgorithmParamsGenBFVrnsB<DCRTPoly>::ParamsGen(
636     shared_ptr<LPCryptoParameters<DCRTPoly>> cryptoParams, int32_t evalAddCount,
637     int32_t evalMultCount, int32_t keySwitchCount, size_t dcrtBits,
638     uint32_t nCustom) const {
639   if (!cryptoParams)
640     PALISADE_THROW(not_available_error,
641                    "No crypto parameters are supplied to BFVrns ParamsGen");
642 
643   if ((dcrtBits < 30) || (dcrtBits > 60))
644     PALISADE_THROW(math_error,
645                    "BFVrns.ParamsGen: Number of bits in CRT moduli should be "
646                    "in the range from 30 to 60");
647 
648   const auto cryptoParamsBFVrnsB =
649       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
650           cryptoParams);
651 
652   double sigma = cryptoParamsBFVrnsB->GetDistributionParameter();
653   double alpha = cryptoParamsBFVrnsB->GetAssuranceMeasure();
654   double hermiteFactor = cryptoParamsBFVrnsB->GetSecurityLevel();
655   double p = static_cast<double>(cryptoParamsBFVrnsB->GetPlaintextModulus());
656   uint32_t relinWindow = cryptoParamsBFVrnsB->GetRelinWindow();
657   SecurityLevel stdLevel = cryptoParamsBFVrnsB->GetStdLevel();
658 
659   // Bound of the Gaussian error polynomial
660   double Berr = sigma * sqrt(alpha);
661 
662   // Bound of the key polynomial
663   double Bkey;
664 
665   DistributionType distType;
666 
667   // supports both discrete Gaussian (RLWE) and ternary uniform distribution
668   // (OPTIMIZED) cases
669   if (cryptoParamsBFVrnsB->GetMode() == RLWE) {
670     Bkey = sigma * sqrt(alpha);
671     distType = HEStd_error;
672   } else {
673     Bkey = 1;
674     distType = HEStd_ternary;
675   }
676 
677   // expansion factor delta
678   auto delta = [](uint32_t n) -> double { return (2. * sqrt(n)); };
679 
680   // norm of fresh ciphertext polynomial
681   auto Vnorm = [&](uint32_t n) -> double {
682     return Berr * (1. + 2. * delta(n) * Bkey);
683   };
684 
685   // RLWE security constraint
686   auto nRLWE = [&](double logq) -> double {
687     if (stdLevel == HEStd_NotSet) {
688       return (logq - log(sigma)) / (4. * log(hermiteFactor));
689     } else {
690       return static_cast<double>(StdLatticeParm::FindRingDim(
691           distType, stdLevel, static_cast<long>(ceil(logq / log(2)))));
692     }
693   };
694 
695   // initial values
696   uint32_t n = (nCustom > 0) ? nCustom : 512;
697 
698   double logq = 0.;
699 
700   // only public key encryption and EvalAdd (optional when evalAddCount = 0)
701   // operations are supported the correctness constraint from section 3.5 of
702   // https://eprint.iacr.org/2014/062.pdf is used
703   if ((evalMultCount == 0) && (keySwitchCount == 0)) {
704     // Correctness constraint
705     auto logqBFV = [&](uint32_t n) -> double {
706       return log(p *
707                  (4 * ((evalAddCount + 1) * Vnorm(n) + evalAddCount * p) + p));
708     };
709 
710     // initial value
711     logq = logqBFV(n);
712 
713     if ((nRLWE(logq) > n) && (nCustom > 0))
714       PALISADE_THROW(config_error,
715                      "Ring dimension n specified by the user does not meet the "
716                      "security requirement. Please increase it.");
717 
718     while (nRLWE(logq) > n) {
719       n = 2 * n;
720       logq = logqBFV(n);
721     }
722 
723     // this code updates n and q to account for the discrete size of CRT moduli
724     // = dcrtBits
725 
726     int32_t k =
727         static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
728 
729     double logqCeil = k * dcrtBits * log(2);
730 
731     while (nRLWE(logqCeil) > n) {
732       n = 2 * n;
733       logq = logqBFV(n);
734       k = static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
735       logqCeil = k * dcrtBits * log(2);
736     }
737   } else if ((evalMultCount == 0) && (keySwitchCount > 0) &&
738              (evalAddCount == 0)) {
739     // this case supports automorphism w/o any other operations
740     // base for relinearization
741 
742     double w = relinWindow == 0 ? pow(2, dcrtBits) : pow(2, relinWindow);
743 
744     // Correctness constraint
745     auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
746       return log(
747           p * (4 * (Vnorm(n) + keySwitchCount * delta(n) *
748                                    (floor(logqPrev / (log(2) * dcrtBits)) + 1) *
749                                    w * Berr) +
750                p));
751     };
752 
753     // initial values
754     double logqPrev = 6 * log(10);
755     logq = logqBFV(n, logqPrev);
756     logqPrev = logq;
757 
758     if ((nRLWE(logq) > n) && (nCustom > 0))
759       PALISADE_THROW(config_error,
760                      "Ring dimension n specified by the user does not meet the "
761                      "security requirement. Please increase it.");
762 
763     // this "while" condition is needed in case the iterative solution for q
764     // changes the requirement for n, which is rare but still theoretically
765     // possible
766     while (nRLWE(logq) > n) {
767       while (nRLWE(logq) > n) {
768         n = 2 * n;
769         logq = logqBFV(n, logqPrev);
770         logqPrev = logq;
771       }
772 
773       logq = logqBFV(n, logqPrev);
774 
775       while (fabs(logq - logqPrev) > log(1.001)) {
776         logqPrev = logq;
777         logq = logqBFV(n, logqPrev);
778       }
779 
780       // this code updates n and q to account for the discrete size of CRT
781       // moduli = dcrtBits
782 
783       int32_t k =
784           static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
785 
786       double logqCeil = k * dcrtBits * log(2);
787       logqPrev = logqCeil;
788 
789       while (nRLWE(logqCeil) > n) {
790         n = 2 * n;
791         logq = logqBFV(n, logqPrev);
792         k = static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
793         logqCeil = k * dcrtBits * log(2);
794         logqPrev = logqCeil;
795       }
796     }
797   } else if ((evalAddCount == 0) && (evalMultCount > 0) &&
798              (keySwitchCount == 0)) {
799     // Only EvalMult operations are used in the correctness constraint
800     // the correctness constraint from section 3.5 of
801     // https://eprint.iacr.org/2014/062.pdf is used
802 
803     // base for relinearization
804     double w = relinWindow == 0 ? pow(2, dcrtBits) : pow(2, relinWindow);
805 
806     // function used in the EvalMult constraint
807     auto epsilon1 = [&](uint32_t n) -> double { return 5 / (delta(n) * Bkey); };
808 
809     // function used in the EvalMult constraint
810     auto C1 = [&](uint32_t n) -> double {
811       return (1 + epsilon1(n)) * delta(n) * delta(n) * p * Bkey;
812     };
813 
814     // function used in the EvalMult constraint
815     auto C2 = [&](uint32_t n, double logqPrev) -> double {
816       return delta(n) * delta(n) * Bkey * ((1 + 0.5) * Bkey + p * p) +
817              delta(n) * (floor(logqPrev / (log(2) * dcrtBits)) + 1) * w * Berr;
818     };
819 
820     // main correctness constraint
821     auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
822       return log(4 * p) + (evalMultCount - 1) * log(C1(n)) +
823              log(C1(n) * Vnorm(n) + evalMultCount * C2(n, logqPrev));
824     };
825 
826     // initial values
827     double logqPrev = 6. * log(10);
828     logq = logqBFV(n, logqPrev);
829     logqPrev = logq;
830 
831     if ((nRLWE(logq) > n) && (nCustom > 0))
832       PALISADE_THROW(config_error,
833                      "Ring dimension n specified by the user does not meet the "
834                      "security requirement. Please increase it.");
835 
836     // this "while" condition is needed in case the iterative solution for q
837     // changes the requirement for n, which is rare but still theoretically
838     // possible
839     while (nRLWE(logq) > n) {
840       while (nRLWE(logq) > n) {
841         n = 2 * n;
842         logq = logqBFV(n, logqPrev);
843         logqPrev = logq;
844       }
845 
846       logq = logqBFV(n, logqPrev);
847 
848       while (fabs(logq - logqPrev) > log(1.001)) {
849         logqPrev = logq;
850         logq = logqBFV(n, logqPrev);
851       }
852 
853       // this code updates n and q to account for the discrete size of CRT
854       // moduli = dcrtBits
855 
856       int32_t k =
857           static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
858 
859       double logqCeil = k * dcrtBits * log(2);
860       logqPrev = logqCeil;
861 
862       while (nRLWE(logqCeil) > n) {
863         n = 2 * n;
864         logq = logqBFV(n, logqPrev);
865         k = static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
866         logqCeil = k * dcrtBits * log(2);
867         logqPrev = logqCeil;
868       }
869     }
870   }
871 
872   size_t sizeQ =
873       static_cast<size_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
874 
875   vector<NativeInteger> moduliQ(sizeQ);
876   vector<NativeInteger> rootsQ(sizeQ);
877 
878   // makes sure the first integer is less than 2^60-1 to take advantage of NTL
879   // optimizations
880   NativeInteger firstInteger = FirstPrime<NativeInteger>(dcrtBits, 2 * n);
881 
882   moduliQ[0] = PreviousPrime<NativeInteger>(firstInteger, 2 * n);
883   rootsQ[0] = RootOfUnity<NativeInteger>(2 * n, moduliQ[0]);
884 
885   for (size_t i = 1; i < sizeQ; i++) {
886     moduliQ[i] = PreviousPrime<NativeInteger>(moduliQ[i - 1], 2 * n);
887     rootsQ[i] = RootOfUnity<NativeInteger>(2 * n, moduliQ[i]);
888   }
889 
890   auto params =
891       std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQ, rootsQ);
892 
893   ChineseRemainderTransformFTT<NativeVector>::PreCompute(rootsQ, 2 * n,
894                                                          moduliQ);
895 
896   cryptoParamsBFVrnsB->SetElementParams(params);
897 
898   const EncodingParams encodingParams = cryptoParamsBFVrnsB->GetEncodingParams();
899   if (encodingParams->GetBatchSize() > n)
900     PALISADE_THROW(config_error,
901                    "The batch size cannot be larger than the ring dimension.");
902 
903   // if no batch size was specified, we set batchSize = n by default (for full
904   // packing)
905   if (encodingParams->GetBatchSize() == 0) {
906     uint32_t batchSize = n;
907     EncodingParams encodingParamsNew(std::make_shared<EncodingParamsImpl>(
908         encodingParams->GetPlaintextModulus(), batchSize));
909     cryptoParamsBFVrnsB->SetEncodingParams(encodingParamsNew);
910   }
911 
912   return cryptoParamsBFVrnsB->PrecomputeCRTTables();
913 }
914 
915 template <>
Encrypt(const LPPublicKey<DCRTPoly> publicKey,DCRTPoly ptxt) const916 Ciphertext<DCRTPoly> LPAlgorithmBFVrnsB<DCRTPoly>::Encrypt(
917     const LPPublicKey<DCRTPoly> publicKey, DCRTPoly ptxt) const {
918   Ciphertext<DCRTPoly> ciphertext(
919       std::make_shared<CiphertextImpl<DCRTPoly>>(publicKey));
920 
921   const auto cryptoParams =
922       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
923           publicKey->GetCryptoParameters());
924 
925   const shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
926 
927   ptxt.SetFormat(Format::EVALUATION);
928 
929   const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
930 
931   const DggType &dgg = cryptoParams->GetDiscreteGaussianGenerator();
932   TugType tug;
933 
934   const DCRTPoly &p0 = publicKey->GetPublicElements().at(0);
935   const DCRTPoly &p1 = publicKey->GetPublicElements().at(1);
936 
937   DCRTPoly u;
938 
939   // Supports both discrete Gaussian (RLWE) and ternary uniform distribution
940   // (OPTIMIZED) cases
941   if (cryptoParams->GetMode() == RLWE)
942     u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
943   else
944     u = DCRTPoly(tug, elementParams, Format::EVALUATION);
945 
946   DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
947   DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
948 
949   DCRTPoly c0(elementParams);
950   DCRTPoly c1(elementParams);
951 
952   c0 = p0 * u + e1 + ptxt.Times(delta);
953 
954   c1 = p1 * u + e2;
955 
956   ciphertext->SetElements({std::move(c0), std::move(c1)});
957 
958   return ciphertext;
959 }
960 
961 template <>
Decrypt(const LPPrivateKey<DCRTPoly> privateKey,ConstCiphertext<DCRTPoly> ciphertext,NativePoly * plaintext) const962 DecryptResult LPAlgorithmBFVrnsB<DCRTPoly>::Decrypt(
963     const LPPrivateKey<DCRTPoly> privateKey,
964     ConstCiphertext<DCRTPoly> ciphertext, NativePoly *plaintext) const {
965   // TimeVar t_total;
966 
967   // TIC(t_total);
968 
969   const auto cryptoParamsBFVrnsB =
970       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
971           privateKey->GetCryptoParameters());
972   const shared_ptr<ParmType> elementParams =
973       cryptoParamsBFVrnsB->GetElementParams();
974 
975   const std::vector<DCRTPoly> &c = ciphertext->GetElements();
976 
977   const DCRTPoly &s = privateKey->GetPrivateElement();
978   DCRTPoly sPower = s;
979 
980   DCRTPoly b = c[0];
981   b.SetFormat(Format::EVALUATION);
982 
983   DCRTPoly cTemp;
984   for (size_t i = 1; i <= ciphertext->GetDepth(); i++) {
985     cTemp = c[i];
986     cTemp.SetFormat(Format::EVALUATION);
987 
988     b += sPower * cTemp;
989     sPower *= s;
990   }
991 
992   // Converts back to Format::COEFFICIENT representation
993   b.SetFormat(Format::COEFFICIENT);
994 
995   auto &t = cryptoParamsBFVrnsB->GetPlaintextModulus();
996   auto &tgamma = cryptoParamsBFVrnsB->Gettgamma();
997   const std::vector<NativeInteger> &moduliQ = cryptoParamsBFVrnsB->GetModuliQ();
998   const std::vector<NativeInteger> &tgammaQHatInvModq =
999       cryptoParamsBFVrnsB->GettgammaQHatInvModq();
1000   const std::vector<NativeInteger> &tgammaQHatInvModqPrecon =
1001       cryptoParamsBFVrnsB->GettgammaQHatInvModqPrecon();
1002   const std::vector<NativeInteger> &negInvqModtgamma =
1003       cryptoParamsBFVrnsB->GetNegInvqModtgamma();
1004   const std::vector<NativeInteger> &negInvqModtgammaPrecon =
1005       cryptoParamsBFVrnsB->GetNegInvqModtgammaPrecon();
1006 
1007   // this is the resulting vector of coefficients;
1008   *plaintext = b.ScaleAndRound(moduliQ, t, tgamma, tgammaQHatInvModq,
1009                                tgammaQHatInvModqPrecon, negInvqModtgamma,
1010                                negInvqModtgammaPrecon);
1011 
1012   // std::cout << "Decryption time (internal): " << TOC_US(t_total) << " us" <<
1013   // std::endl;
1014 
1015   return DecryptResult(plaintext->GetLength());
1016 }
1017 
1018 template <>
Encrypt(const LPPrivateKey<DCRTPoly> privateKey,DCRTPoly ptxt) const1019 Ciphertext<DCRTPoly> LPAlgorithmBFVrnsB<DCRTPoly>::Encrypt(
1020     const LPPrivateKey<DCRTPoly> privateKey, DCRTPoly ptxt) const {
1021   Ciphertext<DCRTPoly> ciphertext(
1022       std::make_shared<CiphertextImpl<DCRTPoly>>(privateKey));
1023 
1024   const auto cryptoParams =
1025       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1026           privateKey->GetCryptoParameters());
1027 
1028   const shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
1029 
1030   ptxt.SwitchFormat();
1031 
1032   const DggType &dgg = cryptoParams->GetDiscreteGaussianGenerator();
1033   DugType dug;
1034 
1035   const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
1036 
1037   DCRTPoly a(dug, elementParams, Format::EVALUATION);
1038   const DCRTPoly &s = privateKey->GetPrivateElement();
1039   DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1040 
1041   DCRTPoly c0(a * s + e + ptxt.Times(delta));
1042   DCRTPoly c1(elementParams, Format::EVALUATION, true);
1043   c1 -= a;
1044 
1045   ciphertext->SetElements({std::move(c0), std::move(c1)});
1046 
1047   return ciphertext;
1048 }
1049 
1050 template <>
EvalAdd(ConstCiphertext<DCRTPoly> ciphertext,ConstPlaintext plaintext) const1051 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalAdd(
1052     ConstCiphertext<DCRTPoly> ciphertext, ConstPlaintext plaintext) const {
1053   Ciphertext<DCRTPoly> newCiphertext = ciphertext->CloneEmpty();
1054   newCiphertext->SetDepth(ciphertext->GetDepth());
1055 
1056   const std::vector<DCRTPoly> &cipherTextElements = ciphertext->GetElements();
1057 
1058   const DCRTPoly &ptElement = plaintext->GetElement<DCRTPoly>();
1059 
1060   std::vector<DCRTPoly> c(cipherTextElements.size());
1061 
1062   const auto cryptoParams =
1063       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1064           ciphertext->GetCryptoParameters());
1065 
1066   const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
1067 
1068   c[0] = cipherTextElements[0] + ptElement.Times(delta);
1069 
1070   for (size_t i = 1; i < cipherTextElements.size(); i++) {
1071     c[i] = cipherTextElements[i];
1072   }
1073 
1074   newCiphertext->SetElements(std::move(c));
1075 
1076   return newCiphertext;
1077 }
1078 
1079 template <>
EvalSub(ConstCiphertext<DCRTPoly> ciphertext,ConstPlaintext plaintext) const1080 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalSub(
1081     ConstCiphertext<DCRTPoly> ciphertext, ConstPlaintext plaintext) const {
1082   Ciphertext<DCRTPoly> newCiphertext = ciphertext->CloneEmpty();
1083   newCiphertext->SetDepth(ciphertext->GetDepth());
1084 
1085   const std::vector<DCRTPoly> &cipherTextElements = ciphertext->GetElements();
1086 
1087   plaintext->SetFormat(Format::EVALUATION);
1088   const DCRTPoly &ptElement = plaintext->GetElement<DCRTPoly>();
1089 
1090   std::vector<DCRTPoly> c(cipherTextElements.size());
1091 
1092   const auto cryptoParams =
1093       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1094           ciphertext->GetCryptoParameters());
1095 
1096   const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
1097 
1098   c[0] = cipherTextElements[0] - ptElement.Times(delta);
1099 
1100   for (size_t i = 1; i < cipherTextElements.size(); i++) {
1101     c[i] = cipherTextElements[i];
1102   }
1103 
1104   newCiphertext->SetElements(std::move(c));
1105 
1106   return newCiphertext;
1107 }
1108 
1109 template <>
EvalMult(ConstCiphertext<DCRTPoly> ciphertext1,ConstCiphertext<DCRTPoly> ciphertext2) const1110 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalMult(
1111     ConstCiphertext<DCRTPoly> ciphertext1,
1112     ConstCiphertext<DCRTPoly> ciphertext2) const {
1113   if (!(ciphertext1->GetCryptoParameters() ==
1114         ciphertext2->GetCryptoParameters())) {
1115     std::string errMsg =
1116         "LPAlgorithmSHEBFVrnsB::EvalMult crypto parameters are not the same";
1117     PALISADE_THROW(config_error, errMsg);
1118   }
1119 
1120   Ciphertext<DCRTPoly> newCiphertext = ciphertext1->CloneEmpty();
1121 
1122   const auto cryptoParamsBFVrnsB =
1123       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1124           ciphertext1->GetCryptoContext()->GetCryptoParameters());
1125 
1126   // Get the ciphertext elements
1127   std::vector<DCRTPoly> cipherText1Elements = ciphertext1->GetElements();
1128   std::vector<DCRTPoly> cipherText2Elements = ciphertext2->GetElements();
1129 
1130   size_t cipherText1ElementsSize = cipherText1Elements.size();
1131   size_t cipherText2ElementsSize = cipherText2Elements.size();
1132   size_t cipherTextRElementsSize =
1133       cipherText1ElementsSize + cipherText2ElementsSize - 1;
1134 
1135   std::vector<DCRTPoly> c(cipherTextRElementsSize);
1136 
1137   const shared_ptr<ParmType> elementParams =
1138       cryptoParamsBFVrnsB->GetElementParams();
1139   const shared_ptr<ILDCRTParams<BigInteger>> paramsBsk =
1140       cryptoParamsBFVrnsB->GetParamsBsk();
1141   const std::vector<NativeInteger> &moduliQ = cryptoParamsBFVrnsB->GetModuliQ();
1142   const std::vector<DoubleNativeInt> &modqBarrettMu =
1143       cryptoParamsBFVrnsB->GetModqBarrettMu();
1144   const std::vector<NativeInteger> &moduliBsk =
1145       cryptoParamsBFVrnsB->GetModuliBsk();
1146   const std::vector<DoubleNativeInt> &modbskBarrettMu =
1147       cryptoParamsBFVrnsB->GetModbskBarrettMu();
1148   const std::vector<NativeInteger> &mtildeQHatInvModq =
1149       cryptoParamsBFVrnsB->GetmtildeQHatInvModq();
1150   const std::vector<NativeInteger> &mtildeQHatInvModqPrecon =
1151       cryptoParamsBFVrnsB->GetmtildeQHatInvModqPrecon();
1152   const std::vector<std::vector<NativeInteger>> &QHatModbsk =
1153       cryptoParamsBFVrnsB->GetQHatModbsk();
1154   const std::vector<uint16_t> &QHatModmtilde =
1155       cryptoParamsBFVrnsB->GetQHatModmtilde();
1156   const std::vector<NativeInteger> &QModbsk = cryptoParamsBFVrnsB->GetQModbsk();
1157   const std::vector<NativeInteger> &QModbskPrecon =
1158       cryptoParamsBFVrnsB->GetQModbskPrecon();
1159   const uint16_t &negQInvModmtilde = cryptoParamsBFVrnsB->GetNegQInvModmtilde();
1160   // const NativeInteger &negQInvModmtildePrecon =
1161   //     cryptoParamsBFVrnsB->GetNegQInvModmtildePrecon();
1162   const std::vector<NativeInteger> &mtildeInvModbsk =
1163       cryptoParamsBFVrnsB->GetmtildeInvModbsk();
1164   const std::vector<NativeInteger> &mtildeInvModbskPrecon =
1165       cryptoParamsBFVrnsB->GetmtildeInvModbskPrecon();
1166 
1167   // Expands the CRT basis to q*Bsk; Outputs the polynomials in coeff
1168   // representation
1169 
1170   for (size_t i = 0; i < cipherText1ElementsSize; i++) {
1171     cipherText1Elements[i].FastBaseConvqToBskMontgomery(
1172         paramsBsk, moduliQ, moduliBsk, modbskBarrettMu, mtildeQHatInvModq,
1173         mtildeQHatInvModqPrecon, QHatModbsk, QHatModmtilde, QModbsk,
1174         QModbskPrecon, negQInvModmtilde, mtildeInvModbsk,
1175         mtildeInvModbskPrecon);
1176 
1177     cipherText1Elements[i].SetFormat(Format::EVALUATION);
1178   }
1179 
1180   for (size_t i = 0; i < cipherText2ElementsSize; i++) {
1181     cipherText2Elements[i].FastBaseConvqToBskMontgomery(
1182         paramsBsk, moduliQ, moduliBsk, modbskBarrettMu, mtildeQHatInvModq,
1183         mtildeQHatInvModqPrecon, QHatModbsk, QHatModmtilde, QModbsk,
1184         QModbskPrecon, negQInvModmtilde, mtildeInvModbsk,
1185         mtildeInvModbskPrecon);
1186 
1187     cipherText2Elements[i].SetFormat(Format::EVALUATION);
1188   }
1189 
1190   // Performs the multiplication itself
1191 
1192 #ifdef USE_KARATSUBA
1193 
1194   if (cipherText1ElementsSize == 2 && cipherText2ElementsSize == 2) {
1195     // size of each ciphertxt = 2, use Karatsuba
1196     c[0] = cipherText1Elements[0] * cipherText2Elements[0];  // a
1197     c[2] = cipherText1Elements[1] * cipherText2Elements[1];  // b
1198 
1199     c[1] = cipherText1Elements[0] + cipherText1Elements[1];
1200     c[1] *= (cipherText2Elements[0] + cipherText2Elements[1]);
1201     c[1] -= c[2];
1202     c[1] -= c[0];
1203 
1204   } else {  // if size of any of the ciphertexts > 2
1205     bool *isFirstAdd = new bool[cipherTextRElementsSize];
1206     std::fill_n(isFirstAdd, cipherTextRElementsSize, true);
1207 
1208     for (size_t i = 0; i < cipherText1ElementsSize; i++) {
1209       for (size_t j = 0; j < cipherText2ElementsSize; j++) {
1210         if (isFirstAdd[i + j] == true) {
1211           c[i + j] = cipherText1Elements[i] * cipherText2Elements[j];
1212           isFirstAdd[i + j] = false;
1213         } else {
1214           c[i + j] += cipherText1Elements[i] * cipherText2Elements[j];
1215         }
1216       }
1217     }
1218 
1219     delete[] isFirstAdd;
1220   }
1221 
1222 #else
1223   bool *isFirstAdd = new bool[cipherTextRElementsSize];
1224   std::fill_n(isFirstAdd, cipherTextRElementsSize, true);
1225 
1226   for (size_t i = 0; i < cipherText1ElementsSize; i++) {
1227     for (size_t j = 0; j < cipherText2ElementsSize; j++) {
1228       if (isFirstAdd[i + j] == true) {
1229         c[i + j] = cipherText1Elements[i] * cipherText2Elements[j];
1230         isFirstAdd[i + j] = false;
1231       } else {
1232         c[i + j] += cipherText1Elements[i] * cipherText2Elements[j];
1233       }
1234     }
1235   }
1236 
1237   delete[] isFirstAdd;
1238 #endif
1239 
1240   // perfrom RNS approximate Flooring
1241   const NativeInteger &t = cryptoParamsBFVrnsB->GetPlaintextModulus();
1242   const std::vector<NativeInteger> &tQHatInvModq =
1243       cryptoParamsBFVrnsB->GettQHatInvModq();
1244   const std::vector<NativeInteger> &tQHatInvModqPrecon =
1245       cryptoParamsBFVrnsB->GettQHatInvModqPrecon();
1246   const std::vector<std::vector<NativeInteger>> &qInvModbsk =
1247       cryptoParamsBFVrnsB->GetqInvModbsk();
1248   const std::vector<NativeInteger> &tQInvModbsk =
1249       cryptoParamsBFVrnsB->GettQInvModbsk();
1250   const std::vector<NativeInteger> &tQInvModbskPrecon =
1251       cryptoParamsBFVrnsB->GettQInvModbskPrecon();
1252 
1253   // perform FastBaseConvSK
1254   const std::vector<NativeInteger> &BHatInvModb =
1255       cryptoParamsBFVrnsB->GetBHatInvModb();
1256   const std::vector<NativeInteger> &BHatInvModbPrecon =
1257       cryptoParamsBFVrnsB->GetBHatInvModbPrecon();
1258   const std::vector<NativeInteger> &BHatModmsk =
1259       cryptoParamsBFVrnsB->GetBHatModmsk();
1260   const NativeInteger &BInvModmsk = cryptoParamsBFVrnsB->GetBInvModmsk();
1261   const NativeInteger &BInvModmskPrecon =
1262       cryptoParamsBFVrnsB->GetBInvModmskPrecon();
1263   const std::vector<std::vector<NativeInteger>> &BHatModq =
1264       cryptoParamsBFVrnsB->GetBHatModq();
1265   const std::vector<NativeInteger> &BModq = cryptoParamsBFVrnsB->GetBModq();
1266   const std::vector<NativeInteger> &BModqPrecon =
1267       cryptoParamsBFVrnsB->GetBModqPrecon();
1268 
1269   for (size_t i = 0; i < cipherTextRElementsSize; i++) {
1270     // converts to Format::COEFFICIENT representation before rounding
1271     c[i].SetFormat(Format::COEFFICIENT);
1272     // Performs the scaling by t/Q followed by rounding; the result is in the
1273     // CRT basis {Bsk}
1274     c[i].FastRNSFloorq(t, moduliQ, moduliBsk, modbskBarrettMu, tQHatInvModq,
1275                        tQHatInvModqPrecon, QHatModbsk, qInvModbsk, tQInvModbsk,
1276                        tQInvModbskPrecon);
1277 
1278     // Converts from the CRT basis {Bsk} to {Q}
1279     c[i].FastBaseConvSK(moduliQ, modqBarrettMu, moduliBsk, modbskBarrettMu,
1280                         BHatInvModb, BHatInvModbPrecon, BHatModmsk, BInvModmsk,
1281                         BInvModmskPrecon, BHatModq, BModq, BModqPrecon);
1282   }
1283 
1284   newCiphertext->SetElements(std::move(c));
1285   newCiphertext->SetDepth((ciphertext1->GetDepth() + ciphertext2->GetDepth()));
1286 
1287   return newCiphertext;
1288 }
1289 
1290 template <>
KeySwitchGen(const LPPrivateKey<DCRTPoly> originalPrivateKey,const LPPrivateKey<DCRTPoly> newPrivateKey) const1291 LPEvalKey<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::KeySwitchGen(
1292     const LPPrivateKey<DCRTPoly> originalPrivateKey,
1293     const LPPrivateKey<DCRTPoly> newPrivateKey) const {
1294   LPEvalKey<DCRTPoly> ek(std::make_shared<LPEvalKeyRelinImpl<DCRTPoly>>(
1295       newPrivateKey->GetCryptoContext()));
1296 
1297   const auto cryptoParamsLWE =
1298       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1299           newPrivateKey->GetCryptoParameters());
1300   const shared_ptr<ParmType> elementParams =
1301       cryptoParamsLWE->GetElementParams();
1302   const DCRTPoly &s = newPrivateKey->GetPrivateElement();
1303 
1304   const DggType &dgg = cryptoParamsLWE->GetDiscreteGaussianGenerator();
1305   DugType dug;
1306 
1307   const DCRTPoly &oldKey = originalPrivateKey->GetPrivateElement();
1308 
1309   std::vector<DCRTPoly> evalKeyElements;
1310   std::vector<DCRTPoly> evalKeyElementsGenerated;
1311 
1312   uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1313 
1314   for (usint i = 0; i < oldKey.GetNumOfElements(); i++) {
1315     if (relinWindow > 0) {
1316       vector<typename DCRTPoly::PolyType> decomposedKeyElements =
1317           oldKey.GetElementAtIndex(i).PowersOfBase(relinWindow);
1318 
1319       for (size_t k = 0; k < decomposedKeyElements.size(); k++) {
1320         // Creates an element with all zeroes
1321         DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1322 
1323         filtered.SetElementAtIndex(i, decomposedKeyElements[k]);
1324 
1325         // Generate a_i vectors
1326         DCRTPoly a(dug, elementParams, Format::EVALUATION);
1327         evalKeyElementsGenerated.push_back(a);
1328 
1329         // Generate a_i * s + e - [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi)
1330         DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1331         evalKeyElements.push_back(filtered - (a * s + e));
1332       }
1333     } else {
1334       // Creates an element with all zeroes
1335       DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1336 
1337       filtered.SetElementAtIndex(i, oldKey.GetElementAtIndex(i));
1338 
1339       // Generate a_i vectors
1340       DCRTPoly a(dug, elementParams, Format::EVALUATION);
1341       evalKeyElementsGenerated.push_back(a);
1342 
1343       // Generate a_i * s + e - [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi)
1344       DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1345       evalKeyElements.push_back(filtered - (a * s + e));
1346     }
1347   }
1348 
1349   ek->SetAVector(std::move(evalKeyElements));
1350   ek->SetBVector(std::move(evalKeyElementsGenerated));
1351 
1352   return ek;
1353 }
1354 
1355 template <>
MultiKeySwitchGen(const LPPrivateKey<DCRTPoly> originalPrivateKey,const LPPrivateKey<DCRTPoly> newPrivateKey,const LPEvalKey<DCRTPoly> ek) const1356 LPEvalKey<DCRTPoly> LPAlgorithmMultipartyBFVrnsB<DCRTPoly>::MultiKeySwitchGen(
1357     const LPPrivateKey<DCRTPoly> originalPrivateKey,
1358     const LPPrivateKey<DCRTPoly> newPrivateKey,
1359     const LPEvalKey<DCRTPoly> ek) const {
1360   LPEvalKeyRelin<DCRTPoly> keySwitchHintRelin(
1361       new LPEvalKeyRelinImpl<DCRTPoly>(newPrivateKey->GetCryptoContext()));
1362 
1363   const shared_ptr<LPCryptoParametersRLWE<DCRTPoly>> cryptoParamsLWE =
1364       std::dynamic_pointer_cast<LPCryptoParametersRLWE<DCRTPoly>>(
1365           newPrivateKey->GetCryptoParameters());
1366   const shared_ptr<ParmType> elementParams =
1367       cryptoParamsLWE->GetElementParams();
1368 
1369   // Getting a reference to the polynomials of new private key.
1370   const DCRTPoly &sNew = newPrivateKey->GetPrivateElement();
1371 
1372   // Getting a reference to the polynomials of original private key.
1373   const DCRTPoly &s = originalPrivateKey->GetPrivateElement();
1374 
1375   const DggType &dgg = cryptoParamsLWE->GetDiscreteGaussianGenerator();
1376   DugType dug;
1377 
1378   std::vector<DCRTPoly> evalKeyElements;
1379   std::vector<DCRTPoly> evalKeyElementsGenerated;
1380 
1381   uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1382 
1383   const std::vector<DCRTPoly> &a = ek->GetBVector();
1384 
1385   for (usint i = 0; i < s.GetNumOfElements(); i++) {
1386     if (relinWindow > 0) {
1387       vector<typename DCRTPoly::PolyType> decomposedKeyElements =
1388           s.GetElementAtIndex(i).PowersOfBase(relinWindow);
1389 
1390       for (size_t k = 0; k < decomposedKeyElements.size(); k++) {
1391         // Creates an element with all zeroes
1392         DCRTPoly filtered(elementParams, EVALUATION, true);
1393 
1394         filtered.SetElementAtIndex(i, decomposedKeyElements[k]);
1395 
1396         // Generate a_i vectors
1397         evalKeyElementsGenerated.push_back(
1398             a[i * decomposedKeyElements.size() + k]);
1399 
1400         // Generate a_i * s + e - [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi)
1401         DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1402         evalKeyElements.push_back(
1403             filtered - (a[i * decomposedKeyElements.size() + k] * sNew + e));
1404       }
1405     } else {
1406       // Creates an element with all zeroes
1407       DCRTPoly filtered(elementParams, EVALUATION, true);
1408 
1409       filtered.SetElementAtIndex(i, s.GetElementAtIndex(i));
1410 
1411       // Generate a_i vectors
1412       evalKeyElementsGenerated.push_back(a[i]);
1413 
1414       // Generate  [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi) - (a_i * s + e)
1415       DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1416       evalKeyElements.push_back(filtered - (a[i] * sNew + e));
1417     }
1418   }
1419 
1420   keySwitchHintRelin->SetAVector(std::move(evalKeyElements));
1421   keySwitchHintRelin->SetBVector(std::move(evalKeyElementsGenerated));
1422 
1423   return keySwitchHintRelin;
1424 }
1425 
1426 template <>
KeySwitchInPlace(const LPEvalKey<DCRTPoly> ek,Ciphertext<DCRTPoly> & cipherText) const1427 void LPAlgorithmSHEBFVrnsB<DCRTPoly>::KeySwitchInPlace(
1428     const LPEvalKey<DCRTPoly> ek, Ciphertext<DCRTPoly>& cipherText) const {
1429 
1430   const auto cryptoParamsLWE =
1431       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1432           ek->GetCryptoParameters());
1433 
1434   LPEvalKeyRelin<DCRTPoly> evalKey =
1435       std::static_pointer_cast<LPEvalKeyRelinImpl<DCRTPoly>>(ek);
1436 
1437   std::vector<DCRTPoly> &c = cipherText->GetElements();
1438 
1439   const std::vector<DCRTPoly> &b = evalKey->GetAVector();
1440   const std::vector<DCRTPoly> &a = evalKey->GetBVector();
1441 
1442   uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1443 
1444   std::vector<DCRTPoly> digitsC2;
1445 
1446 
1447   // in the case of EvalMult, c[0] is initially in Format::COEFFICIENT format
1448   // and needs to be switched to Format::EVALUATION format
1449   if (c.size() > 2) c[0].SetFormat(Format::EVALUATION);
1450 
1451   if (c.size() == 2) {  // case of automorphism
1452     digitsC2 = c[1].CRTDecompose(relinWindow);
1453     c[1] = digitsC2[0] * a[0];
1454   } else {  // case of EvalMult
1455     digitsC2 = c[2].CRTDecompose(relinWindow);
1456     c[1].SetFormat(Format::EVALUATION);
1457     c[1] += digitsC2[0] * a[0];
1458   }
1459 
1460   c[0] += digitsC2[0] * b[0];
1461 
1462   for (usint i = 1; i < digitsC2.size(); ++i) {
1463     c[0] += digitsC2[i] * b[i];
1464     c[1] += digitsC2[i] * a[i];
1465   }
1466 
1467   Ciphertext<DCRTPoly> newCiphertext = cipherText->CloneEmpty();
1468   newCiphertext->SetElements({std::move(c[0]), std::move(c[1])});
1469   cipherText = std::move(newCiphertext);
1470 }
1471 
1472 template <>
EvalMultAndRelinearize(ConstCiphertext<DCRTPoly> ciphertext1,ConstCiphertext<DCRTPoly> ciphertext2,const vector<LPEvalKey<DCRTPoly>> & ek) const1473 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalMultAndRelinearize(
1474     ConstCiphertext<DCRTPoly> ciphertext1,
1475     ConstCiphertext<DCRTPoly> ciphertext2,
1476     const vector<LPEvalKey<DCRTPoly>> &ek) const {
1477   Ciphertext<DCRTPoly> cipherText = this->EvalMult(ciphertext1, ciphertext2);
1478 
1479   const auto cryptoParamsLWE =
1480       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1481           ek[0]->GetCryptoParameters());
1482 
1483   Ciphertext<DCRTPoly> newCiphertext = cipherText->CloneEmpty();
1484 
1485   std::vector<DCRTPoly> c = cipherText->GetElements();
1486   for (size_t i = 0; i < c.size(); i++) c[i].SetFormat(Format::EVALUATION);
1487 
1488   DCRTPoly ct0(c[0]);
1489   DCRTPoly ct1(c[1]);
1490   // Perform a keyswitching operation to result of the multiplication. It does
1491   // it until it reaches to 2 elements.
1492   // TODO: Maybe we can change the number of keyswitching and terminate early.
1493   // For instance; perform keyswitching until 4 elements left.
1494   for (size_t j = 0; j <= cipherText->GetDepth() - 2; j++) {
1495     size_t index = cipherText->GetDepth() - 2 - j;
1496     LPEvalKeyRelin<DCRTPoly> evalKey =
1497         std::static_pointer_cast<LPEvalKeyRelinImpl<DCRTPoly>>(ek[index]);
1498 
1499     const std::vector<DCRTPoly> &b = evalKey->GetAVector();
1500     const std::vector<DCRTPoly> &a = evalKey->GetBVector();
1501 
1502     std::vector<DCRTPoly> digitsC2 = c[index + 2].CRTDecompose();
1503 
1504     for (usint i = 0; i < digitsC2.size(); ++i) {
1505       ct0 += digitsC2[i] * b[i];
1506       ct1 += digitsC2[i] * a[i];
1507     }
1508   }
1509 
1510   newCiphertext->SetElements({std::move(ct0), std::move(ct1)});
1511 
1512   return newCiphertext;
1513 }
1514 
1515 template <>
ReKeyGen(const LPPublicKey<DCRTPoly> newPK,const LPPrivateKey<DCRTPoly> origPrivateKey) const1516 LPEvalKey<DCRTPoly> LPAlgorithmPREBFVrnsB<DCRTPoly>::ReKeyGen(
1517     const LPPublicKey<DCRTPoly> newPK,
1518     const LPPrivateKey<DCRTPoly> origPrivateKey) const {
1519   // Get crypto context of new public key.
1520   auto cc = newPK->GetCryptoContext();
1521 
1522   // Create an Format::EVALUATION key that will contain all the re-encryption
1523   // key elements.
1524   LPEvalKeyRelin<DCRTPoly> ek(
1525       std::make_shared<LPEvalKeyRelinImpl<DCRTPoly>>(cc));
1526 
1527   const auto cryptoParamsLWE =
1528       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1529           newPK->GetCryptoParameters());
1530   const shared_ptr<DCRTPoly::Params> elementParams =
1531       cryptoParamsLWE->GetElementParams();
1532 
1533   const DCRTPoly::DggType &dgg =
1534       cryptoParamsLWE->GetDiscreteGaussianGenerator();
1535   DCRTPoly::DugType dug;
1536   DCRTPoly::TugType tug;
1537 
1538   const DCRTPoly &oldKey = origPrivateKey->GetPrivateElement();
1539 
1540   std::vector<DCRTPoly> evalKeyElements;
1541   std::vector<DCRTPoly> evalKeyElementsGenerated;
1542 
1543   uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1544 
1545   const DCRTPoly &p0 = newPK->GetPublicElements().at(0);
1546   const DCRTPoly &p1 = newPK->GetPublicElements().at(1);
1547 
1548   for (usint i = 0; i < oldKey.GetNumOfElements(); i++) {
1549     if (relinWindow > 0) {
1550       vector<DCRTPoly::PolyType> decomposedKeyElements =
1551           oldKey.GetElementAtIndex(i).PowersOfBase(relinWindow);
1552 
1553       for (size_t k = 0; k < decomposedKeyElements.size(); k++) {
1554         // Creates an element with all zeroes
1555         DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1556 
1557         filtered.SetElementAtIndex(i, decomposedKeyElements[k]);
1558 
1559         DCRTPoly u;
1560 
1561         if (cryptoParamsLWE->GetMode() == RLWE)
1562           u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
1563         else
1564           u = DCRTPoly(tug, elementParams, Format::EVALUATION);
1565 
1566         DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
1567         DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
1568 
1569         DCRTPoly c0(elementParams);
1570         DCRTPoly c1(elementParams);
1571 
1572         c0 = p0 * u + e1 + filtered;
1573 
1574         c1 = p1 * u + e2;
1575 
1576         DCRTPoly a(dug, elementParams, Format::EVALUATION);
1577         evalKeyElementsGenerated.push_back(c1);
1578 
1579         DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1580         evalKeyElements.push_back(c0);
1581       }
1582     } else {
1583       // Creates an element with all zeroes
1584       DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1585 
1586       filtered.SetElementAtIndex(i, oldKey.GetElementAtIndex(i));
1587 
1588       DCRTPoly u;
1589 
1590       if (cryptoParamsLWE->GetMode() == RLWE)
1591         u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
1592       else
1593         u = DCRTPoly(tug, elementParams, Format::EVALUATION);
1594 
1595       DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
1596       DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
1597 
1598       DCRTPoly c0(elementParams);
1599       DCRTPoly c1(elementParams);
1600 
1601       c0 = p0 * u + e1 + filtered;
1602 
1603       c1 = p1 * u + e2;
1604 
1605       DCRTPoly a(dug, elementParams, Format::EVALUATION);
1606       evalKeyElementsGenerated.push_back(c1);
1607 
1608       DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1609       evalKeyElements.push_back(c0);
1610     }
1611   }
1612 
1613   ek->SetAVector(std::move(evalKeyElements));
1614   ek->SetBVector(std::move(evalKeyElementsGenerated));
1615 
1616   return ek;
1617 }
1618 
1619 template <>
ReEncrypt(const LPEvalKey<DCRTPoly> ek,ConstCiphertext<DCRTPoly> ciphertext,const LPPublicKey<DCRTPoly> publicKey) const1620 Ciphertext<DCRTPoly> LPAlgorithmPREBFVrnsB<DCRTPoly>::ReEncrypt(
1621     const LPEvalKey<DCRTPoly> ek, ConstCiphertext<DCRTPoly> ciphertext,
1622     const LPPublicKey<DCRTPoly> publicKey) const {
1623   if (publicKey == nullptr) {  // Sender PK is not provided - CPA-secure PRE
1624     return ciphertext->GetCryptoContext()->KeySwitch(ek, ciphertext);
1625   }
1626 
1627   // Sender PK provided - HRA-secure PRE
1628   const auto cryptoParamsLWE =
1629       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1630           ek->GetCryptoParameters());
1631 
1632   // Get crypto and elements parameters
1633   const shared_ptr<ParmType> elementParams =
1634       cryptoParamsLWE->GetElementParams();
1635 
1636   const DggType &dgg = cryptoParamsLWE->GetDiscreteGaussianGenerator();
1637   TugType tug;
1638 
1639   PlaintextEncodings encType = ciphertext->GetEncodingType();
1640 
1641   Ciphertext<DCRTPoly> zeroCiphertext(
1642       std::make_shared<CiphertextImpl<DCRTPoly>>(publicKey));
1643   zeroCiphertext->SetEncodingType(encType);
1644 
1645   const DCRTPoly &p0 = publicKey->GetPublicElements().at(0);
1646   const DCRTPoly &p1 = publicKey->GetPublicElements().at(1);
1647 
1648   DCRTPoly u;
1649 
1650   if (cryptoParamsLWE->GetMode() == RLWE)
1651     u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
1652   else
1653     u = DCRTPoly(tug, elementParams, Format::EVALUATION);
1654 
1655   DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
1656   DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
1657 
1658   DCRTPoly c0 = p0 * u + e1;
1659   DCRTPoly c1 = p1 * u + e2;
1660 
1661   zeroCiphertext->SetElements({std::move(c0), std::move(c1)});
1662 
1663   // Add the encryption of zero for re-randomization purposes
1664   auto c = ciphertext->GetCryptoContext()->GetEncryptionAlgorithm()->EvalAdd(
1665       ciphertext, zeroCiphertext);
1666 
1667   ciphertext->GetCryptoContext()->KeySwitchInPlace(ek, c);
1668   return c;
1669 }
1670 
1671 template <>
MultipartyDecryptFusion(const vector<Ciphertext<DCRTPoly>> & ciphertextVec,NativePoly * plaintext) const1672 DecryptResult LPAlgorithmMultipartyBFVrnsB<DCRTPoly>::MultipartyDecryptFusion(
1673     const vector<Ciphertext<DCRTPoly>> &ciphertextVec,
1674     NativePoly *plaintext) const {
1675   const auto cryptoParamsBFVrnsB =
1676       std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1677           ciphertextVec[0]->GetCryptoParameters());
1678   const shared_ptr<ParmType> elementParams =
1679       cryptoParamsBFVrnsB->GetElementParams();
1680 
1681   const std::vector<DCRTPoly> &cElem = ciphertextVec[0]->GetElements();
1682   DCRTPoly b = cElem[0];
1683 
1684   size_t numCipher = ciphertextVec.size();
1685   for (size_t i = 1; i < numCipher; i++) {
1686     const std::vector<DCRTPoly> &c2 = ciphertextVec[i]->GetElements();
1687     b += c2[0];
1688   }
1689 
1690   auto &t = cryptoParamsBFVrnsB->GetPlaintextModulus();
1691   auto &tgamma = cryptoParamsBFVrnsB->Gettgamma();
1692 
1693   // Invoke BFVrnsB DecRNS
1694 
1695   const std::vector<NativeInteger> &moduliQ = cryptoParamsBFVrnsB->GetModuliQ();
1696   const std::vector<NativeInteger> &tgammaQHatInvModq =
1697       cryptoParamsBFVrnsB->GettgammaQHatInvModq();
1698   const std::vector<NativeInteger> &tgammaQHatInvModqPrecon =
1699       cryptoParamsBFVrnsB->GettgammaQHatInvModqPrecon();
1700   const std::vector<NativeInteger> &negInvqModtgamma =
1701       cryptoParamsBFVrnsB->GetNegInvqModtgamma();
1702   const std::vector<NativeInteger> &negInvqModtgammaPrecon =
1703       cryptoParamsBFVrnsB->GetNegInvqModtgammaPrecon();
1704 
1705   // this is the resulting vector of coefficients;
1706   *plaintext = b.ScaleAndRound(moduliQ, t, tgamma, tgammaQHatInvModq,
1707                                tgammaQHatInvModqPrecon, negInvqModtgamma,
1708                                negInvqModtgammaPrecon);
1709 
1710   return DecryptResult(plaintext->GetLength());
1711 }
1712 
1713 template class LPCryptoParametersBFVrnsB<DCRTPoly>;
1714 template class LPPublicKeyEncryptionSchemeBFVrnsB<DCRTPoly>;
1715 template class LPAlgorithmBFVrnsB<DCRTPoly>;
1716 template class LPAlgorithmSHEBFVrnsB<DCRTPoly>;
1717 template class LPAlgorithmMultipartyBFVrnsB<DCRTPoly>;
1718 template class LPAlgorithmParamsGenBFVrnsB<DCRTPoly>;
1719 
1720 }  // namespace lbcrypto
1721