1 // @file bfvrnsB-impl.cpp - template instantiations and methods for the BFVrnsB
2 // scheme
3 // @author TPOC: contact@palisade-crypto.org
4 //
5 // @copyright Copyright (c) 2019, New Jersey Institute of Technology (NJIT)
6 // All rights reserved.
7 // Redistribution and use in source and binary forms, with or without
8 // modification, are permitted provided that the following conditions are met:
9 // 1. Redistributions of source code must retain the above copyright notice,
10 // this list of conditions and the following disclaimer.
11 // 2. Redistributions in binary form must reproduce the above copyright notice,
12 // this list of conditions and the following disclaimer in the documentation
13 // and/or other materials provided with the distribution. THIS SOFTWARE IS
14 // PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
15 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
16 // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
17 // EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
18 // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
21 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 /*
25 Description:
26
27 This code implements a RNS variant of the Brakerski-Fan-Vercauteren (BFV)
28 homomorphic encryption scheme. This scheme is also referred to as the FV
29 scheme.
30
31 The BFV scheme is introduced in the following papers:
32 - Zvika Brakerski (2012). Fully Homomorphic Encryption without Modulus
33 Switching from Classical GapSVP. Cryptology ePrint Archive, Report 2012/078.
34 (https://eprint.iacr.org/2012/078)
35 - Junfeng Fan and Frederik Vercauteren (2012). Somewhat Practical Fully
36 Homomorphic Encryption. Cryptology ePrint Archive, Report 2012/144.
37 (https://eprint.iacr.org/2012/144.pdf)
38
39 Our implementation builds from the designs here:
40 - Lepoint T., Naehrig M. (2014) A Comparison of the Homomorphic Encryption
41 Schemes FV and YASHE. In: Pointcheval D., Vergnaud D. (eds) Progress in
42 Cryptology – AFRICACRYPT 2014. AFRICACRYPT 2014. Lecture Notes in Computer
43 Science, vol 8469. Springer, Cham. (https://eprint.iacr.org/2014/062.pdf)
44 - Jean-Claude Bajard and Julien Eynard and Anwar Hasan and Vincent Zucca
45 (2016). A Full RNS Variant of FV like Somewhat Homomorphic Encryption Schemes.
46 Cryptology ePrint Archive, Report 2016/510. (https://eprint.iacr.org/2016/510)
47 - Ahmad Al Badawi and Yuriy Polyakov and Khin Mi Mi Aung and Bharadwaj
48 Veeravalli and Kurt Rohloff (2018). Implementation and Performance Evaluation of
49 RNS Variants of the BFV Homomorphic Encryption Scheme. Cryptology ePrint
50 Archive, Report 2018/589. {https://eprint.iacr.org/2018/589}
51 */
52
53 #include "bfvrnsB.cpp"
54 #include "cryptocontext.h"
55
56 // #define USE_KARATSUBA
57
58 namespace lbcrypto {
59
60 #define NOPOLY \
61 std::string errMsg = "BFVrnsB does not support Poly. Use DCRTPoly instead."; \
62 PALISADE_THROW(not_implemented_error, errMsg);
63
64 #define NONATIVEPOLY \
65 std::string errMsg = \
66 "BFVrnsB does not support NativePoly. Use DCRTPoly instead."; \
67 PALISADE_THROW(not_implemented_error, errMsg);
68
69 template <>
LPCryptoParametersBFVrnsB()70 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB()
71 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
72 NOPOLY
73 }
74
75 template <>
LPCryptoParametersBFVrnsB()76 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB()
77 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
78 NONATIVEPOLY
79 }
80
81 template <>
LPCryptoParametersBFVrnsB(const LPCryptoParametersBFVrnsB & rhs)82 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB(
83 const LPCryptoParametersBFVrnsB &rhs)
84 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
85 NOPOLY
86 }
87
88 template <>
LPCryptoParametersBFVrnsB(const LPCryptoParametersBFVrnsB & rhs)89 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB(
90 const LPCryptoParametersBFVrnsB &rhs)
91 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
92 NONATIVEPOLY
93 }
94
95 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,const PlaintextModulus & plaintextModulus,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)96 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB(
97 shared_ptr<ParmType> params, const PlaintextModulus &plaintextModulus,
98 float distributionParameter, float assuranceMeasure, float securityLevel,
99 usint relinWindow, MODE mode, int depth, int maxDepth)
100 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
101 NOPOLY
102 }
103
104 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,const PlaintextModulus & plaintextModulus,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)105 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB(
106 shared_ptr<ParmType> params, const PlaintextModulus &plaintextModulus,
107 float distributionParameter, float assuranceMeasure, float securityLevel,
108 usint relinWindow, MODE mode, int depth, int maxDepth)
109 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
110 NONATIVEPOLY
111 }
112
113 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,EncodingParams encodingParams,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)114 LPCryptoParametersBFVrnsB<Poly>::LPCryptoParametersBFVrnsB(
115 shared_ptr<ParmType> params, EncodingParams encodingParams,
116 float distributionParameter, float assuranceMeasure, float securityLevel,
117 usint relinWindow, MODE mode, int depth, int maxDepth)
118 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
119 NOPOLY
120 }
121
122 template <>
LPCryptoParametersBFVrnsB(shared_ptr<ParmType> params,EncodingParams encodingParams,float distributionParameter,float assuranceMeasure,float securityLevel,usint relinWindow,MODE mode,int depth,int maxDepth)123 LPCryptoParametersBFVrnsB<NativePoly>::LPCryptoParametersBFVrnsB(
124 shared_ptr<ParmType> params, EncodingParams encodingParams,
125 float distributionParameter, float assuranceMeasure, float securityLevel,
126 usint relinWindow, MODE mode, int depth, int maxDepth)
127 : m_numq(0), m_numb(0), m_negQInvModmtilde(0) {
128 NONATIVEPOLY
129 }
130
131 // Parameter generation for BFV-RNS
132 template <>
PrecomputeCRTTables()133 bool LPCryptoParametersBFVrnsB<Poly>::PrecomputeCRTTables() {
134 NOPOLY
135 }
136
137 template <>
PrecomputeCRTTables()138 bool LPCryptoParametersBFVrnsB<NativePoly>::PrecomputeCRTTables() {
139 NONATIVEPOLY
140 }
141
142 template <>
LPPublicKeyEncryptionSchemeBFVrnsB()143 LPPublicKeyEncryptionSchemeBFVrnsB<Poly>::LPPublicKeyEncryptionSchemeBFVrnsB() {
144 NOPOLY
145 }
146
147 template <>
148 LPPublicKeyEncryptionSchemeBFVrnsB<
LPPublicKeyEncryptionSchemeBFVrnsB()149 NativePoly>::LPPublicKeyEncryptionSchemeBFVrnsB() {
150 NONATIVEPOLY
151 }
152
153 template <>
ParamsGen(shared_ptr<LPCryptoParameters<Poly>> cryptoParams,int32_t evalAddCount,int32_t evalMultCount,int32_t keySwitchCount,size_t dcrtBits,uint32_t n) const154 bool LPAlgorithmParamsGenBFVrnsB<Poly>::ParamsGen(
155 shared_ptr<LPCryptoParameters<Poly>> cryptoParams, int32_t evalAddCount,
156 int32_t evalMultCount, int32_t keySwitchCount, size_t dcrtBits,
157 uint32_t n) const {
158 NOPOLY
159 }
160
161 template <>
ParamsGen(shared_ptr<LPCryptoParameters<NativePoly>> cryptoParams,int32_t evalAddCount,int32_t evalMultCount,int32_t keySwitchCount,size_t dcrtBits,uint32_t n) const162 bool LPAlgorithmParamsGenBFVrnsB<NativePoly>::ParamsGen(
163 shared_ptr<LPCryptoParameters<NativePoly>> cryptoParams,
164 int32_t evalAddCount, int32_t evalMultCount, int32_t keySwitchCount,
165 size_t dcrtBits, uint32_t n) const {
166 NONATIVEPOLY
167 }
168
169 template <>
Encrypt(const LPPublicKey<Poly> publicKey,Poly ptxt) const170 Ciphertext<Poly> LPAlgorithmBFVrnsB<Poly>::Encrypt(
171 const LPPublicKey<Poly> publicKey, Poly ptxt) const {
172 NOPOLY
173 }
174
175 template <>
Encrypt(const LPPublicKey<NativePoly> publicKey,NativePoly ptxt) const176 Ciphertext<NativePoly> LPAlgorithmBFVrnsB<NativePoly>::Encrypt(
177 const LPPublicKey<NativePoly> publicKey, NativePoly ptxt) const {
178 NONATIVEPOLY
179 }
180
181 template <>
Decrypt(const LPPrivateKey<Poly> privateKey,ConstCiphertext<Poly> ciphertext,NativePoly * plaintext) const182 DecryptResult LPAlgorithmBFVrnsB<Poly>::Decrypt(
183 const LPPrivateKey<Poly> privateKey, ConstCiphertext<Poly> ciphertext,
184 NativePoly *plaintext) const {
185 NOPOLY
186 }
187
188 template <>
Decrypt(const LPPrivateKey<NativePoly> privateKey,ConstCiphertext<NativePoly> ciphertext,NativePoly * plaintext) const189 DecryptResult LPAlgorithmBFVrnsB<NativePoly>::Decrypt(
190 const LPPrivateKey<NativePoly> privateKey,
191 ConstCiphertext<NativePoly> ciphertext, NativePoly *plaintext) const {
192 NONATIVEPOLY
193 }
194
195 template <>
Encrypt(const LPPrivateKey<Poly> privateKey,Poly ptxt) const196 Ciphertext<Poly> LPAlgorithmBFVrnsB<Poly>::Encrypt(
197 const LPPrivateKey<Poly> privateKey, Poly ptxt) const {
198 NOPOLY
199 }
200
201 template <>
Encrypt(const LPPrivateKey<NativePoly> privateKey,NativePoly ptxt) const202 Ciphertext<NativePoly> LPAlgorithmBFVrnsB<NativePoly>::Encrypt(
203 const LPPrivateKey<NativePoly> privateKey, NativePoly ptxt) const {
204 NONATIVEPOLY
205 }
206
207 template <>
EvalMult(ConstCiphertext<Poly> ciphertext1,ConstCiphertext<Poly> ciphertext2) const208 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalMult(
209 ConstCiphertext<Poly> ciphertext1,
210 ConstCiphertext<Poly> ciphertext2) const {
211 NOPOLY
212 }
213
214 template <>
EvalMult(ConstCiphertext<NativePoly> ciphertext1,ConstCiphertext<NativePoly> ciphertext2) const215 Ciphertext<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::EvalMult(
216 ConstCiphertext<NativePoly> ciphertext1,
217 ConstCiphertext<NativePoly> ciphertext2) const {
218 NONATIVEPOLY
219 }
220
221 template <>
EvalAdd(ConstCiphertext<Poly> ct,ConstPlaintext pt) const222 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalAdd(ConstCiphertext<Poly> ct,
223 ConstPlaintext pt) const {
224 NOPOLY
225 }
226
227 template <>
EvalAdd(ConstCiphertext<NativePoly> ct,ConstPlaintext pt) const228 Ciphertext<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::EvalAdd(
229 ConstCiphertext<NativePoly> ct, ConstPlaintext pt) const {
230 NONATIVEPOLY
231 }
232
233 template <>
EvalSub(ConstCiphertext<Poly> ct,ConstPlaintext pt) const234 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalSub(ConstCiphertext<Poly> ct,
235 ConstPlaintext pt) const {
236 NOPOLY
237 }
238
239 template <>
EvalSub(ConstCiphertext<NativePoly> ct,ConstPlaintext pt) const240 Ciphertext<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::EvalSub(
241 ConstCiphertext<NativePoly> ct, ConstPlaintext pt) const {
242 NONATIVEPOLY
243 }
244
245 template <>
KeySwitchGen(const LPPrivateKey<Poly> originalPrivateKey,const LPPrivateKey<Poly> newPrivateKey) const246 LPEvalKey<Poly> LPAlgorithmSHEBFVrnsB<Poly>::KeySwitchGen(
247 const LPPrivateKey<Poly> originalPrivateKey,
248 const LPPrivateKey<Poly> newPrivateKey) const {
249 NOPOLY
250 }
251
252 template <>
KeySwitchGen(const LPPrivateKey<NativePoly> originalPrivateKey,const LPPrivateKey<NativePoly> newPrivateKey) const253 LPEvalKey<NativePoly> LPAlgorithmSHEBFVrnsB<NativePoly>::KeySwitchGen(
254 const LPPrivateKey<NativePoly> originalPrivateKey,
255 const LPPrivateKey<NativePoly> newPrivateKey) const {
256 NONATIVEPOLY
257 }
258
259 template <>
KeySwitchInPlace(const LPEvalKey<Poly> keySwitchHint,Ciphertext<Poly> & cipherText) const260 void LPAlgorithmSHEBFVrnsB<Poly>::KeySwitchInPlace(
261 const LPEvalKey<Poly> keySwitchHint,
262 Ciphertext<Poly>& cipherText) const {
263 NOPOLY
264 }
265
266 template <>
KeySwitchInPlace(const LPEvalKey<NativePoly> keySwitchHint,Ciphertext<NativePoly> & cipherText) const267 void LPAlgorithmSHEBFVrnsB<NativePoly>::KeySwitchInPlace(
268 const LPEvalKey<NativePoly> keySwitchHint,
269 Ciphertext<NativePoly>& cipherText) const {
270 NONATIVEPOLY
271 }
272
273 template <>
EvalMultAndRelinearize(ConstCiphertext<Poly> ct1,ConstCiphertext<Poly> ct,const vector<LPEvalKey<Poly>> & ek) const274 Ciphertext<Poly> LPAlgorithmSHEBFVrnsB<Poly>::EvalMultAndRelinearize(
275 ConstCiphertext<Poly> ct1, ConstCiphertext<Poly> ct,
276 const vector<LPEvalKey<Poly>> &ek) const {
277 NOPOLY
278 }
279
280 template <>
281 Ciphertext<NativePoly>
EvalMultAndRelinearize(ConstCiphertext<NativePoly> ct1,ConstCiphertext<NativePoly> ct,const vector<LPEvalKey<NativePoly>> & ek) const282 LPAlgorithmSHEBFVrnsB<NativePoly>::EvalMultAndRelinearize(
283 ConstCiphertext<NativePoly> ct1, ConstCiphertext<NativePoly> ct,
284 const vector<LPEvalKey<NativePoly>> &ek) const {
285 NONATIVEPOLY
286 }
287
288 template <>
MultipartyDecryptFusion(const vector<Ciphertext<Poly>> & ciphertextVec,NativePoly * plaintext) const289 DecryptResult LPAlgorithmMultipartyBFVrnsB<Poly>::MultipartyDecryptFusion(
290 const vector<Ciphertext<Poly>> &ciphertextVec,
291 NativePoly *plaintext) const {
292 NOPOLY
293 }
294
295 template <>
MultipartyDecryptFusion(const vector<Ciphertext<NativePoly>> & ciphertextVec,NativePoly * plaintext) const296 DecryptResult LPAlgorithmMultipartyBFVrnsB<NativePoly>::MultipartyDecryptFusion(
297 const vector<Ciphertext<NativePoly>> &ciphertextVec,
298 NativePoly *plaintext) const {
299 NONATIVEPOLY
300 }
301
302 template <>
MultiKeySwitchGen(const LPPrivateKey<Poly> originalPrivateKey,const LPPrivateKey<Poly> newPrivateKey,const LPEvalKey<Poly> ek) const303 LPEvalKey<Poly> LPAlgorithmMultipartyBFVrnsB<Poly>::MultiKeySwitchGen(
304 const LPPrivateKey<Poly> originalPrivateKey,
305 const LPPrivateKey<Poly> newPrivateKey, const LPEvalKey<Poly> ek) const {
306 NOPOLY
307 }
308
309 template <>
310 LPEvalKey<NativePoly>
MultiKeySwitchGen(const LPPrivateKey<NativePoly> originalPrivateKey,const LPPrivateKey<NativePoly> newPrivateKey,const LPEvalKey<NativePoly> ek) const311 LPAlgorithmMultipartyBFVrnsB<NativePoly>::MultiKeySwitchGen(
312 const LPPrivateKey<NativePoly> originalPrivateKey,
313 const LPPrivateKey<NativePoly> newPrivateKey,
314 const LPEvalKey<NativePoly> ek) const {
315 NONATIVEPOLY
316 }
317
318 template class LPCryptoParametersBFVrnsB<Poly>;
319 template class LPPublicKeyEncryptionSchemeBFVrnsB<Poly>;
320 template class LPAlgorithmBFVrnsB<Poly>;
321 template class LPAlgorithmSHEBFVrnsB<Poly>;
322 template class LPAlgorithmMultipartyBFVrnsB<Poly>;
323 template class LPAlgorithmParamsGenBFVrnsB<Poly>;
324
325 template class LPCryptoParametersBFVrnsB<NativePoly>;
326 template class LPPublicKeyEncryptionSchemeBFVrnsB<NativePoly>;
327 template class LPAlgorithmBFVrnsB<NativePoly>;
328 template class LPAlgorithmSHEBFVrnsB<NativePoly>;
329 template class LPAlgorithmMultipartyBFVrnsB<NativePoly>;
330 template class LPAlgorithmParamsGenBFVrnsB<NativePoly>;
331
332 #undef NOPOLY
333 #undef NONATIVEPOLY
334
335 // Precomputation of CRT tables encryption, decryption, and homomorphic
336 // multiplication
337 template <>
PrecomputeCRTTables()338 bool LPCryptoParametersBFVrnsB<DCRTPoly>::PrecomputeCRTTables() {
339 // read values for the CRT basis
340
341 size_t sizeQ = GetElementParams()->GetParams().size();
342 auto ringDim = GetElementParams()->GetRingDimension();
343
344 vector<NativeInteger> moduliQ(sizeQ);
345 vector<NativeInteger> rootsQ(sizeQ);
346
347 const BigInteger BarrettBase128Bit(
348 "340282366920938463463374607431768211456"); // 2^128
349 const BigInteger TwoPower64("18446744073709551616"); // 2^64
350
351 m_moduliQ.resize(sizeQ);
352 for (size_t i = 0; i < sizeQ; i++) {
353 moduliQ[i] = GetElementParams()->GetParams()[i]->GetModulus();
354 rootsQ[i] = GetElementParams()->GetParams()[i]->GetRootOfUnity();
355 m_moduliQ[i] = moduliQ[i];
356 }
357
358 // compute the CRT delta table floor(Q/p) mod qi - used for encryption
359
360 const BigInteger modulusQ = GetElementParams()->GetModulus();
361
362 const BigInteger QDivt = modulusQ.DividedBy(GetPlaintextModulus());
363
364 std::vector<NativeInteger> QDivtModq(sizeQ);
365
366 for (size_t i = 0; i < sizeQ; i++) {
367 BigInteger qi = BigInteger(moduliQ[i].ConvertToInt());
368 BigInteger QDivtModqi = QDivt.Mod(qi);
369 QDivtModq[i] = NativeInteger(QDivtModqi.ConvertToInt());
370 }
371
372 m_QDivtModq = QDivtModq;
373
374 m_modqBarrettMu.resize(sizeQ);
375 for (uint32_t i = 0; i < m_modqBarrettMu.size(); i++) {
376 BigInteger mu = BarrettBase128Bit / BigInteger(m_moduliQ[i]);
377 uint64_t val[2];
378 val[0] = (mu % TwoPower64).ConvertToInt();
379 val[1] = mu.RShift(64).ConvertToInt();
380
381 memcpy(&m_modqBarrettMu[i], val, sizeof(DoubleNativeInt));
382 }
383
384 ChineseRemainderTransformFTT<NativeVector>::PreCompute(rootsQ, 2 * ringDim,
385 moduliQ);
386
387 // Compute Bajard's et al. RNS variant lookup tables
388
389 // Populate EvalMulrns tables
390 // find the a suitable size of B
391 m_numq = sizeQ;
392
393 BigInteger t = BigInteger(GetPlaintextModulus());
394 BigInteger Q(GetElementParams()->GetModulus());
395
396 BigInteger B = 1;
397 BigInteger maxConvolutionValue =
398 BigInteger(2) * BigInteger(ringDim) * Q * Q * t;
399
400 m_moduliB.push_back(
401 PreviousPrime<NativeInteger>(moduliQ[m_numq - 1], 2 * ringDim));
402 m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * ringDim, m_moduliB[0]));
403 B = B * BigInteger(m_moduliB[0]);
404
405 for (usint i = 1; i < m_numq; i++) { // we already added one prime
406 m_moduliB.push_back(
407 PreviousPrime<NativeInteger>(m_moduliB[i - 1], 2 * ringDim));
408 m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * ringDim, m_moduliB[i]));
409
410 B = B * BigInteger(m_moduliB[i]);
411 }
412
413 m_numb = m_numq;
414
415 m_msk = PreviousPrime<NativeInteger>(m_moduliB[m_numq - 1], 2 * ringDim);
416
417 usint s = 0;
418 NativeInteger tmp = m_msk;
419 while (tmp > 0) {
420 tmp >>= 1;
421 s++;
422 }
423
424 // check msk is large enough
425 while (Q * B * BigInteger(m_msk) < maxConvolutionValue) {
426 NativeInteger firstInteger = FirstPrime<NativeInteger>(s + 1, 2 * ringDim);
427
428 m_msk = NextPrime<NativeInteger>(firstInteger, 2 * ringDim);
429 s++;
430 if (s >= 60) PALISADE_THROW(math_error, "msk is larger than 60 bits");
431 }
432 m_rootsBsk.push_back(RootOfUnity<NativeInteger>(2 * ringDim, m_msk));
433
434 m_moduliBsk = m_moduliB;
435 m_moduliBsk.push_back(m_msk);
436
437 m_paramsBsk = std::make_shared<ILDCRTParams<BigInteger>>(
438 2 * ringDim, m_moduliBsk, m_rootsBsk);
439
440 ChineseRemainderTransformFTT<NativeVector>::PreCompute(
441 m_rootsBsk, 2 * ringDim, m_moduliBsk);
442
443 // populate Barrett constant for m_BskModuli
444 m_modbskBarrettMu.resize(m_moduliBsk.size());
445 for (uint32_t i = 0; i < m_modbskBarrettMu.size(); i++) {
446 BigInteger mu = BarrettBase128Bit / BigInteger(m_moduliBsk[i]);
447 uint64_t val[2];
448 val[0] = (mu % TwoPower64).ConvertToInt();
449 val[1] = mu.RShift(64).ConvertToInt();
450
451 memcpy(&m_modbskBarrettMu[i], val, sizeof(DoubleNativeInt));
452 }
453
454 // Populate [(Q/q_i)^-1]_{q_i}
455 m_QHatInvModq.resize(m_numq);
456 for (uint32_t i = 0; i < m_QHatInvModq.size(); i++) {
457 BigInteger QHatInvModqi;
458 QHatInvModqi = Q.DividedBy(moduliQ[i]);
459 QHatInvModqi = QHatInvModqi.Mod(moduliQ[i]);
460 QHatInvModqi = QHatInvModqi.ModInverse(moduliQ[i]);
461 m_QHatInvModq[i] = QHatInvModqi.ConvertToInt();
462 }
463
464 // Populate [t*(Q/q_i)^-1]_{q_i}
465 m_tQHatInvModq.resize(m_numq);
466 m_tQHatInvModqPrecon.resize(m_numq);
467 for (uint32_t i = 0; i < m_tQHatInvModq.size(); i++) {
468 BigInteger tQHatInvModqi;
469 tQHatInvModqi = Q.DividedBy(moduliQ[i]);
470 tQHatInvModqi = tQHatInvModqi.Mod(moduliQ[i]);
471 tQHatInvModqi = tQHatInvModqi.ModInverse(moduliQ[i]);
472 tQHatInvModqi = tQHatInvModqi.ModMul(t.ConvertToInt(), moduliQ[i]);
473 m_tQHatInvModq[i] = tQHatInvModqi.ConvertToInt();
474 m_tQHatInvModqPrecon[i] = m_tQHatInvModq[i].PrepModMulConst(moduliQ[i]);
475 }
476
477 // Populate [Q/q_i]_{bsk_j, mtilde}
478 m_QHatModbsk.resize(m_numq);
479 m_QHatModmtilde.resize(m_numq);
480 for (uint32_t i = 0; i < m_QHatModbsk.size(); i++) {
481 m_QHatModbsk[i].resize(m_numb + 1);
482
483 BigInteger QHati = Q.DividedBy(moduliQ[i]);
484 for (uint32_t j = 0; j < m_QHatModbsk[i].size(); j++) {
485 BigInteger QHatiModbskj = QHati.Mod(m_moduliBsk[j]);
486 m_QHatModbsk[i][j] = QHatiModbskj.ConvertToInt();
487 }
488 m_QHatModmtilde[i] = QHati.Mod(m_mtilde).ConvertToInt();
489 }
490
491 // Populate [1/q_i]_{bsk_j}
492 m_qInvModbsk.resize(m_numq);
493 for (uint32_t i = 0; i < m_qInvModbsk.size(); i++) {
494 m_qInvModbsk[i].resize(m_numb + 1);
495 for (uint32_t j = 0; j < m_qInvModbsk[i].size(); j++)
496 m_qInvModbsk[i][j] = moduliQ[i].ModInverse(m_moduliBsk[j]);
497 }
498
499 // Populate [mtilde*(Q/q_i)^{-1}]_{q_i}
500 m_mtildeQHatInvModq.resize(m_numq);
501 m_mtildeQHatInvModqPrecon.resize(m_numq);
502
503 BigInteger bmtilde(m_mtilde);
504 for (uint32_t i = 0; i < m_mtildeQHatInvModq.size(); i++) {
505 BigInteger mtildeQHatInvModqi = Q.DividedBy(moduliQ[i]);
506 mtildeQHatInvModqi = mtildeQHatInvModqi.Mod(moduliQ[i]);
507 mtildeQHatInvModqi = mtildeQHatInvModqi.ModInverse(moduliQ[i]);
508 mtildeQHatInvModqi = mtildeQHatInvModqi * bmtilde;
509 mtildeQHatInvModqi = mtildeQHatInvModqi.Mod(moduliQ[i]);
510 m_mtildeQHatInvModq[i] = mtildeQHatInvModqi.ConvertToInt();
511 m_mtildeQHatInvModqPrecon[i] =
512 m_mtildeQHatInvModq[i].PrepModMulConst(moduliQ[i]);
513 }
514
515 // Populate [-Q^{-1}]_{mtilde}
516 BigInteger negQInvModmtilde =
517 (BigInteger(m_mtilde - 1) * Q.ModInverse(m_mtilde));
518 negQInvModmtilde = negQInvModmtilde.Mod(m_mtilde);
519 m_negQInvModmtilde = negQInvModmtilde.ConvertToInt();
520
521 // Populate [Q]_{bski_j}
522 m_QModbsk.resize(m_numq + 1);
523 m_QModbskPrecon.resize(m_numq + 1);
524
525 for (uint32_t j = 0; j < m_QModbsk.size(); j++) {
526 BigInteger QModbskij = Q.Mod(m_moduliBsk[j]);
527 m_QModbsk[j] = QModbskij.ConvertToInt();
528 m_QModbskPrecon[j] = m_QModbsk[j].PrepModMulConst(m_moduliBsk[j]);
529 }
530
531 // Populate [mtilde^{-1}]_{bsk_j}
532 m_mtildeInvModbsk.resize(m_numb + 1);
533 m_mtildeInvModbskPrecon.resize(m_numb + 1);
534 for (uint32_t j = 0; j < m_mtildeInvModbsk.size(); j++) {
535 BigInteger mtildeInvModbskij = m_mtilde % m_moduliBsk[j];
536 mtildeInvModbskij = mtildeInvModbskij.ModInverse(m_moduliBsk[j]);
537 m_mtildeInvModbsk[j] = mtildeInvModbskij.ConvertToInt();
538 m_mtildeInvModbskPrecon[j] =
539 m_mtildeInvModbsk[j].PrepModMulConst(m_moduliBsk[j]);
540 }
541
542 // Populate {t/Q}_{bsk_j}
543 m_tQInvModbsk.resize(m_numb + 1);
544 m_tQInvModbskPrecon.resize(m_numb + 1);
545
546 for (uint32_t i = 0; i < m_tQInvModbsk.size(); i++) {
547 BigInteger tDivqModBski = Q.ModInverse(m_moduliBsk[i]);
548 tDivqModBski.ModMulEq(t.ConvertToInt(), m_moduliBsk[i]);
549 m_tQInvModbsk[i] = tDivqModBski.ConvertToInt();
550 m_tQInvModbskPrecon[i] = m_tQInvModbsk[i].PrepModMulConst(m_moduliBsk[i]);
551 }
552
553 // Populate [(B/b_j)^{-1}]_{b_j}
554 m_BHatInvModb.resize(m_numb);
555 m_BHatInvModbPrecon.resize(m_numb);
556
557 for (uint32_t i = 0; i < m_BHatInvModb.size(); i++) {
558 BigInteger BDivBi;
559 BDivBi = B.DividedBy(m_moduliB[i]);
560 BDivBi = BDivBi.Mod(m_moduliB[i]);
561 BDivBi = BDivBi.ModInverse(m_moduliB[i]);
562 m_BHatInvModb[i] = BDivBi.ConvertToInt();
563 m_BHatInvModbPrecon[i] = m_BHatInvModb[i].PrepModMulConst(m_moduliB[i]);
564 }
565
566 // Populate [B/b_j]_{q_i}
567 m_BHatModq.resize(m_numb);
568 for (uint32_t i = 0; i < m_BHatModq.size(); i++) {
569 m_BHatModq[i].resize(m_numq);
570 BigInteger BDivBi = B.DividedBy(m_moduliB[i]);
571 for (uint32_t j = 0; j < m_BHatModq[i].size(); j++) {
572 BigInteger BDivBiModqj = BDivBi.Mod(moduliQ[j]);
573 m_BHatModq[i][j] = BDivBiModqj.ConvertToInt();
574 }
575 }
576
577 // Populate [B/b_j]_{msk}
578 m_BHatModmsk.resize(m_numb);
579 for (uint32_t i = 0; i < m_BHatModmsk.size(); i++) {
580 BigInteger BDivBi = B.DividedBy(m_moduliB[i]);
581 m_BHatModmsk[i] = (BDivBi.Mod(m_msk)).ConvertToInt();
582 }
583
584 // Populate [B^{-1}]_{msk}
585 m_BInvModmsk = (B.ModInverse(m_msk)).ConvertToInt();
586 m_BInvModmskPrecon = m_BInvModmsk.PrepModMulConst(m_msk);
587
588 // Populate [B]_{q_i}
589 m_BModq.resize(m_numq);
590 m_BModqPrecon.resize(m_numq);
591 for (uint32_t i = 0; i < m_BModq.size(); i++) {
592 m_BModq[i] = (B.Mod(moduliQ[i])).ConvertToInt();
593 m_BModqPrecon[i] = m_BModq[i].PrepModMulConst(moduliQ[i]);
594 }
595
596 // Populate Decrns lookup tables
597
598 NativeInteger tgamma = NativeInteger(t.ConvertToInt() * m_gamma); // t*gamma
599
600 m_tgamma = tgamma;
601
602 // Populate [-1/q_i]_{t*gamma} (t*gamma < 2^58)
603 m_negInvqModtgamma.resize(m_numq);
604 m_negInvqModtgammaPrecon.resize(m_numq);
605 for (uint32_t i = 0; i < m_negInvqModtgamma.size(); i++) {
606 BigInteger imod(moduliQ[i]);
607 BigInteger negInvqi = BigInteger((tgamma - 1)) * imod.ModInverse(tgamma);
608
609 BigInteger negInvqiModtgamma = negInvqi.Mod(tgamma);
610 m_negInvqModtgamma[i] = negInvqiModtgamma.ConvertToInt();
611 m_negInvqModtgammaPrecon[i] = m_negInvqModtgamma[i].PrepModMulConst(tgamma);
612 }
613
614 // populate [t*gamma*(Q/q_i)^(-1)]_{q_i}
615 m_tgammaQHatInvModq.resize(m_numq);
616 m_tgammaQHatInvModqPrecon.resize(m_numq);
617
618 BigInteger bmgamma(m_gamma);
619 for (uint32_t i = 0; i < m_tgammaQHatInvModq.size(); i++) {
620 BigInteger qDivqi = Q.DividedBy(moduliQ[i]);
621 BigInteger imod(moduliQ[i]);
622 qDivqi = qDivqi.ModInverse(moduliQ[i]);
623 BigInteger gammaqDivqi = (qDivqi * bmgamma) % imod;
624 BigInteger tgammaqDivqi = (gammaqDivqi * t) % imod;
625 m_tgammaQHatInvModq[i] = tgammaqDivqi.ConvertToInt();
626 m_tgammaQHatInvModqPrecon[i] =
627 m_tgammaQHatInvModq[i].PrepModMulConst(moduliQ[i]);
628 }
629
630 return true;
631 }
632
633 // Parameter generation for BFV-RNS
634 template <>
ParamsGen(shared_ptr<LPCryptoParameters<DCRTPoly>> cryptoParams,int32_t evalAddCount,int32_t evalMultCount,int32_t keySwitchCount,size_t dcrtBits,uint32_t nCustom) const635 bool LPAlgorithmParamsGenBFVrnsB<DCRTPoly>::ParamsGen(
636 shared_ptr<LPCryptoParameters<DCRTPoly>> cryptoParams, int32_t evalAddCount,
637 int32_t evalMultCount, int32_t keySwitchCount, size_t dcrtBits,
638 uint32_t nCustom) const {
639 if (!cryptoParams)
640 PALISADE_THROW(not_available_error,
641 "No crypto parameters are supplied to BFVrns ParamsGen");
642
643 if ((dcrtBits < 30) || (dcrtBits > 60))
644 PALISADE_THROW(math_error,
645 "BFVrns.ParamsGen: Number of bits in CRT moduli should be "
646 "in the range from 30 to 60");
647
648 const auto cryptoParamsBFVrnsB =
649 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
650 cryptoParams);
651
652 double sigma = cryptoParamsBFVrnsB->GetDistributionParameter();
653 double alpha = cryptoParamsBFVrnsB->GetAssuranceMeasure();
654 double hermiteFactor = cryptoParamsBFVrnsB->GetSecurityLevel();
655 double p = static_cast<double>(cryptoParamsBFVrnsB->GetPlaintextModulus());
656 uint32_t relinWindow = cryptoParamsBFVrnsB->GetRelinWindow();
657 SecurityLevel stdLevel = cryptoParamsBFVrnsB->GetStdLevel();
658
659 // Bound of the Gaussian error polynomial
660 double Berr = sigma * sqrt(alpha);
661
662 // Bound of the key polynomial
663 double Bkey;
664
665 DistributionType distType;
666
667 // supports both discrete Gaussian (RLWE) and ternary uniform distribution
668 // (OPTIMIZED) cases
669 if (cryptoParamsBFVrnsB->GetMode() == RLWE) {
670 Bkey = sigma * sqrt(alpha);
671 distType = HEStd_error;
672 } else {
673 Bkey = 1;
674 distType = HEStd_ternary;
675 }
676
677 // expansion factor delta
678 auto delta = [](uint32_t n) -> double { return (2. * sqrt(n)); };
679
680 // norm of fresh ciphertext polynomial
681 auto Vnorm = [&](uint32_t n) -> double {
682 return Berr * (1. + 2. * delta(n) * Bkey);
683 };
684
685 // RLWE security constraint
686 auto nRLWE = [&](double logq) -> double {
687 if (stdLevel == HEStd_NotSet) {
688 return (logq - log(sigma)) / (4. * log(hermiteFactor));
689 } else {
690 return static_cast<double>(StdLatticeParm::FindRingDim(
691 distType, stdLevel, static_cast<long>(ceil(logq / log(2)))));
692 }
693 };
694
695 // initial values
696 uint32_t n = (nCustom > 0) ? nCustom : 512;
697
698 double logq = 0.;
699
700 // only public key encryption and EvalAdd (optional when evalAddCount = 0)
701 // operations are supported the correctness constraint from section 3.5 of
702 // https://eprint.iacr.org/2014/062.pdf is used
703 if ((evalMultCount == 0) && (keySwitchCount == 0)) {
704 // Correctness constraint
705 auto logqBFV = [&](uint32_t n) -> double {
706 return log(p *
707 (4 * ((evalAddCount + 1) * Vnorm(n) + evalAddCount * p) + p));
708 };
709
710 // initial value
711 logq = logqBFV(n);
712
713 if ((nRLWE(logq) > n) && (nCustom > 0))
714 PALISADE_THROW(config_error,
715 "Ring dimension n specified by the user does not meet the "
716 "security requirement. Please increase it.");
717
718 while (nRLWE(logq) > n) {
719 n = 2 * n;
720 logq = logqBFV(n);
721 }
722
723 // this code updates n and q to account for the discrete size of CRT moduli
724 // = dcrtBits
725
726 int32_t k =
727 static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
728
729 double logqCeil = k * dcrtBits * log(2);
730
731 while (nRLWE(logqCeil) > n) {
732 n = 2 * n;
733 logq = logqBFV(n);
734 k = static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
735 logqCeil = k * dcrtBits * log(2);
736 }
737 } else if ((evalMultCount == 0) && (keySwitchCount > 0) &&
738 (evalAddCount == 0)) {
739 // this case supports automorphism w/o any other operations
740 // base for relinearization
741
742 double w = relinWindow == 0 ? pow(2, dcrtBits) : pow(2, relinWindow);
743
744 // Correctness constraint
745 auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
746 return log(
747 p * (4 * (Vnorm(n) + keySwitchCount * delta(n) *
748 (floor(logqPrev / (log(2) * dcrtBits)) + 1) *
749 w * Berr) +
750 p));
751 };
752
753 // initial values
754 double logqPrev = 6 * log(10);
755 logq = logqBFV(n, logqPrev);
756 logqPrev = logq;
757
758 if ((nRLWE(logq) > n) && (nCustom > 0))
759 PALISADE_THROW(config_error,
760 "Ring dimension n specified by the user does not meet the "
761 "security requirement. Please increase it.");
762
763 // this "while" condition is needed in case the iterative solution for q
764 // changes the requirement for n, which is rare but still theoretically
765 // possible
766 while (nRLWE(logq) > n) {
767 while (nRLWE(logq) > n) {
768 n = 2 * n;
769 logq = logqBFV(n, logqPrev);
770 logqPrev = logq;
771 }
772
773 logq = logqBFV(n, logqPrev);
774
775 while (fabs(logq - logqPrev) > log(1.001)) {
776 logqPrev = logq;
777 logq = logqBFV(n, logqPrev);
778 }
779
780 // this code updates n and q to account for the discrete size of CRT
781 // moduli = dcrtBits
782
783 int32_t k =
784 static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
785
786 double logqCeil = k * dcrtBits * log(2);
787 logqPrev = logqCeil;
788
789 while (nRLWE(logqCeil) > n) {
790 n = 2 * n;
791 logq = logqBFV(n, logqPrev);
792 k = static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
793 logqCeil = k * dcrtBits * log(2);
794 logqPrev = logqCeil;
795 }
796 }
797 } else if ((evalAddCount == 0) && (evalMultCount > 0) &&
798 (keySwitchCount == 0)) {
799 // Only EvalMult operations are used in the correctness constraint
800 // the correctness constraint from section 3.5 of
801 // https://eprint.iacr.org/2014/062.pdf is used
802
803 // base for relinearization
804 double w = relinWindow == 0 ? pow(2, dcrtBits) : pow(2, relinWindow);
805
806 // function used in the EvalMult constraint
807 auto epsilon1 = [&](uint32_t n) -> double { return 5 / (delta(n) * Bkey); };
808
809 // function used in the EvalMult constraint
810 auto C1 = [&](uint32_t n) -> double {
811 return (1 + epsilon1(n)) * delta(n) * delta(n) * p * Bkey;
812 };
813
814 // function used in the EvalMult constraint
815 auto C2 = [&](uint32_t n, double logqPrev) -> double {
816 return delta(n) * delta(n) * Bkey * ((1 + 0.5) * Bkey + p * p) +
817 delta(n) * (floor(logqPrev / (log(2) * dcrtBits)) + 1) * w * Berr;
818 };
819
820 // main correctness constraint
821 auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
822 return log(4 * p) + (evalMultCount - 1) * log(C1(n)) +
823 log(C1(n) * Vnorm(n) + evalMultCount * C2(n, logqPrev));
824 };
825
826 // initial values
827 double logqPrev = 6. * log(10);
828 logq = logqBFV(n, logqPrev);
829 logqPrev = logq;
830
831 if ((nRLWE(logq) > n) && (nCustom > 0))
832 PALISADE_THROW(config_error,
833 "Ring dimension n specified by the user does not meet the "
834 "security requirement. Please increase it.");
835
836 // this "while" condition is needed in case the iterative solution for q
837 // changes the requirement for n, which is rare but still theoretically
838 // possible
839 while (nRLWE(logq) > n) {
840 while (nRLWE(logq) > n) {
841 n = 2 * n;
842 logq = logqBFV(n, logqPrev);
843 logqPrev = logq;
844 }
845
846 logq = logqBFV(n, logqPrev);
847
848 while (fabs(logq - logqPrev) > log(1.001)) {
849 logqPrev = logq;
850 logq = logqBFV(n, logqPrev);
851 }
852
853 // this code updates n and q to account for the discrete size of CRT
854 // moduli = dcrtBits
855
856 int32_t k =
857 static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
858
859 double logqCeil = k * dcrtBits * log(2);
860 logqPrev = logqCeil;
861
862 while (nRLWE(logqCeil) > n) {
863 n = 2 * n;
864 logq = logqBFV(n, logqPrev);
865 k = static_cast<int32_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
866 logqCeil = k * dcrtBits * log(2);
867 logqPrev = logqCeil;
868 }
869 }
870 }
871
872 size_t sizeQ =
873 static_cast<size_t>(ceil((ceil(logq / log(2)) + 1.0) / dcrtBits));
874
875 vector<NativeInteger> moduliQ(sizeQ);
876 vector<NativeInteger> rootsQ(sizeQ);
877
878 // makes sure the first integer is less than 2^60-1 to take advantage of NTL
879 // optimizations
880 NativeInteger firstInteger = FirstPrime<NativeInteger>(dcrtBits, 2 * n);
881
882 moduliQ[0] = PreviousPrime<NativeInteger>(firstInteger, 2 * n);
883 rootsQ[0] = RootOfUnity<NativeInteger>(2 * n, moduliQ[0]);
884
885 for (size_t i = 1; i < sizeQ; i++) {
886 moduliQ[i] = PreviousPrime<NativeInteger>(moduliQ[i - 1], 2 * n);
887 rootsQ[i] = RootOfUnity<NativeInteger>(2 * n, moduliQ[i]);
888 }
889
890 auto params =
891 std::make_shared<ILDCRTParams<BigInteger>>(2 * n, moduliQ, rootsQ);
892
893 ChineseRemainderTransformFTT<NativeVector>::PreCompute(rootsQ, 2 * n,
894 moduliQ);
895
896 cryptoParamsBFVrnsB->SetElementParams(params);
897
898 const EncodingParams encodingParams = cryptoParamsBFVrnsB->GetEncodingParams();
899 if (encodingParams->GetBatchSize() > n)
900 PALISADE_THROW(config_error,
901 "The batch size cannot be larger than the ring dimension.");
902
903 // if no batch size was specified, we set batchSize = n by default (for full
904 // packing)
905 if (encodingParams->GetBatchSize() == 0) {
906 uint32_t batchSize = n;
907 EncodingParams encodingParamsNew(std::make_shared<EncodingParamsImpl>(
908 encodingParams->GetPlaintextModulus(), batchSize));
909 cryptoParamsBFVrnsB->SetEncodingParams(encodingParamsNew);
910 }
911
912 return cryptoParamsBFVrnsB->PrecomputeCRTTables();
913 }
914
915 template <>
Encrypt(const LPPublicKey<DCRTPoly> publicKey,DCRTPoly ptxt) const916 Ciphertext<DCRTPoly> LPAlgorithmBFVrnsB<DCRTPoly>::Encrypt(
917 const LPPublicKey<DCRTPoly> publicKey, DCRTPoly ptxt) const {
918 Ciphertext<DCRTPoly> ciphertext(
919 std::make_shared<CiphertextImpl<DCRTPoly>>(publicKey));
920
921 const auto cryptoParams =
922 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
923 publicKey->GetCryptoParameters());
924
925 const shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
926
927 ptxt.SetFormat(Format::EVALUATION);
928
929 const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
930
931 const DggType &dgg = cryptoParams->GetDiscreteGaussianGenerator();
932 TugType tug;
933
934 const DCRTPoly &p0 = publicKey->GetPublicElements().at(0);
935 const DCRTPoly &p1 = publicKey->GetPublicElements().at(1);
936
937 DCRTPoly u;
938
939 // Supports both discrete Gaussian (RLWE) and ternary uniform distribution
940 // (OPTIMIZED) cases
941 if (cryptoParams->GetMode() == RLWE)
942 u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
943 else
944 u = DCRTPoly(tug, elementParams, Format::EVALUATION);
945
946 DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
947 DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
948
949 DCRTPoly c0(elementParams);
950 DCRTPoly c1(elementParams);
951
952 c0 = p0 * u + e1 + ptxt.Times(delta);
953
954 c1 = p1 * u + e2;
955
956 ciphertext->SetElements({std::move(c0), std::move(c1)});
957
958 return ciphertext;
959 }
960
961 template <>
Decrypt(const LPPrivateKey<DCRTPoly> privateKey,ConstCiphertext<DCRTPoly> ciphertext,NativePoly * plaintext) const962 DecryptResult LPAlgorithmBFVrnsB<DCRTPoly>::Decrypt(
963 const LPPrivateKey<DCRTPoly> privateKey,
964 ConstCiphertext<DCRTPoly> ciphertext, NativePoly *plaintext) const {
965 // TimeVar t_total;
966
967 // TIC(t_total);
968
969 const auto cryptoParamsBFVrnsB =
970 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
971 privateKey->GetCryptoParameters());
972 const shared_ptr<ParmType> elementParams =
973 cryptoParamsBFVrnsB->GetElementParams();
974
975 const std::vector<DCRTPoly> &c = ciphertext->GetElements();
976
977 const DCRTPoly &s = privateKey->GetPrivateElement();
978 DCRTPoly sPower = s;
979
980 DCRTPoly b = c[0];
981 b.SetFormat(Format::EVALUATION);
982
983 DCRTPoly cTemp;
984 for (size_t i = 1; i <= ciphertext->GetDepth(); i++) {
985 cTemp = c[i];
986 cTemp.SetFormat(Format::EVALUATION);
987
988 b += sPower * cTemp;
989 sPower *= s;
990 }
991
992 // Converts back to Format::COEFFICIENT representation
993 b.SetFormat(Format::COEFFICIENT);
994
995 auto &t = cryptoParamsBFVrnsB->GetPlaintextModulus();
996 auto &tgamma = cryptoParamsBFVrnsB->Gettgamma();
997 const std::vector<NativeInteger> &moduliQ = cryptoParamsBFVrnsB->GetModuliQ();
998 const std::vector<NativeInteger> &tgammaQHatInvModq =
999 cryptoParamsBFVrnsB->GettgammaQHatInvModq();
1000 const std::vector<NativeInteger> &tgammaQHatInvModqPrecon =
1001 cryptoParamsBFVrnsB->GettgammaQHatInvModqPrecon();
1002 const std::vector<NativeInteger> &negInvqModtgamma =
1003 cryptoParamsBFVrnsB->GetNegInvqModtgamma();
1004 const std::vector<NativeInteger> &negInvqModtgammaPrecon =
1005 cryptoParamsBFVrnsB->GetNegInvqModtgammaPrecon();
1006
1007 // this is the resulting vector of coefficients;
1008 *plaintext = b.ScaleAndRound(moduliQ, t, tgamma, tgammaQHatInvModq,
1009 tgammaQHatInvModqPrecon, negInvqModtgamma,
1010 negInvqModtgammaPrecon);
1011
1012 // std::cout << "Decryption time (internal): " << TOC_US(t_total) << " us" <<
1013 // std::endl;
1014
1015 return DecryptResult(plaintext->GetLength());
1016 }
1017
1018 template <>
Encrypt(const LPPrivateKey<DCRTPoly> privateKey,DCRTPoly ptxt) const1019 Ciphertext<DCRTPoly> LPAlgorithmBFVrnsB<DCRTPoly>::Encrypt(
1020 const LPPrivateKey<DCRTPoly> privateKey, DCRTPoly ptxt) const {
1021 Ciphertext<DCRTPoly> ciphertext(
1022 std::make_shared<CiphertextImpl<DCRTPoly>>(privateKey));
1023
1024 const auto cryptoParams =
1025 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1026 privateKey->GetCryptoParameters());
1027
1028 const shared_ptr<ParmType> elementParams = cryptoParams->GetElementParams();
1029
1030 ptxt.SwitchFormat();
1031
1032 const DggType &dgg = cryptoParams->GetDiscreteGaussianGenerator();
1033 DugType dug;
1034
1035 const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
1036
1037 DCRTPoly a(dug, elementParams, Format::EVALUATION);
1038 const DCRTPoly &s = privateKey->GetPrivateElement();
1039 DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1040
1041 DCRTPoly c0(a * s + e + ptxt.Times(delta));
1042 DCRTPoly c1(elementParams, Format::EVALUATION, true);
1043 c1 -= a;
1044
1045 ciphertext->SetElements({std::move(c0), std::move(c1)});
1046
1047 return ciphertext;
1048 }
1049
1050 template <>
EvalAdd(ConstCiphertext<DCRTPoly> ciphertext,ConstPlaintext plaintext) const1051 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalAdd(
1052 ConstCiphertext<DCRTPoly> ciphertext, ConstPlaintext plaintext) const {
1053 Ciphertext<DCRTPoly> newCiphertext = ciphertext->CloneEmpty();
1054 newCiphertext->SetDepth(ciphertext->GetDepth());
1055
1056 const std::vector<DCRTPoly> &cipherTextElements = ciphertext->GetElements();
1057
1058 const DCRTPoly &ptElement = plaintext->GetElement<DCRTPoly>();
1059
1060 std::vector<DCRTPoly> c(cipherTextElements.size());
1061
1062 const auto cryptoParams =
1063 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1064 ciphertext->GetCryptoParameters());
1065
1066 const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
1067
1068 c[0] = cipherTextElements[0] + ptElement.Times(delta);
1069
1070 for (size_t i = 1; i < cipherTextElements.size(); i++) {
1071 c[i] = cipherTextElements[i];
1072 }
1073
1074 newCiphertext->SetElements(std::move(c));
1075
1076 return newCiphertext;
1077 }
1078
1079 template <>
EvalSub(ConstCiphertext<DCRTPoly> ciphertext,ConstPlaintext plaintext) const1080 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalSub(
1081 ConstCiphertext<DCRTPoly> ciphertext, ConstPlaintext plaintext) const {
1082 Ciphertext<DCRTPoly> newCiphertext = ciphertext->CloneEmpty();
1083 newCiphertext->SetDepth(ciphertext->GetDepth());
1084
1085 const std::vector<DCRTPoly> &cipherTextElements = ciphertext->GetElements();
1086
1087 plaintext->SetFormat(Format::EVALUATION);
1088 const DCRTPoly &ptElement = plaintext->GetElement<DCRTPoly>();
1089
1090 std::vector<DCRTPoly> c(cipherTextElements.size());
1091
1092 const auto cryptoParams =
1093 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1094 ciphertext->GetCryptoParameters());
1095
1096 const std::vector<NativeInteger> &delta = cryptoParams->GetDelta();
1097
1098 c[0] = cipherTextElements[0] - ptElement.Times(delta);
1099
1100 for (size_t i = 1; i < cipherTextElements.size(); i++) {
1101 c[i] = cipherTextElements[i];
1102 }
1103
1104 newCiphertext->SetElements(std::move(c));
1105
1106 return newCiphertext;
1107 }
1108
1109 template <>
EvalMult(ConstCiphertext<DCRTPoly> ciphertext1,ConstCiphertext<DCRTPoly> ciphertext2) const1110 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalMult(
1111 ConstCiphertext<DCRTPoly> ciphertext1,
1112 ConstCiphertext<DCRTPoly> ciphertext2) const {
1113 if (!(ciphertext1->GetCryptoParameters() ==
1114 ciphertext2->GetCryptoParameters())) {
1115 std::string errMsg =
1116 "LPAlgorithmSHEBFVrnsB::EvalMult crypto parameters are not the same";
1117 PALISADE_THROW(config_error, errMsg);
1118 }
1119
1120 Ciphertext<DCRTPoly> newCiphertext = ciphertext1->CloneEmpty();
1121
1122 const auto cryptoParamsBFVrnsB =
1123 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1124 ciphertext1->GetCryptoContext()->GetCryptoParameters());
1125
1126 // Get the ciphertext elements
1127 std::vector<DCRTPoly> cipherText1Elements = ciphertext1->GetElements();
1128 std::vector<DCRTPoly> cipherText2Elements = ciphertext2->GetElements();
1129
1130 size_t cipherText1ElementsSize = cipherText1Elements.size();
1131 size_t cipherText2ElementsSize = cipherText2Elements.size();
1132 size_t cipherTextRElementsSize =
1133 cipherText1ElementsSize + cipherText2ElementsSize - 1;
1134
1135 std::vector<DCRTPoly> c(cipherTextRElementsSize);
1136
1137 const shared_ptr<ParmType> elementParams =
1138 cryptoParamsBFVrnsB->GetElementParams();
1139 const shared_ptr<ILDCRTParams<BigInteger>> paramsBsk =
1140 cryptoParamsBFVrnsB->GetParamsBsk();
1141 const std::vector<NativeInteger> &moduliQ = cryptoParamsBFVrnsB->GetModuliQ();
1142 const std::vector<DoubleNativeInt> &modqBarrettMu =
1143 cryptoParamsBFVrnsB->GetModqBarrettMu();
1144 const std::vector<NativeInteger> &moduliBsk =
1145 cryptoParamsBFVrnsB->GetModuliBsk();
1146 const std::vector<DoubleNativeInt> &modbskBarrettMu =
1147 cryptoParamsBFVrnsB->GetModbskBarrettMu();
1148 const std::vector<NativeInteger> &mtildeQHatInvModq =
1149 cryptoParamsBFVrnsB->GetmtildeQHatInvModq();
1150 const std::vector<NativeInteger> &mtildeQHatInvModqPrecon =
1151 cryptoParamsBFVrnsB->GetmtildeQHatInvModqPrecon();
1152 const std::vector<std::vector<NativeInteger>> &QHatModbsk =
1153 cryptoParamsBFVrnsB->GetQHatModbsk();
1154 const std::vector<uint16_t> &QHatModmtilde =
1155 cryptoParamsBFVrnsB->GetQHatModmtilde();
1156 const std::vector<NativeInteger> &QModbsk = cryptoParamsBFVrnsB->GetQModbsk();
1157 const std::vector<NativeInteger> &QModbskPrecon =
1158 cryptoParamsBFVrnsB->GetQModbskPrecon();
1159 const uint16_t &negQInvModmtilde = cryptoParamsBFVrnsB->GetNegQInvModmtilde();
1160 // const NativeInteger &negQInvModmtildePrecon =
1161 // cryptoParamsBFVrnsB->GetNegQInvModmtildePrecon();
1162 const std::vector<NativeInteger> &mtildeInvModbsk =
1163 cryptoParamsBFVrnsB->GetmtildeInvModbsk();
1164 const std::vector<NativeInteger> &mtildeInvModbskPrecon =
1165 cryptoParamsBFVrnsB->GetmtildeInvModbskPrecon();
1166
1167 // Expands the CRT basis to q*Bsk; Outputs the polynomials in coeff
1168 // representation
1169
1170 for (size_t i = 0; i < cipherText1ElementsSize; i++) {
1171 cipherText1Elements[i].FastBaseConvqToBskMontgomery(
1172 paramsBsk, moduliQ, moduliBsk, modbskBarrettMu, mtildeQHatInvModq,
1173 mtildeQHatInvModqPrecon, QHatModbsk, QHatModmtilde, QModbsk,
1174 QModbskPrecon, negQInvModmtilde, mtildeInvModbsk,
1175 mtildeInvModbskPrecon);
1176
1177 cipherText1Elements[i].SetFormat(Format::EVALUATION);
1178 }
1179
1180 for (size_t i = 0; i < cipherText2ElementsSize; i++) {
1181 cipherText2Elements[i].FastBaseConvqToBskMontgomery(
1182 paramsBsk, moduliQ, moduliBsk, modbskBarrettMu, mtildeQHatInvModq,
1183 mtildeQHatInvModqPrecon, QHatModbsk, QHatModmtilde, QModbsk,
1184 QModbskPrecon, negQInvModmtilde, mtildeInvModbsk,
1185 mtildeInvModbskPrecon);
1186
1187 cipherText2Elements[i].SetFormat(Format::EVALUATION);
1188 }
1189
1190 // Performs the multiplication itself
1191
1192 #ifdef USE_KARATSUBA
1193
1194 if (cipherText1ElementsSize == 2 && cipherText2ElementsSize == 2) {
1195 // size of each ciphertxt = 2, use Karatsuba
1196 c[0] = cipherText1Elements[0] * cipherText2Elements[0]; // a
1197 c[2] = cipherText1Elements[1] * cipherText2Elements[1]; // b
1198
1199 c[1] = cipherText1Elements[0] + cipherText1Elements[1];
1200 c[1] *= (cipherText2Elements[0] + cipherText2Elements[1]);
1201 c[1] -= c[2];
1202 c[1] -= c[0];
1203
1204 } else { // if size of any of the ciphertexts > 2
1205 bool *isFirstAdd = new bool[cipherTextRElementsSize];
1206 std::fill_n(isFirstAdd, cipherTextRElementsSize, true);
1207
1208 for (size_t i = 0; i < cipherText1ElementsSize; i++) {
1209 for (size_t j = 0; j < cipherText2ElementsSize; j++) {
1210 if (isFirstAdd[i + j] == true) {
1211 c[i + j] = cipherText1Elements[i] * cipherText2Elements[j];
1212 isFirstAdd[i + j] = false;
1213 } else {
1214 c[i + j] += cipherText1Elements[i] * cipherText2Elements[j];
1215 }
1216 }
1217 }
1218
1219 delete[] isFirstAdd;
1220 }
1221
1222 #else
1223 bool *isFirstAdd = new bool[cipherTextRElementsSize];
1224 std::fill_n(isFirstAdd, cipherTextRElementsSize, true);
1225
1226 for (size_t i = 0; i < cipherText1ElementsSize; i++) {
1227 for (size_t j = 0; j < cipherText2ElementsSize; j++) {
1228 if (isFirstAdd[i + j] == true) {
1229 c[i + j] = cipherText1Elements[i] * cipherText2Elements[j];
1230 isFirstAdd[i + j] = false;
1231 } else {
1232 c[i + j] += cipherText1Elements[i] * cipherText2Elements[j];
1233 }
1234 }
1235 }
1236
1237 delete[] isFirstAdd;
1238 #endif
1239
1240 // perfrom RNS approximate Flooring
1241 const NativeInteger &t = cryptoParamsBFVrnsB->GetPlaintextModulus();
1242 const std::vector<NativeInteger> &tQHatInvModq =
1243 cryptoParamsBFVrnsB->GettQHatInvModq();
1244 const std::vector<NativeInteger> &tQHatInvModqPrecon =
1245 cryptoParamsBFVrnsB->GettQHatInvModqPrecon();
1246 const std::vector<std::vector<NativeInteger>> &qInvModbsk =
1247 cryptoParamsBFVrnsB->GetqInvModbsk();
1248 const std::vector<NativeInteger> &tQInvModbsk =
1249 cryptoParamsBFVrnsB->GettQInvModbsk();
1250 const std::vector<NativeInteger> &tQInvModbskPrecon =
1251 cryptoParamsBFVrnsB->GettQInvModbskPrecon();
1252
1253 // perform FastBaseConvSK
1254 const std::vector<NativeInteger> &BHatInvModb =
1255 cryptoParamsBFVrnsB->GetBHatInvModb();
1256 const std::vector<NativeInteger> &BHatInvModbPrecon =
1257 cryptoParamsBFVrnsB->GetBHatInvModbPrecon();
1258 const std::vector<NativeInteger> &BHatModmsk =
1259 cryptoParamsBFVrnsB->GetBHatModmsk();
1260 const NativeInteger &BInvModmsk = cryptoParamsBFVrnsB->GetBInvModmsk();
1261 const NativeInteger &BInvModmskPrecon =
1262 cryptoParamsBFVrnsB->GetBInvModmskPrecon();
1263 const std::vector<std::vector<NativeInteger>> &BHatModq =
1264 cryptoParamsBFVrnsB->GetBHatModq();
1265 const std::vector<NativeInteger> &BModq = cryptoParamsBFVrnsB->GetBModq();
1266 const std::vector<NativeInteger> &BModqPrecon =
1267 cryptoParamsBFVrnsB->GetBModqPrecon();
1268
1269 for (size_t i = 0; i < cipherTextRElementsSize; i++) {
1270 // converts to Format::COEFFICIENT representation before rounding
1271 c[i].SetFormat(Format::COEFFICIENT);
1272 // Performs the scaling by t/Q followed by rounding; the result is in the
1273 // CRT basis {Bsk}
1274 c[i].FastRNSFloorq(t, moduliQ, moduliBsk, modbskBarrettMu, tQHatInvModq,
1275 tQHatInvModqPrecon, QHatModbsk, qInvModbsk, tQInvModbsk,
1276 tQInvModbskPrecon);
1277
1278 // Converts from the CRT basis {Bsk} to {Q}
1279 c[i].FastBaseConvSK(moduliQ, modqBarrettMu, moduliBsk, modbskBarrettMu,
1280 BHatInvModb, BHatInvModbPrecon, BHatModmsk, BInvModmsk,
1281 BInvModmskPrecon, BHatModq, BModq, BModqPrecon);
1282 }
1283
1284 newCiphertext->SetElements(std::move(c));
1285 newCiphertext->SetDepth((ciphertext1->GetDepth() + ciphertext2->GetDepth()));
1286
1287 return newCiphertext;
1288 }
1289
1290 template <>
KeySwitchGen(const LPPrivateKey<DCRTPoly> originalPrivateKey,const LPPrivateKey<DCRTPoly> newPrivateKey) const1291 LPEvalKey<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::KeySwitchGen(
1292 const LPPrivateKey<DCRTPoly> originalPrivateKey,
1293 const LPPrivateKey<DCRTPoly> newPrivateKey) const {
1294 LPEvalKey<DCRTPoly> ek(std::make_shared<LPEvalKeyRelinImpl<DCRTPoly>>(
1295 newPrivateKey->GetCryptoContext()));
1296
1297 const auto cryptoParamsLWE =
1298 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1299 newPrivateKey->GetCryptoParameters());
1300 const shared_ptr<ParmType> elementParams =
1301 cryptoParamsLWE->GetElementParams();
1302 const DCRTPoly &s = newPrivateKey->GetPrivateElement();
1303
1304 const DggType &dgg = cryptoParamsLWE->GetDiscreteGaussianGenerator();
1305 DugType dug;
1306
1307 const DCRTPoly &oldKey = originalPrivateKey->GetPrivateElement();
1308
1309 std::vector<DCRTPoly> evalKeyElements;
1310 std::vector<DCRTPoly> evalKeyElementsGenerated;
1311
1312 uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1313
1314 for (usint i = 0; i < oldKey.GetNumOfElements(); i++) {
1315 if (relinWindow > 0) {
1316 vector<typename DCRTPoly::PolyType> decomposedKeyElements =
1317 oldKey.GetElementAtIndex(i).PowersOfBase(relinWindow);
1318
1319 for (size_t k = 0; k < decomposedKeyElements.size(); k++) {
1320 // Creates an element with all zeroes
1321 DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1322
1323 filtered.SetElementAtIndex(i, decomposedKeyElements[k]);
1324
1325 // Generate a_i vectors
1326 DCRTPoly a(dug, elementParams, Format::EVALUATION);
1327 evalKeyElementsGenerated.push_back(a);
1328
1329 // Generate a_i * s + e - [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi)
1330 DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1331 evalKeyElements.push_back(filtered - (a * s + e));
1332 }
1333 } else {
1334 // Creates an element with all zeroes
1335 DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1336
1337 filtered.SetElementAtIndex(i, oldKey.GetElementAtIndex(i));
1338
1339 // Generate a_i vectors
1340 DCRTPoly a(dug, elementParams, Format::EVALUATION);
1341 evalKeyElementsGenerated.push_back(a);
1342
1343 // Generate a_i * s + e - [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi)
1344 DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1345 evalKeyElements.push_back(filtered - (a * s + e));
1346 }
1347 }
1348
1349 ek->SetAVector(std::move(evalKeyElements));
1350 ek->SetBVector(std::move(evalKeyElementsGenerated));
1351
1352 return ek;
1353 }
1354
1355 template <>
MultiKeySwitchGen(const LPPrivateKey<DCRTPoly> originalPrivateKey,const LPPrivateKey<DCRTPoly> newPrivateKey,const LPEvalKey<DCRTPoly> ek) const1356 LPEvalKey<DCRTPoly> LPAlgorithmMultipartyBFVrnsB<DCRTPoly>::MultiKeySwitchGen(
1357 const LPPrivateKey<DCRTPoly> originalPrivateKey,
1358 const LPPrivateKey<DCRTPoly> newPrivateKey,
1359 const LPEvalKey<DCRTPoly> ek) const {
1360 LPEvalKeyRelin<DCRTPoly> keySwitchHintRelin(
1361 new LPEvalKeyRelinImpl<DCRTPoly>(newPrivateKey->GetCryptoContext()));
1362
1363 const shared_ptr<LPCryptoParametersRLWE<DCRTPoly>> cryptoParamsLWE =
1364 std::dynamic_pointer_cast<LPCryptoParametersRLWE<DCRTPoly>>(
1365 newPrivateKey->GetCryptoParameters());
1366 const shared_ptr<ParmType> elementParams =
1367 cryptoParamsLWE->GetElementParams();
1368
1369 // Getting a reference to the polynomials of new private key.
1370 const DCRTPoly &sNew = newPrivateKey->GetPrivateElement();
1371
1372 // Getting a reference to the polynomials of original private key.
1373 const DCRTPoly &s = originalPrivateKey->GetPrivateElement();
1374
1375 const DggType &dgg = cryptoParamsLWE->GetDiscreteGaussianGenerator();
1376 DugType dug;
1377
1378 std::vector<DCRTPoly> evalKeyElements;
1379 std::vector<DCRTPoly> evalKeyElementsGenerated;
1380
1381 uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1382
1383 const std::vector<DCRTPoly> &a = ek->GetBVector();
1384
1385 for (usint i = 0; i < s.GetNumOfElements(); i++) {
1386 if (relinWindow > 0) {
1387 vector<typename DCRTPoly::PolyType> decomposedKeyElements =
1388 s.GetElementAtIndex(i).PowersOfBase(relinWindow);
1389
1390 for (size_t k = 0; k < decomposedKeyElements.size(); k++) {
1391 // Creates an element with all zeroes
1392 DCRTPoly filtered(elementParams, EVALUATION, true);
1393
1394 filtered.SetElementAtIndex(i, decomposedKeyElements[k]);
1395
1396 // Generate a_i vectors
1397 evalKeyElementsGenerated.push_back(
1398 a[i * decomposedKeyElements.size() + k]);
1399
1400 // Generate a_i * s + e - [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi)
1401 DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1402 evalKeyElements.push_back(
1403 filtered - (a[i * decomposedKeyElements.size() + k] * sNew + e));
1404 }
1405 } else {
1406 // Creates an element with all zeroes
1407 DCRTPoly filtered(elementParams, EVALUATION, true);
1408
1409 filtered.SetElementAtIndex(i, s.GetElementAtIndex(i));
1410
1411 // Generate a_i vectors
1412 evalKeyElementsGenerated.push_back(a[i]);
1413
1414 // Generate [oldKey]_qi [(q/qi)^{-1}]_qi (q/qi) - (a_i * s + e)
1415 DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1416 evalKeyElements.push_back(filtered - (a[i] * sNew + e));
1417 }
1418 }
1419
1420 keySwitchHintRelin->SetAVector(std::move(evalKeyElements));
1421 keySwitchHintRelin->SetBVector(std::move(evalKeyElementsGenerated));
1422
1423 return keySwitchHintRelin;
1424 }
1425
1426 template <>
KeySwitchInPlace(const LPEvalKey<DCRTPoly> ek,Ciphertext<DCRTPoly> & cipherText) const1427 void LPAlgorithmSHEBFVrnsB<DCRTPoly>::KeySwitchInPlace(
1428 const LPEvalKey<DCRTPoly> ek, Ciphertext<DCRTPoly>& cipherText) const {
1429
1430 const auto cryptoParamsLWE =
1431 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1432 ek->GetCryptoParameters());
1433
1434 LPEvalKeyRelin<DCRTPoly> evalKey =
1435 std::static_pointer_cast<LPEvalKeyRelinImpl<DCRTPoly>>(ek);
1436
1437 std::vector<DCRTPoly> &c = cipherText->GetElements();
1438
1439 const std::vector<DCRTPoly> &b = evalKey->GetAVector();
1440 const std::vector<DCRTPoly> &a = evalKey->GetBVector();
1441
1442 uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1443
1444 std::vector<DCRTPoly> digitsC2;
1445
1446
1447 // in the case of EvalMult, c[0] is initially in Format::COEFFICIENT format
1448 // and needs to be switched to Format::EVALUATION format
1449 if (c.size() > 2) c[0].SetFormat(Format::EVALUATION);
1450
1451 if (c.size() == 2) { // case of automorphism
1452 digitsC2 = c[1].CRTDecompose(relinWindow);
1453 c[1] = digitsC2[0] * a[0];
1454 } else { // case of EvalMult
1455 digitsC2 = c[2].CRTDecompose(relinWindow);
1456 c[1].SetFormat(Format::EVALUATION);
1457 c[1] += digitsC2[0] * a[0];
1458 }
1459
1460 c[0] += digitsC2[0] * b[0];
1461
1462 for (usint i = 1; i < digitsC2.size(); ++i) {
1463 c[0] += digitsC2[i] * b[i];
1464 c[1] += digitsC2[i] * a[i];
1465 }
1466
1467 Ciphertext<DCRTPoly> newCiphertext = cipherText->CloneEmpty();
1468 newCiphertext->SetElements({std::move(c[0]), std::move(c[1])});
1469 cipherText = std::move(newCiphertext);
1470 }
1471
1472 template <>
EvalMultAndRelinearize(ConstCiphertext<DCRTPoly> ciphertext1,ConstCiphertext<DCRTPoly> ciphertext2,const vector<LPEvalKey<DCRTPoly>> & ek) const1473 Ciphertext<DCRTPoly> LPAlgorithmSHEBFVrnsB<DCRTPoly>::EvalMultAndRelinearize(
1474 ConstCiphertext<DCRTPoly> ciphertext1,
1475 ConstCiphertext<DCRTPoly> ciphertext2,
1476 const vector<LPEvalKey<DCRTPoly>> &ek) const {
1477 Ciphertext<DCRTPoly> cipherText = this->EvalMult(ciphertext1, ciphertext2);
1478
1479 const auto cryptoParamsLWE =
1480 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1481 ek[0]->GetCryptoParameters());
1482
1483 Ciphertext<DCRTPoly> newCiphertext = cipherText->CloneEmpty();
1484
1485 std::vector<DCRTPoly> c = cipherText->GetElements();
1486 for (size_t i = 0; i < c.size(); i++) c[i].SetFormat(Format::EVALUATION);
1487
1488 DCRTPoly ct0(c[0]);
1489 DCRTPoly ct1(c[1]);
1490 // Perform a keyswitching operation to result of the multiplication. It does
1491 // it until it reaches to 2 elements.
1492 // TODO: Maybe we can change the number of keyswitching and terminate early.
1493 // For instance; perform keyswitching until 4 elements left.
1494 for (size_t j = 0; j <= cipherText->GetDepth() - 2; j++) {
1495 size_t index = cipherText->GetDepth() - 2 - j;
1496 LPEvalKeyRelin<DCRTPoly> evalKey =
1497 std::static_pointer_cast<LPEvalKeyRelinImpl<DCRTPoly>>(ek[index]);
1498
1499 const std::vector<DCRTPoly> &b = evalKey->GetAVector();
1500 const std::vector<DCRTPoly> &a = evalKey->GetBVector();
1501
1502 std::vector<DCRTPoly> digitsC2 = c[index + 2].CRTDecompose();
1503
1504 for (usint i = 0; i < digitsC2.size(); ++i) {
1505 ct0 += digitsC2[i] * b[i];
1506 ct1 += digitsC2[i] * a[i];
1507 }
1508 }
1509
1510 newCiphertext->SetElements({std::move(ct0), std::move(ct1)});
1511
1512 return newCiphertext;
1513 }
1514
1515 template <>
ReKeyGen(const LPPublicKey<DCRTPoly> newPK,const LPPrivateKey<DCRTPoly> origPrivateKey) const1516 LPEvalKey<DCRTPoly> LPAlgorithmPREBFVrnsB<DCRTPoly>::ReKeyGen(
1517 const LPPublicKey<DCRTPoly> newPK,
1518 const LPPrivateKey<DCRTPoly> origPrivateKey) const {
1519 // Get crypto context of new public key.
1520 auto cc = newPK->GetCryptoContext();
1521
1522 // Create an Format::EVALUATION key that will contain all the re-encryption
1523 // key elements.
1524 LPEvalKeyRelin<DCRTPoly> ek(
1525 std::make_shared<LPEvalKeyRelinImpl<DCRTPoly>>(cc));
1526
1527 const auto cryptoParamsLWE =
1528 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1529 newPK->GetCryptoParameters());
1530 const shared_ptr<DCRTPoly::Params> elementParams =
1531 cryptoParamsLWE->GetElementParams();
1532
1533 const DCRTPoly::DggType &dgg =
1534 cryptoParamsLWE->GetDiscreteGaussianGenerator();
1535 DCRTPoly::DugType dug;
1536 DCRTPoly::TugType tug;
1537
1538 const DCRTPoly &oldKey = origPrivateKey->GetPrivateElement();
1539
1540 std::vector<DCRTPoly> evalKeyElements;
1541 std::vector<DCRTPoly> evalKeyElementsGenerated;
1542
1543 uint32_t relinWindow = cryptoParamsLWE->GetRelinWindow();
1544
1545 const DCRTPoly &p0 = newPK->GetPublicElements().at(0);
1546 const DCRTPoly &p1 = newPK->GetPublicElements().at(1);
1547
1548 for (usint i = 0; i < oldKey.GetNumOfElements(); i++) {
1549 if (relinWindow > 0) {
1550 vector<DCRTPoly::PolyType> decomposedKeyElements =
1551 oldKey.GetElementAtIndex(i).PowersOfBase(relinWindow);
1552
1553 for (size_t k = 0; k < decomposedKeyElements.size(); k++) {
1554 // Creates an element with all zeroes
1555 DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1556
1557 filtered.SetElementAtIndex(i, decomposedKeyElements[k]);
1558
1559 DCRTPoly u;
1560
1561 if (cryptoParamsLWE->GetMode() == RLWE)
1562 u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
1563 else
1564 u = DCRTPoly(tug, elementParams, Format::EVALUATION);
1565
1566 DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
1567 DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
1568
1569 DCRTPoly c0(elementParams);
1570 DCRTPoly c1(elementParams);
1571
1572 c0 = p0 * u + e1 + filtered;
1573
1574 c1 = p1 * u + e2;
1575
1576 DCRTPoly a(dug, elementParams, Format::EVALUATION);
1577 evalKeyElementsGenerated.push_back(c1);
1578
1579 DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1580 evalKeyElements.push_back(c0);
1581 }
1582 } else {
1583 // Creates an element with all zeroes
1584 DCRTPoly filtered(elementParams, Format::EVALUATION, true);
1585
1586 filtered.SetElementAtIndex(i, oldKey.GetElementAtIndex(i));
1587
1588 DCRTPoly u;
1589
1590 if (cryptoParamsLWE->GetMode() == RLWE)
1591 u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
1592 else
1593 u = DCRTPoly(tug, elementParams, Format::EVALUATION);
1594
1595 DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
1596 DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
1597
1598 DCRTPoly c0(elementParams);
1599 DCRTPoly c1(elementParams);
1600
1601 c0 = p0 * u + e1 + filtered;
1602
1603 c1 = p1 * u + e2;
1604
1605 DCRTPoly a(dug, elementParams, Format::EVALUATION);
1606 evalKeyElementsGenerated.push_back(c1);
1607
1608 DCRTPoly e(dgg, elementParams, Format::EVALUATION);
1609 evalKeyElements.push_back(c0);
1610 }
1611 }
1612
1613 ek->SetAVector(std::move(evalKeyElements));
1614 ek->SetBVector(std::move(evalKeyElementsGenerated));
1615
1616 return ek;
1617 }
1618
1619 template <>
ReEncrypt(const LPEvalKey<DCRTPoly> ek,ConstCiphertext<DCRTPoly> ciphertext,const LPPublicKey<DCRTPoly> publicKey) const1620 Ciphertext<DCRTPoly> LPAlgorithmPREBFVrnsB<DCRTPoly>::ReEncrypt(
1621 const LPEvalKey<DCRTPoly> ek, ConstCiphertext<DCRTPoly> ciphertext,
1622 const LPPublicKey<DCRTPoly> publicKey) const {
1623 if (publicKey == nullptr) { // Sender PK is not provided - CPA-secure PRE
1624 return ciphertext->GetCryptoContext()->KeySwitch(ek, ciphertext);
1625 }
1626
1627 // Sender PK provided - HRA-secure PRE
1628 const auto cryptoParamsLWE =
1629 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1630 ek->GetCryptoParameters());
1631
1632 // Get crypto and elements parameters
1633 const shared_ptr<ParmType> elementParams =
1634 cryptoParamsLWE->GetElementParams();
1635
1636 const DggType &dgg = cryptoParamsLWE->GetDiscreteGaussianGenerator();
1637 TugType tug;
1638
1639 PlaintextEncodings encType = ciphertext->GetEncodingType();
1640
1641 Ciphertext<DCRTPoly> zeroCiphertext(
1642 std::make_shared<CiphertextImpl<DCRTPoly>>(publicKey));
1643 zeroCiphertext->SetEncodingType(encType);
1644
1645 const DCRTPoly &p0 = publicKey->GetPublicElements().at(0);
1646 const DCRTPoly &p1 = publicKey->GetPublicElements().at(1);
1647
1648 DCRTPoly u;
1649
1650 if (cryptoParamsLWE->GetMode() == RLWE)
1651 u = DCRTPoly(dgg, elementParams, Format::EVALUATION);
1652 else
1653 u = DCRTPoly(tug, elementParams, Format::EVALUATION);
1654
1655 DCRTPoly e1(dgg, elementParams, Format::EVALUATION);
1656 DCRTPoly e2(dgg, elementParams, Format::EVALUATION);
1657
1658 DCRTPoly c0 = p0 * u + e1;
1659 DCRTPoly c1 = p1 * u + e2;
1660
1661 zeroCiphertext->SetElements({std::move(c0), std::move(c1)});
1662
1663 // Add the encryption of zero for re-randomization purposes
1664 auto c = ciphertext->GetCryptoContext()->GetEncryptionAlgorithm()->EvalAdd(
1665 ciphertext, zeroCiphertext);
1666
1667 ciphertext->GetCryptoContext()->KeySwitchInPlace(ek, c);
1668 return c;
1669 }
1670
1671 template <>
MultipartyDecryptFusion(const vector<Ciphertext<DCRTPoly>> & ciphertextVec,NativePoly * plaintext) const1672 DecryptResult LPAlgorithmMultipartyBFVrnsB<DCRTPoly>::MultipartyDecryptFusion(
1673 const vector<Ciphertext<DCRTPoly>> &ciphertextVec,
1674 NativePoly *plaintext) const {
1675 const auto cryptoParamsBFVrnsB =
1676 std::static_pointer_cast<LPCryptoParametersBFVrnsB<DCRTPoly>>(
1677 ciphertextVec[0]->GetCryptoParameters());
1678 const shared_ptr<ParmType> elementParams =
1679 cryptoParamsBFVrnsB->GetElementParams();
1680
1681 const std::vector<DCRTPoly> &cElem = ciphertextVec[0]->GetElements();
1682 DCRTPoly b = cElem[0];
1683
1684 size_t numCipher = ciphertextVec.size();
1685 for (size_t i = 1; i < numCipher; i++) {
1686 const std::vector<DCRTPoly> &c2 = ciphertextVec[i]->GetElements();
1687 b += c2[0];
1688 }
1689
1690 auto &t = cryptoParamsBFVrnsB->GetPlaintextModulus();
1691 auto &tgamma = cryptoParamsBFVrnsB->Gettgamma();
1692
1693 // Invoke BFVrnsB DecRNS
1694
1695 const std::vector<NativeInteger> &moduliQ = cryptoParamsBFVrnsB->GetModuliQ();
1696 const std::vector<NativeInteger> &tgammaQHatInvModq =
1697 cryptoParamsBFVrnsB->GettgammaQHatInvModq();
1698 const std::vector<NativeInteger> &tgammaQHatInvModqPrecon =
1699 cryptoParamsBFVrnsB->GettgammaQHatInvModqPrecon();
1700 const std::vector<NativeInteger> &negInvqModtgamma =
1701 cryptoParamsBFVrnsB->GetNegInvqModtgamma();
1702 const std::vector<NativeInteger> &negInvqModtgammaPrecon =
1703 cryptoParamsBFVrnsB->GetNegInvqModtgammaPrecon();
1704
1705 // this is the resulting vector of coefficients;
1706 *plaintext = b.ScaleAndRound(moduliQ, t, tgamma, tgammaQHatInvModq,
1707 tgammaQHatInvModqPrecon, negInvqModtgamma,
1708 negInvqModtgammaPrecon);
1709
1710 return DecryptResult(plaintext->GetLength());
1711 }
1712
1713 template class LPCryptoParametersBFVrnsB<DCRTPoly>;
1714 template class LPPublicKeyEncryptionSchemeBFVrnsB<DCRTPoly>;
1715 template class LPAlgorithmBFVrnsB<DCRTPoly>;
1716 template class LPAlgorithmSHEBFVrnsB<DCRTPoly>;
1717 template class LPAlgorithmMultipartyBFVrnsB<DCRTPoly>;
1718 template class LPAlgorithmParamsGenBFVrnsB<DCRTPoly>;
1719
1720 } // namespace lbcrypto
1721