1 // @file trapdoor.cpp Provides the utility for sampling trapdoor lattices as
2 // described in https://eprint.iacr.org/2017/844.pdf
3 // https://eprint.iacr.org/2018/946, and "Implementing Token-Based Obfuscation
4 // under (Ring) LWE" as described in https://eprint.iacr.org/2018/1222.pdf.
5 // @author TPOC: contact@palisade-crypto.org
6 //
7 // @copyright Copyright (c) 2019, New Jersey Institute of Technology (NJIT)
8 // All rights reserved.
9 // Redistribution and use in source and binary forms, with or without
10 // modification, are permitted provided that the following conditions are met:
11 // 1. Redistributions of source code must retain the above copyright notice,
12 // this list of conditions and the following disclaimer.
13 // 2. Redistributions in binary form must reproduce the above copyright notice,
14 // this list of conditions and the following disclaimer in the documentation
15 // and/or other materials provided with the distribution. THIS SOFTWARE IS
16 // PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
17 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 // EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
20 // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 
27 #ifndef _SRC_LIB_CRYPTO_SIGNATURE_TRAPDOOR_CPP
28 #define _SRC_LIB_CRYPTO_SIGNATURE_TRAPDOOR_CPP
29 
30 #include "lattice/trapdoor.h"
31 
32 namespace lbcrypto {
33 
34 // On-line stage of pre-image sampling (includes only G-sampling)
35 
36 template <class Element>
GaussSampOnline(size_t n,size_t k,const Matrix<Element> & A,const RLWETrapdoorPair<Element> & T,const Element & u,DggType & dgg,const shared_ptr<Matrix<Element>> pHat,int64_t base)37 Matrix<Element> RLWETrapdoorUtility<Element>::GaussSampOnline(
38     size_t n, size_t k, const Matrix<Element>& A,
39     const RLWETrapdoorPair<Element>& T, const Element& u, DggType& dgg,
40     const shared_ptr<Matrix<Element>> pHat, int64_t base) {
41   const shared_ptr<ParmType> params = u.GetParams();
42   auto zero_alloc = Element::Allocator(params, Format::EVALUATION);
43 
44   double c = (base + 1) * SIGMA;
45 
46   const IntType& modulus = A(0, 0).GetModulus();
47 
48   // It is assumed that A has dimension 1 x (k + 2) and pHat has the dimension
49   // of (k + 2) x 1 perturbedSyndrome is in the Format::EVALUATION
50   // representation
51   Element perturbedSyndrome = u - (A.Mult(*pHat))(0, 0);
52 
53   Matrix<int64_t> zHatBBI([]() { return 0; }, k, n);
54 
55   perturbedSyndrome.SetFormat(Format::COEFFICIENT);
56 
57   LatticeGaussSampUtility<Element>::GaussSampGqArbBase(
58       perturbedSyndrome, c, k, modulus, base, dgg, &zHatBBI);
59 
60   // Convert zHat from a matrix of integers to a vector of Element ring elements
61   // zHat is in the coefficient representation
62   Matrix<Element> zHat = SplitInt64AltIntoElements<Element>(zHatBBI, n, params);
63   // Now converting it to the Format::EVALUATION representation before
64   // multiplication
65   zHat.SetFormat(Format::EVALUATION);
66 
67   Matrix<Element> zHatPrime(zero_alloc, k + 2, 1);
68 
69   zHatPrime(0, 0) = (*pHat)(0, 0) + T.m_e.Mult(zHat)(0, 0);
70   zHatPrime(1, 0) = (*pHat)(1, 0) + T.m_r.Mult(zHat)(0, 0);
71 
72   for (size_t row = 2; row < k + 2; ++row)
73     zHatPrime(row, 0) = (*pHat)(row, 0) + zHat(row - 2, 0);
74 
75   return zHatPrime;
76 }
77 
78 // Offline stage of pre-image sampling (perturbation sampling)
79 
80 template <class Element>
GaussSampOffline(size_t n,size_t k,const RLWETrapdoorPair<Element> & T,DggType & dgg,DggType & dggLargeSigma,int64_t base)81 shared_ptr<Matrix<Element>> RLWETrapdoorUtility<Element>::GaussSampOffline(
82     size_t n, size_t k, const RLWETrapdoorPair<Element>& T, DggType& dgg,
83     DggType& dggLargeSigma, int64_t base) {
84   const shared_ptr<ParmType> params = T.m_e(0, 0).GetParams();
85   auto zero_alloc = Element::Allocator(params, Format::EVALUATION);
86 
87   double c = (base + 1) * SIGMA;
88 
89   // spectral bound s
90   double s = SPECTRAL_BOUND(n, k, base);
91 
92   // perturbation vector in evaluation representation
93   auto result = std::make_shared<Matrix<Element>>(zero_alloc, k + 2, 1);
94   ZSampleSigmaP(n, s, c, T, dgg, dggLargeSigma, result);
95 
96   return result;
97 }
98 
99 template <>
ZSampleSigmaP(size_t n,double s,double sigma,const RLWETrapdoorPair<DCRTPoly> & Tprime,const DCRTPoly::DggType & dgg,const DCRTPoly::DggType & dggLargeSigma,shared_ptr<Matrix<DCRTPoly>> perturbationVector)100 inline void RLWETrapdoorUtility<DCRTPoly>::ZSampleSigmaP(
101     size_t n, double s, double sigma, const RLWETrapdoorPair<DCRTPoly>& Tprime,
102     const DCRTPoly::DggType& dgg, const DCRTPoly::DggType& dggLargeSigma,
103     shared_ptr<Matrix<DCRTPoly>> perturbationVector) {
104   DEBUG_FLAG(false);
105   TimeVar t1, t1_tot;
106 
107   TIC(t1);
108   TIC(t1_tot);
109   Matrix<DCRTPoly> Tprime0 = Tprime.m_e;
110   Matrix<DCRTPoly> Tprime1 = Tprime.m_r;
111   // k is the bit length
112   size_t k = Tprime0.GetCols();
113 
114   const shared_ptr<DCRTPoly::Params> params = Tprime0(0, 0).GetParams();
115 
116   DEBUG("z1a: " << TOC(t1));  // 0
117   TIC(t1);
118   // all three Polynomials are initialized with "0" coefficients
119   NativePoly va((*params)[0], Format::EVALUATION, 1);
120   NativePoly vb((*params)[0], Format::EVALUATION, 1);
121   NativePoly vd((*params)[0], Format::EVALUATION, 1);
122 
123   for (size_t i = 0; i < k; i++) {
124     va += (NativePoly)Tprime0(0, i).GetElementAtIndex(0) *
125           Tprime0(0, i).Transpose().GetElementAtIndex(0);
126     vb += (NativePoly)Tprime1(0, i).GetElementAtIndex(0) *
127           Tprime0(0, i).Transpose().GetElementAtIndex(0);
128     vd += (NativePoly)Tprime1(0, i).GetElementAtIndex(0) *
129           Tprime1(0, i).Transpose().GetElementAtIndex(0);
130   }
131   DEBUG("z1b: " << TOC(t1));  // 9
132   TIC(t1);
133 
134   // Switch the ring elements (Polynomials) to coefficient representation
135   va.SetFormat(Format::COEFFICIENT);
136   vb.SetFormat(Format::COEFFICIENT);
137   vd.SetFormat(Format::COEFFICIENT);
138 
139   DEBUG("z1c: " << TOC(t1));  // 5
140   TIC(t1);
141 
142   // Create field elements from ring elements
143   Field2n a(va), b(vb), d(vd);
144 
145   double scalarFactor = -s * s * sigma * sigma / (s * s - sigma * sigma);
146 
147   a = a.ScalarMult(scalarFactor);
148   b = b.ScalarMult(scalarFactor);
149   d = d.ScalarMult(scalarFactor);
150 
151   a = a + s * s;
152   d = d + s * s;
153   DEBUG("z1d: " << TOC(t1));  // 0
154   TIC(t1);
155 
156   // converts the field elements to DFT representation
157   a.SetFormat(Format::EVALUATION);
158   b.SetFormat(Format::EVALUATION);
159   d.SetFormat(Format::EVALUATION);
160   DEBUG("z1e: " << TOC(t1));  // 0
161   TIC(t1);
162 
163   Matrix<int64_t> p2ZVector([]() { return 0; }, n * k, 1);
164 
165   double sigmaLarge = sqrt(s * s - sigma * sigma);
166 
167   // for distribution parameters up to KARNEY_THRESHOLD (experimentally found
168   // threshold) use the Peikert's inversion method otherwise, use Karney's
169   // method
170   if (sigmaLarge > KARNEY_THRESHOLD) {
171     // Karney rejection sampling method
172     for (size_t i = 0; i < n * k; i++) {
173       p2ZVector(i, 0) = dgg.GenerateIntegerKarney(0, sigmaLarge);
174     }
175   } else {
176     // Peikert's inversion sampling method
177     std::shared_ptr<int64_t> dggVector = dggLargeSigma.GenerateIntVector(n * k);
178 
179     for (size_t i = 0; i < n * k; i++) {
180       p2ZVector(i, 0) = (dggVector.get())[i];
181     }
182   }
183   DEBUG("z1f1: " << TOC(t1));
184   TIC(t1);
185 
186   // create k ring elements in coefficient representation
187   Matrix<DCRTPoly> p2 = SplitInt64IntoElements<DCRTPoly>(p2ZVector, n, params);
188   DEBUG("z1f2: " << TOC(t1));
189   TIC(t1);
190 
191   // now converting to Format::EVALUATION representation before multiplication
192   p2.SetFormat(Format::EVALUATION);
193 
194   DEBUG("z1g: " << TOC(t1));  // 17
195 
196   TIC(t1);
197 
198   auto zero_alloc = NativePoly::Allocator((*params)[0], Format::EVALUATION);
199   Matrix<NativePoly> Tp2(zero_alloc, 2, 1);
200   for (unsigned int i = 0; i < k; i++) {
201     Tp2(0, 0) += Tprime0(0, i).GetElementAtIndex(0) *
202                  (NativePoly)p2(i, 0).GetElementAtIndex(0);
203     Tp2(1, 0) += Tprime1(0, i).GetElementAtIndex(0) *
204                  (NativePoly)p2(i, 0).GetElementAtIndex(0);
205   }
206 
207   DEBUG("z1h2: " << TOC(t1));
208   TIC(t1);
209   // change to coefficient representation before converting to field elements
210   Tp2.SetFormat(Format::COEFFICIENT);
211   DEBUG("z1h3: " << TOC(t1));
212   TIC(t1);
213 
214   Matrix<Field2n> c([]() { return Field2n(); }, 2, 1);
215 
216   c(0, 0) =
217       Field2n(Tp2(0, 0)).ScalarMult(-sigma * sigma / (s * s - sigma * sigma));
218   c(1, 0) =
219       Field2n(Tp2(1, 0)).ScalarMult(-sigma * sigma / (s * s - sigma * sigma));
220 
221   auto p1ZVector =
222       std::make_shared<Matrix<int64_t>>([]() { return 0; }, n * 2, 1);
223   DEBUG("z1i: " << TOC(t1));
224   TIC(t1);
225 
226   LatticeGaussSampUtility<DCRTPoly>::ZSampleSigma2x2(a, b, d, c, dgg,
227                                                      p1ZVector);
228   DEBUG("z1j1: " << TOC(t1));  // 14
229   TIC(t1);
230 
231   // create 2 ring elements in coefficient representation
232   Matrix<DCRTPoly> p1 = SplitInt64IntoElements<DCRTPoly>(*p1ZVector, n, params);
233   DEBUG("z1j2: " << TOC(t1));
234   TIC(t1);
235 
236   p1.SetFormat(Format::EVALUATION);
237   DEBUG("z1j3: " << TOC(t1));
238   TIC(t1);
239 
240   *perturbationVector = p1.VStack(p2);
241   DEBUG("z1j4: " << TOC(t1));
242   TIC(t1);
243   DEBUG("z1tot: " << TOC(t1_tot));
244 }
245 
246 }  // namespace lbcrypto
247 #endif
248