1 /*
2  * Copyright 2018 Google LLC.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     https://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef RLWE_INT256_H_
17 #define RLWE_INT256_H_
18 
19 #include "absl/numeric/int128.h"
20 #include "integral_types.h"
21 
22 namespace rlwe {
23 
24 struct uint256_pod;
25 
26 // An unsigned 256-bit integer type. Thread-compatible.
27 class uint256 {
28  public:
29   constexpr uint256();
30   constexpr uint256(absl::uint128 top, absl::uint128 bottom);
31 
32   // Implicit type conversion is allowed so these behave like familiar int types
33 #ifndef SWIG
34   constexpr uint256(int bottom);
35   constexpr uint256(Uint32 bottom);
36 #endif
37   constexpr uint256(Uint8 bottom);
38   constexpr uint256(unsigned long bottom);
39   constexpr uint256(unsigned long long bottom);
40   constexpr uint256(absl::uint128 bottom);
41   constexpr uint256(const uint256_pod &val);
42 
43   // Conversion operators to other arithmetic types
44   constexpr explicit operator bool() const;
45   constexpr explicit operator char() const;
46   constexpr explicit operator signed char() const;
47   constexpr explicit operator unsigned char() const;
48   constexpr explicit operator char16_t() const;
49   constexpr explicit operator char32_t() const;
50   constexpr explicit operator short() const;
51 
52   constexpr explicit operator unsigned short() const;
53   constexpr explicit operator int() const;
54   constexpr explicit operator unsigned int() const;
55   constexpr explicit operator long() const;
56 
57   constexpr explicit operator unsigned long() const;
58 
59   constexpr explicit operator long long() const;
60 
61   constexpr explicit operator unsigned long long() const;
62   constexpr explicit operator absl::int128() const;
63   constexpr explicit operator absl::uint128() const;
64   explicit operator float() const;
65   explicit operator double() const;
66   explicit operator long double() const;
67 
68   // Trivial copy constructor, assignment operator and destructor.
69 
70   void Initialize(absl::uint128 top, absl::uint128 bottom);
71 
72   // Arithmetic operators.
73   uint256& operator+=(const uint256& b);
74   uint256& operator-=(const uint256& b);
75   uint256& operator*=(const uint256& b);
76   // Long division/modulo for uint256.
77   uint256& operator/=(const uint256& b);
78   uint256& operator%=(const uint256& b);
79   uint256 operator++(int);
80   uint256 operator--(int);
81   uint256& operator<<=(int);
82   uint256& operator>>=(int);
83   uint256& operator&=(const uint256& b);
84   uint256& operator|=(const uint256& b);
85   uint256& operator^=(const uint256& b);
86   uint256& operator++();
87   uint256& operator--();
88 
89   friend absl::uint128 Uint256Low128(const uint256& v);
90   friend absl::uint128 Uint256High128(const uint256& v);
91 
92   // We add "std::" to avoid including all of port.h.
93   friend std::ostream& operator<<(std::ostream& o, const uint256& b);
94 
95  private:
96   static void DivModImpl(uint256 dividend, uint256 divisor,
97                          uint256* quotient_ret, uint256* remainder_ret);
98 
99   // Little-endian memory order optimizations can benefit from
100   // having lo_ first, hi_ last.
101   // See util/endian/endian.h and Load256/Store256 for storing a uint256.
102   // Adding any new members will cause sizeof(uint256) tests to fail.
103   absl::uint128 lo_;
104   absl::uint128 hi_;
105 
106   // Uint256Max()
107   //
108   // Returns the highest value for a 256-bit unsigned integer.
109   friend constexpr uint256 Uint256Max();
110 
111   // Not implemented, just declared for catching automatic type conversions.
112   uint256(Uint16);
113   uint256(float v);
114   uint256(double v);
115 };
116 
Uint256Max()117 constexpr uint256 Uint256Max() {
118   return uint256((std::numeric_limits<absl::uint128>::max)(),
119                  (std::numeric_limits<absl::uint128>::max)());
120 }
121 
122 // This is a POD form of uint256 which can be used for static variables which
123 // need to be operated on as uint256.
124 struct uint256_pod {
125   // Note: The ordering of fields is different than 'class uint256' but the
126   // same as its 2-arg constructor.  This enables more obvious initialization
127   // of static instances, which is the primary reason for this struct in the
128   // first place.  This does not seem to defeat any optimizations wrt
129   // operations involving this struct.
130   absl::uint128 hi;
131   absl::uint128 lo;
132 };
133 
134 constexpr uint256_pod kuint256max = {absl::Uint128Max(), absl::Uint128Max()};
135 
136 // allow uint256 to be logged
137 extern std::ostream& operator<<(std::ostream& o, const uint256& b);
138 
139 // Methods to access low and high pieces of 256-bit value.
140 // Defined externally from uint256 to facilitate conversion
141 // to native 256-bit types when compilers support them.
Uint256Low128(const uint256 & v)142 inline absl::uint128 Uint256Low128(const uint256& v) { return v.lo_; }
Uint256High128(const uint256 & v)143 inline absl::uint128 Uint256High128(const uint256& v) { return v.hi_; }
144 
145 // --------------------------------------------------------------------------
146 //                      Implementation details follow
147 // --------------------------------------------------------------------------
148 inline bool operator==(const uint256& lhs, const uint256& rhs) {
149   return (Uint256Low128(lhs) == Uint256Low128(rhs) &&
150           Uint256High128(lhs) == Uint256High128(rhs));
151 }
152 inline bool operator!=(const uint256& lhs, const uint256& rhs) {
153   return !(lhs == rhs);
154 }
155 
uint256()156 inline constexpr uint256::uint256() : lo_(0), hi_(0) {}
uint256(absl::uint128 top,absl::uint128 bottom)157 inline constexpr uint256::uint256(absl::uint128 top, absl::uint128 bottom)
158     : lo_(bottom), hi_(top) {}
uint256(const uint256_pod & v)159 inline constexpr uint256::uint256(const uint256_pod& v)
160     : lo_(v.lo), hi_(v.hi) {}
uint256(absl::uint128 bottom)161 inline constexpr uint256::uint256(absl::uint128 bottom) : lo_(bottom), hi_(0) {}
162 #ifndef SWIG
uint256(int bottom)163 inline constexpr uint256::uint256(int bottom)
164       : lo_(bottom), hi_((bottom < 0) ? -1 : 0) {}
uint256(Uint32 bottom)165 inline constexpr uint256::uint256(Uint32 bottom) : lo_(bottom), hi_(0) {}
166 #endif
uint256(Uint8 bottom)167 inline constexpr uint256::uint256(Uint8 bottom) : lo_(bottom), hi_(0) {}
168 
uint256(unsigned long bottom)169 inline constexpr uint256::uint256(unsigned long bottom)
170     : lo_(bottom), hi_(0) {}
171 
uint256(unsigned long long bottom)172 inline constexpr uint256::uint256(unsigned long long bottom)
173     : lo_(bottom), hi_(0) {}
174 
Initialize(absl::uint128 top,absl::uint128 bottom)175 inline void uint256::Initialize(absl::uint128 top, absl::uint128 bottom) {
176   hi_ = top;
177   lo_ = bottom;
178 }
179 
180 // Conversion operators to integer types.
181 
182 constexpr uint256::operator bool() const { return lo_ || hi_; }
183 
184 constexpr uint256::operator char() const { return static_cast<char>(lo_); }
185 
186 constexpr uint256::operator signed char() const {
187   return static_cast<signed char>(lo_);
188 }
189 
190 constexpr uint256::operator unsigned char() const {
191   return static_cast<unsigned char>(lo_);
192 }
193 
char16_t()194 constexpr uint256::operator char16_t() const {
195   return static_cast<char16_t>(lo_);
196 }
197 
char32_t()198 constexpr uint256::operator char32_t() const {
199   return static_cast<char32_t>(lo_);
200 }
201 
202 
203 constexpr uint256::operator short() const { return static_cast<short>(lo_); }
204 
205 constexpr uint256::operator unsigned short() const {
206   return static_cast<unsigned short>(lo_);
207 }
208 
209 constexpr uint256::operator int() const { return static_cast<int>(lo_); }
210 
211 constexpr uint256::operator unsigned int() const {
212   return static_cast<unsigned int>(lo_);
213 }
214 
215 
216 constexpr uint256::operator long() const { return static_cast<long>(lo_); }
217 
218 constexpr uint256::operator unsigned long() const {
219   return static_cast<unsigned long>(lo_);
220 }
221 
222 constexpr uint256::operator long long() const {
223   return static_cast<long long>(lo_);
224 }
225 
226 constexpr uint256::operator unsigned long long() const {
227   return static_cast<unsigned long long>(lo_);
228 }
229 
230 
uint128()231 constexpr uint256::operator absl::uint128() const { return lo_; }
int128()232 constexpr uint256::operator absl::int128() const {
233   return static_cast<absl::int128>(lo_);
234 }
235 
236 // Conversion operators to floating point types.
237 
238 inline uint256::operator float() const {
239   return static_cast<float>(lo_) + std::ldexp(static_cast<float>(hi_), 128);
240 }
241 
242 inline uint256::operator double() const {
243   return static_cast<double>(lo_) + std::ldexp(static_cast<double>(hi_), 128);
244 }
245 
246 inline uint256::operator long double() const {
247   return static_cast<long double>(lo_) +
248          std::ldexp(static_cast<long double>(hi_), 128);
249 }
250 
251 // Comparison operators.
252 
253 #define CMP256(op)                                                  \
254   inline bool operator op(const uint256& lhs, const uint256& rhs) { \
255     return (Uint256High128(lhs) == Uint256High128(rhs))             \
256                ? (Uint256Low128(lhs) op Uint256Low128(rhs))         \
257                : (Uint256High128(lhs) op Uint256High128(rhs));      \
258   }
259 
260 CMP256(<)
261 CMP256(>)
262 CMP256(>=)
263 CMP256(<=)
264 
265 #undef CMP256
266 
267 // Unary operators
268 
269 inline uint256 operator-(const uint256& val) {
270   const absl::uint128 hi_flip = ~Uint256High128(val);
271   const absl::uint128 lo_flip = ~Uint256Low128(val);
272   const absl::uint128 lo_add = lo_flip + 1;
273   if (lo_add < lo_flip) {
274     return uint256(hi_flip + 1, lo_add);
275   }
276   return uint256(hi_flip, lo_add);
277 }
278 
279 inline bool operator!(const uint256& val) {
280   return !Uint256High128(val) && !Uint256Low128(val);
281 }
282 
283 // Logical operators.
284 
285 inline uint256 operator~(const uint256& val) {
286   return uint256(~Uint256High128(val), ~Uint256Low128(val));
287 }
288 
289 #define LOGIC256(op)                                                   \
290   inline uint256 operator op(const uint256& lhs, const uint256& rhs) { \
291     return uint256(Uint256High128(lhs) op Uint256High128(rhs),         \
292                    Uint256Low128(lhs) op Uint256Low128(rhs));          \
293   }
294 
295 LOGIC256(|)
296 LOGIC256(&)
297 LOGIC256(^)
298 
299 #undef LOGIC256
300 
301 #define LOGICASSIGN256(op)                                 \
302   inline uint256& uint256::operator op(const uint256& b) { \
303     hi_ op b.hi_;                                          \
304     lo_ op b.lo_;                                          \
305     return *this;                                          \
306   }
307 
308 LOGICASSIGN256(|=)
309 LOGICASSIGN256(&=)
310 LOGICASSIGN256(^=)
311 
312 #undef LOGICASSIGN256
313 
314 // Shift operators.
315 
316 inline uint256 operator<<(const uint256& val, int amount) {
317   uint256 out(val);
318   out <<= amount;
319   return out;
320 }
321 
322 inline uint256 operator>>(const uint256& val, int amount) {
323   uint256 out(val);
324   out >>= amount;
325   return out;
326 }
327 
328 inline uint256& uint256::operator<<=(int amount) {
329   // uint128 shifts of >= 128 are undefined, so we will need some special-casing
330   if (amount < 128) {
331     if (amount != 0) {
332       hi_ = (hi_ << amount) | (lo_ >> (128 - amount));
333       lo_ = lo_ << amount;
334     }
335   } else if (amount < 256) {
336     hi_ = lo_ << (amount - 128);
337     lo_ = 0;
338   } else {
339     hi_ = 0;
340     lo_ = 0;
341   }
342   return *this;
343 }
344 
345 inline uint256& uint256::operator>>=(int amount) {
346   // uint128 shifts of >= 128 are undefined, so we will need some special-casing
347   if (amount < 128) {
348     if (amount != 0) {
349       lo_ = (lo_ >> amount) | (hi_ << (128 - amount));
350       hi_ = hi_ >> amount;
351     }
352   } else if (amount < 256) {
353     lo_ = hi_ >> (amount - 128);
354     hi_ = 0;
355   } else {
356     lo_ = 0;
357     hi_ = 0;
358   }
359   return *this;
360 }
361 
362 inline uint256 operator+(const uint256& lhs, const uint256& rhs) {
363   return uint256(lhs) += rhs;
364 }
365 
366 inline uint256 operator-(const uint256& lhs, const uint256& rhs) {
367   return uint256(lhs) -= rhs;
368 }
369 
370 inline uint256 operator*(const uint256& lhs, const uint256& rhs) {
371   return uint256(lhs) *= rhs;
372 }
373 
374 inline uint256 operator/(const uint256& lhs, const uint256& rhs) {
375   return uint256(lhs) /= rhs;
376 }
377 
378 inline uint256 operator%(const uint256& lhs, const uint256& rhs) {
379   return uint256(lhs) %= rhs;
380 }
381 
382 inline uint256& uint256::operator+=(const uint256& b) {
383   hi_ += b.hi_;
384   absl::uint128 lolo = lo_ + b.lo_;
385   if (lolo < lo_)
386     ++hi_;
387   lo_ = lolo;
388   return *this;
389 }
390 
391 inline uint256& uint256::operator-=(const uint256& b) {
392   hi_ -= b.hi_;
393   if (b.lo_ > lo_)
394     --hi_;
395   lo_ -= b.lo_;
396   return *this;
397 }
398 
399 inline uint256& uint256::operator*=(const uint256& b) {
400   // Computes the product c = a * b modulo 2^256.
401   //
402   // We have that
403   //   a = [a.hi_ || a.lo_] and b = [b.hi_ || b.lo_]
404   // where hi_, lo_ are 128-bit numbers. Further, we have that
405   //   a.lo_ = [a64 || a00] and b.lo_ = [b64 || b00]
406   // where a64, a00, b64, b00 are 64-bit numbers.
407   //
408   // The product c = (a * b mod 2^256) is equal to
409   //   (a.hi_ * b.lo_ + a64 * b64 + b.hi_ * a.lo_ mod 2^128) * 2^128 +
410   //   (a64 * b00 + a00 * b64) * 2^64 +
411   //   (a00 * b00)
412   //
413   // The first and last lines can be computed without worrying about the
414   // carries, and then we add the two elements from the second line.
415   absl::uint128 a64 = absl::Uint128High64(lo_);
416   absl::uint128 a00 = absl::Uint128Low64(lo_);
417   absl::uint128 b64 = absl::Uint128High64(b.lo_);
418   absl::uint128 b00 = absl::Uint128Low64(b.lo_);
419 
420   // Compute the high order and low order part of c (safe to ignore carry bits).
421   this->hi_ = hi_ * b.lo_ + a64 * b64 + lo_ * b.hi_;
422   this->lo_ = a00 * b00;
423 
424   // add middle term and capture carry
425   uint256 middle_term = uint256(a64 * b00) + uint256(a00 * b64);
426   *this += middle_term << 64;
427   return *this;
428 }
429 
430 inline uint256 uint256::operator++(int) {
431   uint256 tmp(*this);
432   lo_++;
433   if (lo_ == 0) hi_++;  // If there was a wrap around, increase the high word.
434   return tmp;
435 }
436 
437 inline uint256 uint256::operator--(int) {
438   uint256 tmp(*this);
439   if (lo_ == 0) hi_--;  // If it wraps around, decrease the high word.
440   lo_--;
441   return tmp;
442 }
443 
444 inline uint256& uint256::operator++() {
445   lo_++;
446   if (lo_ == 0) hi_++;  // If there was a wrap around, increase the high word.
447   return *this;
448 }
449 
450 inline uint256& uint256::operator--() {
451   if (lo_ == 0) hi_--;  // If it wraps around, decrease the high word.
452   lo_--;
453   return *this;
454 }
455 
456 }  // namespace rlwe
457 
458 // Specialized numeric_limits for uint256.
459 namespace std {
460 template <>
461 class numeric_limits<rlwe::uint256> {
462  public:
463   static constexpr bool is_specialized = true;
464   static constexpr bool is_signed = false;
465   static constexpr bool is_integer = true;
466   static constexpr bool is_exact = true;
467   static constexpr bool has_infinity = false;
468   static constexpr bool has_quiet_NaN = false;
469   static constexpr bool has_signaling_NaN = false;
470   static constexpr float_denorm_style has_denorm = denorm_absent;
471   static constexpr bool has_denorm_loss = false;
472   static constexpr float_round_style round_style = round_toward_zero;
473   static constexpr bool is_iec559 = false;
474   static constexpr bool is_bounded = true;
475   static constexpr bool is_modulo = true;
476   static constexpr int digits = 256;
477   static constexpr int digits10 = 77;
478   static constexpr int max_digits10 = 0;
479   static constexpr int radix = 2;
480   static constexpr int min_exponent = 0;
481   static constexpr int min_exponent10 = 0;
482   static constexpr int max_exponent = 0;
483   static constexpr int max_exponent10 = 0;
484   static constexpr bool traps = numeric_limits<absl::uint128>::traps;
485   static constexpr bool tinyness_before = false;
486 
uint256(min)487   static constexpr rlwe::uint256(min)() { return 0; }
lowest()488   static constexpr rlwe::uint256 lowest() { return 0; }
uint256(max)489   static constexpr rlwe::uint256(max)() { return rlwe::Uint256Max(); }
epsilon()490   static constexpr rlwe::uint256 epsilon() { return 0; }
round_error()491   static constexpr rlwe::uint256 round_error() { return 0; }
infinity()492   static constexpr rlwe::uint256 infinity() { return 0; }
quiet_NaN()493   static constexpr rlwe::uint256 quiet_NaN() { return 0; }
signaling_NaN()494   static constexpr rlwe::uint256 signaling_NaN() { return 0; }
denorm_min()495   static constexpr rlwe::uint256 denorm_min() { return 0; }
496 };
497 }  // namespace std
498 #endif  // RLWE_INT256_H_
499