1 // @file plaintext.h Represents and defines plaintext objects in Palisade.
2 // @author TPOC: contact@palisade-crypto.org
3 //
4 // @copyright Copyright (c) 2019, New Jersey Institute of Technology (NJIT)
5 // All rights reserved.
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are met:
8 // 1. Redistributions of source code must retain the above copyright notice,
9 // this list of conditions and the following disclaimer.
10 // 2. Redistributions in binary form must reproduce the above copyright notice,
11 // this list of conditions and the following disclaimer in the documentation
12 // and/or other materials provided with the distribution. THIS SOFTWARE IS
13 // PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
14 // IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
15 // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
16 // EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
17 // INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
18 // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
19 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
20 // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
22 // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 
24 #ifndef LBCRYPTO_UTILS_PLAINTEXT_H
25 #define LBCRYPTO_UTILS_PLAINTEXT_H
26 
27 #include <initializer_list>
28 #include <iostream>
29 #include <memory>
30 #include <string>
31 #include <vector>
32 
33 #include "encoding/encodingparams.h"
34 #include "lattice/backend.h"
35 #include "math/backend.h"
36 #include "utils/inttypes.h"
37 
38 using std::shared_ptr;
39 
40 namespace lbcrypto {
41 
42 enum PlaintextEncodings {
43   Unknown = 0,
44   CoefPacked,
45   Packed,
46   String,
47   CKKSPacked,
48 };
49 
50 inline std::ostream& operator<<(std::ostream& out, const PlaintextEncodings p) {
51   switch (p) {
52     case Unknown:
53       out << "Unknown";
54       break;
55     case CoefPacked:
56       out << "CoefPacked";
57       break;
58     case Packed:
59       out << "Packed";
60       break;
61     case String:
62       out << "String";
63       break;
64     case CKKSPacked:
65       out << "CKKSPacked";
66       break;
67   }
68   return out;
69 }
70 
71 class PlaintextImpl;
72 typedef shared_ptr<PlaintextImpl> Plaintext;
73 typedef shared_ptr<const PlaintextImpl> ConstPlaintext;
74 
75 /**
76  * @class PlaintextImpl
77  * @brief This class represents plaintext in the Palisade library.
78  *
79  * PlaintextImpl is primarily intended to be
80  * used as a container and in conjunction with specific encodings which inherit
81  * from this class which depend on the application the plaintext is used with.
82  * It provides virtual methods for encoding and decoding of data.
83  */
84 
85 enum PtxtPolyType { IsPoly, IsDCRTPoly, IsNativePoly };
86 
87 class PlaintextImpl {
88  protected:
89   bool isEncoded;
90   PtxtPolyType typeFlag;
91   EncodingParams encodingParams;
92 
93   mutable Poly encodedVector;
94   mutable NativePoly encodedNativeVector;
95   mutable DCRTPoly encodedVectorDCRT;
96 
97   static const int intCTOR = 0x01;
98   static const int vecintCTOR = 0x02;
99   static const int fracCTOR = 0x04;
100   static const int vecuintCTOR = 0x08;
101 
102   double scalingFactor;
103   size_t level;
104   size_t depth;
105 
106  public:
107   PlaintextImpl(shared_ptr<Poly::Params> vp, EncodingParams ep,
108                 bool isEncoded = false)
isEncoded(isEncoded)109       : isEncoded(isEncoded),
110         typeFlag(IsPoly),
111         encodingParams(ep),
112         encodedVector(vp, Format::COEFFICIENT),
113         scalingFactor(1),
114         level(0),
115         depth(1) {}
116 
117   PlaintextImpl(shared_ptr<NativePoly::Params> vp, EncodingParams ep,
118                 bool isEncoded = false)
isEncoded(isEncoded)119       : isEncoded(isEncoded),
120         typeFlag(IsNativePoly),
121         encodingParams(ep),
122         encodedNativeVector(vp, Format::COEFFICIENT),
123         scalingFactor(1),
124         level(0),
125         depth(1) {}
126 
127   PlaintextImpl(shared_ptr<DCRTPoly::Params> vp, EncodingParams ep,
128                 bool isEncoded = false)
isEncoded(isEncoded)129       : isEncoded(isEncoded),
130         typeFlag(IsDCRTPoly),
131         encodingParams(ep),
132         encodedVector(vp, Format::COEFFICIENT),
133         encodedVectorDCRT(vp, Format::COEFFICIENT),
134         scalingFactor(1),
135         level(0),
136         depth(1) {}
137 
PlaintextImpl(const PlaintextImpl & rhs)138   PlaintextImpl(const PlaintextImpl& rhs)
139       : isEncoded(rhs.isEncoded),
140         typeFlag(rhs.typeFlag),
141         encodingParams(rhs.encodingParams),
142         encodedVector(rhs.encodedVector),
143         encodedVectorDCRT(rhs.encodedVectorDCRT),
144         scalingFactor(rhs.scalingFactor),
145         level(rhs.level),
146         depth(rhs.depth) {}
147 
PlaintextImpl(const PlaintextImpl && rhs)148   PlaintextImpl(const PlaintextImpl&& rhs)
149       : isEncoded(rhs.isEncoded),
150         typeFlag(rhs.typeFlag),
151         encodingParams(std::move(rhs.encodingParams)),
152         encodedVector(std::move(rhs.encodedVector)),
153         encodedVectorDCRT(std::move(rhs.encodedVectorDCRT)),
154         scalingFactor(rhs.scalingFactor),
155         level(rhs.level),
156         depth(rhs.depth) {}
157 
~PlaintextImpl()158   virtual ~PlaintextImpl() {}
159 
160   /**
161    * GetEncodingType
162    * @return Encoding type used by this plaintext
163    */
164   virtual PlaintextEncodings GetEncodingType() const = 0;
165 
166   /**
167    * Get the scaling factor of the plaintext.
168    */
GetScalingFactor()169   double GetScalingFactor() const { return scalingFactor; }
170 
171   /**
172    * Set the scaling factor of the plaintext.
173    */
SetScalingFactor(double sf)174   void SetScalingFactor(double sf) { scalingFactor = sf; }
175 
176   /**
177    * IsEncoded
178    * @return true when encoding is done
179    */
IsEncoded()180   bool IsEncoded() const { return isEncoded; }
181 
182   /**
183    * GetEncodingParams
184    * @return Encoding params used with this plaintext
185    */
GetEncodingParams()186   const EncodingParams GetEncodingParams() const { return encodingParams; }
187 
188   /**
189    * Encode the plaintext into a polynomial
190    * @return true on success
191    */
192   virtual bool Encode() = 0;
193 
194   /**
195    * Decode the polynomial into the plaintext
196    * @return
197    */
198   virtual bool Decode() = 0;
199 
200   /**
201    * Calculate and return lower bound that can be encoded with the plaintext
202    * modulus the number to encode MUST be greater than this value
203    * @return floor(-p/2)
204    */
LowBound()205   int64_t LowBound() const {
206     uint64_t half = GetEncodingParams()->GetPlaintextModulus() >> 1;
207     bool odd = (GetEncodingParams()->GetPlaintextModulus() & 0x1) == 1;
208     int64_t bound = -1 * half;
209     if (odd) bound--;
210     return bound;
211   }
212 
213   /**
214    * Calculate and return upper bound that can be encoded with the plaintext
215    * modulus the number to encode MUST be less than or equal to this value
216    * @return floor(p/2)
217    */
HighBound()218   int64_t HighBound() const {
219     return GetEncodingParams()->GetPlaintextModulus() >> 1;
220   }
221 
222   /**
223    * SetFormat - allows format to be changed for PlaintextImpl evaluations
224    *
225    * @param fmt
226    */
SetFormat(Format fmt)227   void SetFormat(Format fmt) const {
228     if (typeFlag == IsPoly)
229       encodedVector.SetFormat(fmt);
230     else if (typeFlag == IsNativePoly)
231       encodedNativeVector.SetFormat(fmt);
232     else
233       encodedVectorDCRT.SetFormat(fmt);
234   }
235 
236   /**
237    * GetElement
238    * @return the Polynomial that the element was encoded into
239    */
240   template <typename Element>
241   Element& GetElement();
242 
243   template <typename Element>
244   const Element& GetElement() const;
245 
246   /**
247    * GetElementRingDimension
248    * @return ring dimension on the underlying element
249    */
GetElementRingDimension()250   usint GetElementRingDimension() const {
251     return typeFlag == IsPoly ? encodedVector.GetRingDimension()
252                               : (typeFlag == IsNativePoly
253                                      ? encodedNativeVector.GetRingDimension()
254                                      : encodedVectorDCRT.GetRingDimension());
255   }
256 
257   /**
258    * GetElementModulus
259    * @return modulus on the underlying elemenbt
260    */
GetElementModulus()261   const BigInteger GetElementModulus() const {
262     return typeFlag == IsPoly
263                ? encodedVector.GetModulus()
264                : (typeFlag == IsNativePoly
265                       ? BigInteger(encodedNativeVector.GetModulus())
266                       : encodedVectorDCRT.GetModulus());
267   }
268 
269   /**
270    * Get method to return the length of plaintext
271    *
272    * @return the length of the plaintext in terms of the number of bits.
273    */
274   virtual size_t GetLength() const = 0;
275 
276   /**
277    * resize the plaintext; only works for plaintexts that support a resizable
278    * vector (coefpacked)
279    * @param newSize
280    */
SetLength(size_t newSize)281   virtual void SetLength(size_t newSize) {
282     PALISADE_THROW(not_implemented_error, "resize not supported");
283   }
284 
285   /*
286    * Method to get the depth of a plaintext.
287    *
288    * @return the depth of the plaintext
289    */
GetDepth()290   size_t GetDepth() const { return depth; }
291 
292   /*
293    * Method to set the depth of a plaintext.
294    */
SetDepth(size_t d)295   void SetDepth(size_t d) { depth = d; }
296 
297   /*
298    * Method to get the level of a plaintext.
299    *
300    * @return the level of the plaintext
301    */
GetLevel()302   size_t GetLevel() const { return level; }
303 
304   /*
305    * Method to set the level of a plaintext.
306    */
SetLevel(size_t l)307   void SetLevel(size_t l) { level = l; }
308 
GetLogError()309   virtual double GetLogError() const {
310     PALISADE_THROW(not_available_error,
311                    "no estimate of noise available for the current scheme");
312   }
313 
GetLogPrecision()314   virtual double GetLogPrecision() const {
315     PALISADE_THROW(not_available_error,
316                    "no estimate of precision available for the current scheme");
317   }
318 
GetStringValue()319   virtual const std::string& GetStringValue() const {
320     PALISADE_THROW(type_error, "not a string");
321   }
GetCoefPackedValue()322   virtual const vector<int64_t>& GetCoefPackedValue() const {
323     PALISADE_THROW(type_error, "not a packed coefficient vector");
324   }
GetPackedValue()325   virtual const vector<int64_t>& GetPackedValue() const {
326     PALISADE_THROW(type_error, "not a packed coefficient vector");
327   }
GetCKKSPackedValue()328   virtual const std::vector<std::complex<double>>& GetCKKSPackedValue() const {
329     PALISADE_THROW(type_error, "not a packed vector of complex numbers");
330   }
GetRealPackedValue()331   virtual const std::vector<double> GetRealPackedValue() const {
332     PALISADE_THROW(type_error, "not a packed vector of real numbers");
333   }
SetStringValue(const std::string &)334   virtual void SetStringValue(const std::string&) {
335     PALISADE_THROW(type_error, "does not support a string");
336   }
SetIntVectorValue(const vector<int64_t> &)337   virtual void SetIntVectorValue(const vector<int64_t>&) {
338     PALISADE_THROW(type_error, "does not support an int vector");
339   }
340 
341   /**
342    * Method to compare two plaintext to test for equivalence.
343    * This method is called by operator==
344    *
345    * @param other - the other plaintext to compare to.
346    * @return whether the two plaintext are equivalent.
347    */
348   virtual bool CompareTo(const PlaintextImpl& other) const = 0;
349 
350   /**
351    * operator== for plaintexts.  This method makes sure the plaintexts are of
352    * the same type.
353    *
354    * @param other - the other plaintext to compare to.
355    * @return whether the two plaintext are the same.
356    */
357   bool operator==(const PlaintextImpl& other) const { return CompareTo(other); }
358 
359   bool operator!=(const PlaintextImpl& other) const {
360     return !(*this == other);
361   }
362 
363   /**
364    * operator<< for ostream integration - calls PrintValue
365    * @param out
366    * @param item
367    * @return
368    */
369   friend std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item);
370 
371   /**
372    * PrintValue is called by operator<<
373    * @param out
374    */
375   virtual void PrintValue(std::ostream& out) const = 0;
376 };
377 
378 inline std::ostream& operator<<(std::ostream& out, const PlaintextImpl& item) {
379   item.PrintValue(out);
380   return out;
381 }
382 
383 inline std::ostream& operator<<(std::ostream& out, const Plaintext item) {
384   item->PrintValue(out);
385   return out;
386 }
387 
388 inline bool operator==(const Plaintext p1, const Plaintext p2) {
389   return *p1 == *p2;
390 }
391 
392 inline bool operator!=(const Plaintext p1, const Plaintext p2) {
393   return *p1 != *p2;
394 }
395 
396 /**
397  * GetElement
398  * @return the Polynomial that the element was encoded into
399  */
400 template <>
401 inline const Poly& PlaintextImpl::GetElement<Poly>() const {
402   return encodedVector;
403 }
404 
405 template <>
406 inline Poly& PlaintextImpl::GetElement<Poly>() {
407   return encodedVector;
408 }
409 
410 /**
411  * GetElement
412  * @return the NativePolynomial that the element was encoded into
413  */
414 template <>
415 inline const NativePoly& PlaintextImpl::GetElement<NativePoly>() const {
416   return encodedNativeVector;
417 }
418 
419 template <>
420 inline NativePoly& PlaintextImpl::GetElement<NativePoly>() {
421   return encodedNativeVector;
422 }
423 
424 /**
425  * GetElement
426  * @return the DCRTPolynomial that the element was encoded into
427  */
428 template <>
429 inline const DCRTPoly& PlaintextImpl::GetElement<DCRTPoly>() const {
430   return encodedVectorDCRT;
431 }
432 
433 template <>
434 inline DCRTPoly& PlaintextImpl::GetElement<DCRTPoly>() {
435   return encodedVectorDCRT;
436 }
437 
438 }  // namespace lbcrypto
439 
440 #endif
441