1 // @file trapdoor.h Provides the utility for sampling trapdoor lattices as
2 // described in https://eprint.iacr.org/2017/844.pdf
3 // https://eprint.iacr.org/2018/946, and "Implementing Token-Based Obfuscation
4 // under (Ring) LWE" as described in https://eprint.iacr.org/2018/1222.pdf.
5 // @author TPOC: contact@palisade-crypto.org
6 //
7 // @copyright Copyright (c) 2019, New Jersey Institute of Technology (NJIT)
8 // All rights reserved.
9 // Redistribution and use in source and binary forms, with or without
10 // modification, are permitted provided that the following conditions are met:
11 // 1. Redistributions of source code must retain the above copyright notice,
12 // this list of conditions and the following disclaimer.
13 // 2. Redistributions in binary form must reproduce the above copyright notice,
14 // this list of conditions and the following disclaimer in the documentation
15 // and/or other materials provided with the distribution. THIS SOFTWARE IS
16 // PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
17 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 // EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
20 // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27 #ifndef LBCRYPTO_LATTICE_TRAPDOOR_H
28 #define LBCRYPTO_LATTICE_TRAPDOOR_H
29 
30 #include <memory>
31 #include <utility>
32 
33 #include "math/matrix.h"
34 
35 #include "lattice/dgsampling.h"
36 
37 namespace lbcrypto {
38 
39 /**
40  * @brief Class to store a lattice trapdoor pair generated using construction 1
41  * in section 3.2 of https://eprint.iacr.org/2013/297.pdf This construction is
42  * based on the hardness of Ring-LWE problem
43  */
44 template <class Element>
45 class RLWETrapdoorPair {
46  public:
Field2n()47   // matrix of noise polynomials
48   Matrix<Element> m_r;
Field2n(Format f)49   // matrix
50   Matrix<Element> m_e;
51   // CTOR with empty trapdoor pair for deserialization
52   RLWETrapdoorPair()
53       : m_r(Matrix<Element>([]() { return Element(); }, 0, 0)),
54         m_e(Matrix<Element>([]() { return Element(); }, 0, 0)) {}
55 
56   RLWETrapdoorPair(const Matrix<Element> &r, const Matrix<Element> &e)
57       : m_r(r), m_e(e) {}
58 
59   template <class Archive>
60   void save(Archive &ar, std::uint32_t const version) const {
61     ar(CEREAL_NVP(m_r));
62     ar(CEREAL_NVP(m_e));
63   }
64 
65   template <class Archive>
66   void load(Archive &ar, std::uint32_t const version) {
67     ar(CEREAL_NVP(m_r));
68     ar(CEREAL_NVP(m_e));
69   }
70 };
71 
72 /**
73  * @brief Static class implementing lattice trapdoor construction in Algorithm 1
74  * of https://eprint.iacr.org/2017/844.pdf
75  */
76 template <class Element>
77 class RLWETrapdoorUtility {
78   using ParmType = typename Element::Params;
79   using DggType = typename Element::DggType;
80   using IntType = typename Element::Integer;
81 
82  public:
83   /**
84    * Trapdoor generation method as described in Algorithm 1 of
85    * https://eprint.iacr.org/2017/844.pdf
86    *
87    * @param params ring element parameters
88    * @param sttdev distribution parameter used in sampling noise polynomials
89    * of the trapdoor
90    * @param base base of gadget matrix
91    * @param bal flag for balanced (true) versus not-balanced (false) digit
92    * representation
93    * @return the trapdoor pair including the public key (matrix of rings)
94    * and trapdoor itself
95    */
GetFormat()96   static std::pair<Matrix<Element>, RLWETrapdoorPair<Element>> TrapdoorGen(
97       shared_ptr<ParmType> params, double stddev, int64_t base = 2,
98       bool bal = false);
99 
100   /**
101    * Generalized trapdoor generation method (described in "Implementing
102    * Token-Based Obfuscation under (Ring) LWE")
103    *
104    * @param params ring element parameters
105    * @param sttdev distribution parameter used in sampling noise polynomials of
106    * the trapdoor
107    * @param dimension of square matrix
108    * @param base base of gadget matrix
109    * @param bal flag for balanced (true) versus not-balanced (false) digit
110    * representation
111    * @return the trapdoor pair including the public key (matrix of rings) and
112    * trapdoor itself
113    */
114   static std::pair<Matrix<Element>, RLWETrapdoorPair<Element>>
115   TrapdoorGenSquareMat(shared_ptr<ParmType> params, double stddev,
116                        size_t dimension, int64_t base = 2, bool bal = false);
117 
118   /**
119    * Gaussian sampling as described in Alogorithm 2 of
120    * https://eprint.iacr.org/2017/844.pdf
121    *
122    * @param n ring dimension
123    * @param k matrix sample dimension; k = log2(q)/log2(base) + 2
124    * @param &A public key of the trapdoor pair
125    * @param &T trapdoor itself
126    * @param &u syndrome vector where gaussian that Gaussian sampling is centered
127    * around
128    * @param &dgg discrete Gaussian generator for integers
129    * @param &dggLargeSigma discrete Gaussian generator for perturbation vector
130    * sampling (only used in Peikert's method)
131    * @param base base of gadget matrix
132    * @return the sampled vector (matrix)
133    */
134   static Matrix<Element> GaussSamp(size_t n, size_t k, const Matrix<Element> &A,
135                                    const RLWETrapdoorPair<Element> &T,
136                                    const Element &u, DggType &dgg,
137                                    DggType &dggLargeSigma, int64_t base = 2);
138 
139   /**
140    * Gaussian sampling (described in "Implementing Token-Based Obfuscation under
141    * (Ring) LWE")
142    *
143    * @param n ring dimension
144    * @param k matrix sample dimension; k = log2(q)/log2(base) + 2
145    * @param &A public key of the trapdoor pair
146    * @param &T trapdoor itself
147    * @param &U syndrome matrix that Gaussian sampling is centered around
148    * @param &dgg discrete Gaussian generator for integers
149    * @param &dggLargeSigma discrete Gaussian generator for perturbation vector
150    * sampling (only used in Peikert's method)
151    * @param base base of gadget matrix
152    * @return the sampled vector (matrix)
153    */
154   static Matrix<Element> GaussSampSquareMat(
155       size_t n, size_t k, const Matrix<Element> &A,
156       const RLWETrapdoorPair<Element> &T, const Matrix<Element> &U,
157       DggType &dgg, DggType &dggLargeSigma, int64_t base = 2);
158 
159   /**
160    * On-line stage of pre-image sampling (includes only G-sampling)
161    *
162    * @param n ring dimension
163    * @param k matrix sample dimension; k = log2(q)/log2(base) + 2
164    * @param &A public key of the trapdoor pair
165    * @param &T trapdoor itself
166    * @param &u syndrome vector where gaussian that Gaussian sampling is centered
167    * around
168    * @param &dgg discrete Gaussian generator for integers
169    * @param &perturbationVector perturbation vector generated during the offline
170    * stage
171    * @param &base base for G-lattice
172    * @return the sampled vector (matrix)
173    */
174   static Matrix<Element> GaussSampOnline(
175       size_t n, size_t k, const Matrix<Element> &A,
176       const RLWETrapdoorPair<Element> &T, const Element &u, DggType &dgg,
177       const shared_ptr<Matrix<Element>> perturbationVector, int64_t base = 2);
178 
179   /**
180    * Offline stage of pre-image sampling (perturbation sampling)
181    *
182    * @param n ring dimension
183    * @param k matrix sample dimension; k = logq + 2
184    * @param &T trapdoor itself
185    * @param &dgg discrete Gaussian generator for integers
186    * @param &dggLargeSigma discrete Gaussian generator for perturbation vector
187    * sampling
188    * @param &base base for G-lattice
189    * @return the sampled vector (matrix)
190    */
191   static shared_ptr<Matrix<Element>> GaussSampOffline(
192       size_t n, size_t k, const RLWETrapdoorPair<Element> &T, DggType &dgg,
193       DggType &dggLargeSigma, int64_t base = 2);
194 
195   /**
196    * Method for perturbation generation as described in Algorithm 4 of
197    *https://eprint.iacr.org/2017/844.pdf
198    *
199    *@param n ring dimension
200    *@param s parameter Gaussian distribution
201    *@param sigma standard deviation
202    *@param &Tprime compact trapdoor matrix
203    *@param &dgg discrete Gaussian generator for error sampling
204    *@param &dggLargeSigma discrete Gaussian generator for perturbation vector
205    *sampling
206    *@param *perturbationVector perturbation vector;output of the function
207    */
208   static void ZSampleSigmaP(size_t n, double s, double sigma,
SetFormat(Format format)209                             const RLWETrapdoorPair<Element> &Tprime,
210                             const DggType &dgg, const DggType &dggLargeSigma,
211                             shared_ptr<Matrix<Element>> perturbationVector) {
212     DEBUG_FLAG(false);
213     TimeVar t1, t1_tot;
214 
215     TIC(t1);
216     TIC(t1_tot);
217     Matrix<Element> Tprime0 = Tprime.m_e;
218     Matrix<Element> Tprime1 = Tprime.m_r;
219 
220     // k is the bit length
221     size_t k = Tprime0.GetCols();
222 
223     const shared_ptr<ParmType> params = Tprime0(0, 0).GetParams();
224     DEBUG("z1a: " << TOC(t1));  // 0
225     TIC(t1);
226     // all three Polynomials are initialized with "0" coefficients
227     Element va(params, Format::EVALUATION, 1);
228     Element vb(params, Format::EVALUATION, 1);
229     Element vd(params, Format::EVALUATION, 1);
230 
231     for (size_t i = 0; i < k; i++) {
232       va += Tprime0(0, i) * Tprime0(0, i).Transpose();
233       vb += Tprime1(0, i) * Tprime0(0, i).Transpose();
234       vd += Tprime1(0, i) * Tprime1(0, i).Transpose();
235     }
236     DEBUG("z1b: " << TOC(t1));  // 9
237     TIC(t1);
238 
239     // Switch the ring elements (Polynomials) to coefficient representation
240     va.SetFormat(Format::COEFFICIENT);
241     vb.SetFormat(Format::COEFFICIENT);
242     vd.SetFormat(Format::COEFFICIENT);
243 
244     DEBUG("z1c: " << TOC(t1));  // 5
245     TIC(t1);
246 
247     // Create field elements from ring elements
248     Field2n a(va), b(vb), d(vd);
249 
250     double scalarFactor = -s * s * sigma * sigma / (s * s - sigma * sigma);
251 
252     a = a.ScalarMult(scalarFactor);
253     b = b.ScalarMult(scalarFactor);
254     d = d.ScalarMult(scalarFactor);
255 
256     a = a + s * s;
257     d = d + s * s;
258     DEBUG("z1d: " << TOC(t1));  // 0
259     TIC(t1);
260 
261     // converts the field elements to DFT representation
262     a.SetFormat(Format::EVALUATION);
263     b.SetFormat(Format::EVALUATION);
264     d.SetFormat(Format::EVALUATION);
265     DEBUG("z1e: " << TOC(t1));  // 0
266     TIC(t1);
267 
268     Matrix<int64_t> p2ZVector([]() { return 0; }, n * k, 1);
269 
270     double sigmaLarge = sqrt(s * s - sigma * sigma);
271 
272     // for distribution parameters up to 3e5 (experimentally found threshold)
273     // use the Peikert's inversion method otherwise, use Karney's method
274 
275     if (sigmaLarge > KARNEY_THRESHOLD) {
276       // Karney rejection sampling method
277       for (size_t i = 0; i < n * k; i++) {
278         p2ZVector(i, 0) = dgg.GenerateIntegerKarney(0, sigmaLarge);
279       }
280     } else {
281       // Peikert's inversion sampling method
282       std::shared_ptr<int64_t> dggVector =
283           dggLargeSigma.GenerateIntVector(n * k);
284       for (size_t i = 0; i < n * k; i++) {
285         p2ZVector(i, 0) = (dggVector.get())[i];
286       }
287     }
288     DEBUG("z1f1: " << TOC(t1));
289     TIC(t1);
290 
291     // create k ring elements in coefficient representation
292     Matrix<Element> p2 =
293         SplitInt64IntoElements<Element>(p2ZVector, n, va.GetParams());
294     DEBUG("z1f2: " << TOC(t1));
295     TIC(t1);
296 
297     // now converting to Format::EVALUATION representation before multiplication
298     p2.SetFormat(Format::EVALUATION);
299 
300     DEBUG("z1g: " << TOC(t1));  // 17
301 
302     TIC(t1);
303 
304     // the dimension is 2x1 - a vector of 2 ring elements
305     auto zero_alloc = Element::Allocator(params, Format::EVALUATION);
306     Matrix<Element> Tp2(zero_alloc, 2, 1);
307     Tp2(0, 0) = (Tprime0 * p2)(0, 0);
308     Tp2(1, 0) = (Tprime1 * p2)(0, 0);
309 
310     DEBUG("z1h2: " << TOC(t1));
311     TIC(t1);
312     // change to coefficient representation before converting to field elements
313     Tp2.SetFormat(Format::COEFFICIENT);
314     DEBUG("z1h3: " << TOC(t1));
315     TIC(t1);
316 
317     Matrix<Field2n> c([]() { return Field2n(); }, 2, 1);
318 
319     c(0, 0) =
320         Field2n(Tp2(0, 0)).ScalarMult(-sigma * sigma / (s * s - sigma * sigma));
321     c(1, 0) =
322         Field2n(Tp2(1, 0)).ScalarMult(-sigma * sigma / (s * s - sigma * sigma));
323 
324     auto p1ZVector =
325         std::make_shared<Matrix<int64_t>>([]() { return 0; }, n * 2, 1);
326     DEBUG("z1i: " << TOC(t1));
327     TIC(t1);
328     LatticeGaussSampUtility<Element>::ZSampleSigma2x2(a, b, d, c, dgg,
329                                                       p1ZVector);
330     DEBUG("z1j1: " << TOC(t1));  // 14
331     TIC(t1);
332 
333     // create 2 ring elements in coefficient representation
334     Matrix<Element> p1 =
335         SplitInt64IntoElements<Element>(*p1ZVector, n, va.GetParams());
336     DEBUG("z1j2: " << TOC(t1));
337     TIC(t1);
338 
339     p1.SetFormat(Format::EVALUATION);
340     DEBUG("z1j3: " << TOC(t1));
341     TIC(t1);
342 
343     *perturbationVector = p1.VStack(p2);
344     DEBUG("z1j4: " << TOC(t1));
345     TIC(t1);
346     DEBUG("z1tot: " << TOC(t1_tot));
347   }
348 
349   /**
350    * Method for perturbation generation as described in "Implementing
351    *Token-Based Obfuscation under (Ring) LWE"
352    *
353    *@param n ring dimension
354    *@param s spectral norm
355    *@param sigma standard deviation
356    *@param &Tprime compact trapdoor matrix
357    *@param &dgg discrete Gaussian generator for error sampling
358    *@param &dggLargeSigma discrete Gaussian generator for perturbation vector
359    *sampling
360    *@param *perturbationVector perturbation vector;output of the function
361    */
362   static void SamplePertSquareMat(
363       size_t n, double s, double sigma, const RLWETrapdoorPair<Element> &Tprime,
364       const DggType &dgg, const DggType &dggLargeSigma,
365       shared_ptr<Matrix<Element>> perturbationVector) {
366     Matrix<Element> R = Tprime.m_r;
367     Matrix<Element> E = Tprime.m_e;
368 
369     const shared_ptr<ParmType> params = R(0, 0).GetParams();
370 
371     // k is the bit length
372     size_t k = R.GetCols();
373     size_t d = R.GetRows();
374 
375     Matrix<int64_t> p2ZVector([]() { return 0; }, n * k, d);
376 
377     double sigmaLarge = sqrt(s * s - sigma * sigma);
378 
379     // for distribution parameters up to the experimentally found threshold, use
380     // the Peikert's inversion method otherwise, use Karney's method
381     if (sigmaLarge > KARNEY_THRESHOLD) {
382       // Karney rejection sampling method
383       for (size_t i = 0; i < n * k; i++) {
384         for (size_t j = 0; j < d; j++) {
385           p2ZVector(i, j) = dgg.GenerateIntegerKarney(0, sigmaLarge);
386         }
387       }
388     } else {
389       // Peikert's inversion sampling method
390       std::shared_ptr<int64_t> dggVector =
391           dggLargeSigma.GenerateIntVector(n * k * d);
392 
393       for (size_t i = 0; i < n * k; i++) {
394         for (size_t j = 0; j < d; j++) {
395           p2ZVector(i, j) = (dggVector.get())[i * d + j];
396         }
397       }
398     }
399 
400     // create a matrix of d*k x d ring elements in coefficient representation
401     Matrix<Element> p2 =
402         SplitInt64IntoElements<Element>(p2ZVector.ExtractCol(0), n, params);
403     for (size_t i = 1; i < d; i++) {
404       p2.HStack(
405           SplitInt64IntoElements<Element>(p2ZVector.ExtractCol(i), n, params));
406     }
407 
408     // now converting to Format::EVALUATION representation before multiplication
409     p2.SetFormat(Format::EVALUATION);
410 
411     auto zero_alloc = Element::Allocator(params, Format::EVALUATION);
412 
413     Matrix<Element> A = R * (R.Transpose());  // d x d
414     Matrix<Element> B = R * (E.Transpose());  // d x d
415     Matrix<Element> D = E * (E.Transpose());  // d x d
416 
417     // Switch the ring elements (Polynomials) to coefficient representation
418     A.SetFormat(Format::COEFFICIENT);
419     B.SetFormat(Format::COEFFICIENT);
420     D.SetFormat(Format::COEFFICIENT);
421 
422     Matrix<Field2n> AF([&]() { return Field2n(n, Format::EVALUATION, true); },
423                        d, d);
424     Matrix<Field2n> BF([&]() { return Field2n(n, Format::EVALUATION, true); },
425                        d, d);
426     Matrix<Field2n> DF([&]() { return Field2n(n, Format::EVALUATION, true); },
427                        d, d);
428 
429     double scalarFactor = -sigma * sigma;
430 
431     for (size_t i = 0; i < d; i++) {
432       for (size_t j = 0; j < d; j++) {
433         AF(i, j) = Field2n(A(i, j));
434         AF(i, j) = AF(i, j).ScalarMult(scalarFactor);
435         BF(i, j) = Field2n(B(i, j));
436         BF(i, j) = BF(i, j).ScalarMult(scalarFactor);
437         DF(i, j) = Field2n(D(i, j));
438         DF(i, j) = DF(i, j).ScalarMult(scalarFactor);
439         if (i == j) {
440           AF(i, j) = AF(i, j) + s * s;
441           DF(i, j) = DF(i, j) + s * s;
442         }
443       }
444     }
445 
446     // converts the field elements to DFT representation
447     AF.SetFormat(Format::EVALUATION);
448     BF.SetFormat(Format::EVALUATION);
449     DF.SetFormat(Format::EVALUATION);
450 
451     // the dimension is 2d x d
452     Matrix<Element> Tp2 = (R.VStack(E)) * p2;
453 
454     // change to coefficient representation before converting to field elements
455     Tp2.SetFormat(Format::COEFFICIENT);
456 
457     Matrix<Element> p1(zero_alloc, 1, 1);
458 
459     for (size_t j = 0; j < d; j++) {
460       Matrix<Field2n> c([&]() { return Field2n(n, Format::COEFFICIENT); },
461                         2 * d, 1);
462 
463       for (size_t i = 0; i < d; i++) {
464         c(i, 0) = Field2n(Tp2(i, j)).ScalarMult(-sigma * sigma /
465                                                 (s * s - sigma * sigma));
466         c(i + d, 0) = Field2n(Tp2(i + d, j))
467                           .ScalarMult(-sigma * sigma / (s * s - sigma * sigma));
468       }
469 
470       auto p1ZVector =
471           std::make_shared<Matrix<int64_t>>([]() { return 0; }, n * 2 * d, 1);
472 
473       LatticeGaussSampUtility<Element>::SampleMat(AF, BF, DF, c, dgg,
474                                                   p1ZVector);
475 
476       if (j == 0)
477         p1 = SplitInt64IntoElements<Element>(*p1ZVector, n, params);
478       else
479         p1.HStack(SplitInt64IntoElements<Element>(*p1ZVector, n, params));
480     }
481 
482     p1.SetFormat(Format::EVALUATION);
483 
484     *perturbationVector = p1.VStack(p2);
485 
486     p1.SetFormat(Format::COEFFICIENT);
487   }
488 };
489 
490 }  // namespace lbcrypto
491 
492 #endif
493