1 /* Copyright (C) 2012-2020 IBM Corp.
2  * This program is Licensed under the Apache License, Version 2.0
3  * (the "License"); you may not use this file except in compliance
4  * with the License. You may obtain a copy of the License at
5  *   http://www.apache.org/licenses/LICENSE-2.0
6  * Unless required by applicable law or agreed to in writing, software
7  * distributed under the License is distributed on an "AS IS" BASIS,
8  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9  * See the License for the specific language governing permissions and
10  * limitations under the License. See accompanying LICENSE file.
11  */
12 // debugging.cpp - debugging utilities
13 #include <NTL/xdouble.h>
14 #include <helib/debugging.h>
15 #include <helib/norms.h>
16 #include <helib/Context.h>
17 #include <helib/Ctxt.h>
18 #include <helib/EncryptedArray.h>
19 //#include <helib/powerful.h>
20 
21 namespace helib {
22 
23 SecKey* dbgKey = nullptr;
24 std::shared_ptr<const EncryptedArray> dbgEa = nullptr;
25 NTL::ZZX dbg_ptxt;
26 
27 // return the ratio between the real noise <sk,ct> and the estimated one
realToEstimatedNoise(const Ctxt & ctxt,const SecKey & sk)28 double realToEstimatedNoise(const Ctxt& ctxt, const SecKey& sk)
29 {
30   NTL::xdouble noiseEst = ctxt.totalNoiseBound();
31   NTL::xdouble actualNoise = embeddingLargestCoeff(ctxt, sk);
32 
33   return NTL::conv<double>(actualNoise / noiseEst);
34 }
35 
log2_realToEstimatedNoise(const Ctxt & ctxt,const SecKey & sk)36 double log2_realToEstimatedNoise(const Ctxt& ctxt, const SecKey& sk)
37 {
38   NTL::xdouble noiseEst = ctxt.totalNoiseBound();
39   NTL::xdouble actualNoise = embeddingLargestCoeff(ctxt, sk);
40 
41   return NTL::log(actualNoise / noiseEst) / std::log(2.0);
42 }
43 
44 // check that real-to-estimated ratio is not too large, print warning otherwise
checkNoise(const Ctxt & ctxt,const SecKey & sk,const std::string & msg,double thresh)45 void checkNoise(const Ctxt& ctxt,
46                 const SecKey& sk,
47                 const std::string& msg,
48                 double thresh)
49 {
50   double ratio;
51   if ((ratio = realToEstimatedNoise(ctxt, sk)) > thresh) {
52     std::cerr << "\n*** too much noise: " << msg << ": " << ratio << "\n";
53   }
54 }
55 
56 // Decrypt and find the l-infinity norm of the result in canonical embedding
embeddingLargestCoeff(const Ctxt & ctxt,const SecKey & sk)57 NTL::xdouble embeddingLargestCoeff(const Ctxt& ctxt, const SecKey& sk)
58 {
59   const Context& context = ctxt.getContext();
60   NTL::ZZX p, pp;
61   sk.Decrypt(p, ctxt, pp);
62   return embeddingLargestCoeff(pp, context.zMStar);
63 }
64 
decryptAndPrint(std::ostream & s,const Ctxt & ctxt,const SecKey & sk,const EncryptedArray & ea,long flags)65 void decryptAndPrint(std::ostream& s,
66                      const Ctxt& ctxt,
67                      const SecKey& sk,
68                      const EncryptedArray& ea,
69                      long flags)
70 {
71   const Context& context = ctxt.getContext();
72   std::vector<NTL::ZZX> ptxt;
73   NTL::ZZX p, pp;
74   sk.Decrypt(p, ctxt, pp);
75 
76   NTL::xdouble modulus = NTL::xexp(context.logOfProduct(ctxt.getPrimeSet()));
77   NTL::xdouble actualNoise =
78       embeddingLargestCoeff(pp, ctxt.getContext().zMStar);
79   NTL::xdouble noiseEst = ctxt.totalNoiseBound();
80 
81   s << "plaintext space mod " << ctxt.getPtxtSpace()
82     << ", capacity=" << ctxt.capacity() << ", \n           |noise|=q*"
83     << (actualNoise / modulus) << ", |noiseBound|=q*" << (noiseEst / modulus);
84   if (ctxt.isCKKS()) {
85     s << ", \n           ratFactor=" << ctxt.getRatFactor()
86       << ", ptxtMag=" << ctxt.getPtxtMag()
87       << ", scaledErr=" << (actualNoise / ctxt.getRatFactor());
88   }
89   s << std::endl;
90 
91   if (flags & FLAG_PRINT_ZZX) {
92     s << "   before mod-p reduction=";
93     printZZX(s, pp) << std::endl;
94   }
95   if (flags & FLAG_PRINT_POLY) {
96     s << "   after mod-p reduction=";
97     printZZX(s, p) << std::endl;
98   }
99   if (flags & FLAG_PRINT_VEC) { // decode to a vector of ZZX
100     ea.decode(ptxt, p);
101     if (ea.getAlMod().getTag() == PA_zz_p_tag &&
102         ctxt.getPtxtSpace() != ea.getAlMod().getPPowR()) {
103       long g = NTL::GCD(ctxt.getPtxtSpace(), ea.getAlMod().getPPowR());
104       for (long i = 0; i < ea.size(); i++)
105         PolyRed(ptxt[i], g, true);
106     }
107     s << "   decoded to ";
108     if (deg(p) < 40) // just print the whole thing
109       s << ptxt << std::endl;
110     else if (ptxt.size() == 1) // a single slot
111       printZZX(s, ptxt[0]) << std::endl;
112     else { // print first and last slots
113       printZZX(s, ptxt[0], 20) << "--";
114       printZZX(s, ptxt[ptxt.size() - 1], 20) << std::endl;
115     }
116   } else if (flags & FLAG_PRINT_DVEC) { // decode to a vector of doubles
117     const EncryptedArrayCx& eacx = ea.getCx();
118     std::vector<double> v;
119     eacx.rawDecrypt(ctxt, sk, v);
120     printVec(s << "           ", v, 20) << std::endl;
121   } else if (flags & FLAG_PRINT_XVEC) { // decode to a vector of complex
122     const EncryptedArrayCx& eacx = ea.getCx();
123     std::vector<cx_double> v;
124     eacx.rawDecrypt(ctxt, sk, v);
125     printVec(s << "           ", v, 20) << std::endl;
126   }
127 }
128 
decryptAndCompare(const Ctxt & ctxt,const SecKey & sk,const EncryptedArray & ea,const PlaintextArray & pa)129 bool decryptAndCompare(const Ctxt& ctxt,
130                        const SecKey& sk,
131                        const EncryptedArray& ea,
132                        const PlaintextArray& pa)
133 {
134   PlaintextArray ppa(ea);
135   ea.decrypt(ctxt, sk, ppa);
136 
137   return equals(ea, pa, ppa);
138 }
139 
140 // Compute decryption with.without mod-q on a vector of ZZX'es,
141 // useful when debugging bootstrapping (after "raw mod-switch")
rawDecrypt(NTL::ZZX & plaintxt,const std::vector<NTL::ZZX> & zzParts,const DoubleCRT & sKey,long q)142 void rawDecrypt(NTL::ZZX& plaintxt,
143                 const std::vector<NTL::ZZX>& zzParts,
144                 const DoubleCRT& sKey,
145                 long q)
146 {
147   // Set to zzParts[0] + sKey * zzParts[1] "over the integers"
148   DoubleCRT ptxt = sKey;
149   ptxt *= zzParts[1];
150   ptxt += zzParts[0];
151 
152   // convert to coefficient representation
153   ptxt.toPoly(plaintxt);
154 
155   if (q > 1)
156     PolyRed(plaintxt, q, false /*reduce to [-q/2,1/2]*/);
157 }
158 
CheckCtxt(const Ctxt & c,const char * label)159 void CheckCtxt(const Ctxt& c, const char* label)
160 {
161   std::cerr << "  " << label << ", capacity=" << c.capacity();
162 
163   if (!c.isCKKS())
164     std::cerr << ", p^r=" << c.getPtxtSpace();
165 
166   if (dbgKey) {
167     double ratio = log2_realToEstimatedNoise(c, *dbgKey);
168     std::cerr << ", log2(noise/bound)=" << ratio;
169     if (ratio > 0)
170       std::cerr << " BAD-BOUND";
171   }
172 
173 #if 0
174   // This is not really a useful test
175 
176   if (dbgKey && c.getContext().isBootstrappable()) {
177     Ctxt c1(c);
178     //c1.dropSmallAndSpecialPrimes();
179 
180     const Context& context = c1.getContext();
181     const RecryptData& rcData = context.rcData;
182     const PAlgebra& palg = context.zMStar;
183 
184     NTL::ZZX p, pp;
185     dbgKey->Decrypt(p, c1, pp);
186     NTL::Vec<NTL::ZZ> powerful;
187     rcData.p2dConv->ZZXtoPowerful(powerful, pp);
188 
189     NTL::ZZ q;
190     q = context.productOfPrimes(c1.getPrimeSet());
191     vecRed(powerful, powerful, q, false);
192 
193     NTL::ZZX pp_alt;
194     rcData.p2dConv->powerfulToZZX(pp_alt, powerful);
195 
196     NTL::xdouble max_coeff = NTL::conv<NTL::xdouble>(largestCoeff(pp));
197     NTL::xdouble max_pwrfl = NTL::conv<NTL::xdouble>(largestCoeff(powerful));
198     NTL::xdouble max_canon = embeddingLargestCoeff(pp_alt, palg);
199     double ratio = log(max_pwrfl/max_canon)/log(2.0);
200 
201     //cerr << ", max_coeff=" << max_coeff;
202     //cerr << ", max_pwrfl=" << max_pwrfl;
203     //cerr << ", max_canon=" << max_canon;
204     std::cerr << ", log2(max_pwrfl/max_canon)=" << ratio;
205     if (ratio > 0) std::cerr << " BAD-BOUND";
206   }
207 #endif
208 
209   std::cerr << std::endl;
210 }
211 
212 } // namespace helib
213