1 /*
2 * Copyright 2018 Google LLC.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * https://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef RLWE_INT256_H_
17 #define RLWE_INT256_H_
18
19 #include "absl/numeric/int128.h"
20 #include "integral_types.h"
21
22 namespace rlwe {
23
24 struct uint256_pod;
25
26 // An unsigned 256-bit integer type. Thread-compatible.
27 class uint256 {
28 public:
29 constexpr uint256();
30 constexpr uint256(absl::uint128 top, absl::uint128 bottom);
31
32 // Implicit type conversion is allowed so these behave like familiar int types
33 #ifndef SWIG
34 constexpr uint256(int bottom);
35 constexpr uint256(Uint32 bottom);
36 #endif
37 constexpr uint256(Uint8 bottom);
38 constexpr uint256(unsigned long bottom);
39 constexpr uint256(unsigned long long bottom);
40 constexpr uint256(absl::uint128 bottom);
41 constexpr uint256(const uint256_pod &val);
42
43 // Conversion operators to other arithmetic types
44 constexpr explicit operator bool() const;
45 constexpr explicit operator char() const;
46 constexpr explicit operator signed char() const;
47 constexpr explicit operator unsigned char() const;
48 constexpr explicit operator char16_t() const;
49 constexpr explicit operator char32_t() const;
50 constexpr explicit operator short() const;
51
52 constexpr explicit operator unsigned short() const;
53 constexpr explicit operator int() const;
54 constexpr explicit operator unsigned int() const;
55 constexpr explicit operator long() const;
56
57 constexpr explicit operator unsigned long() const;
58
59 constexpr explicit operator long long() const;
60
61 constexpr explicit operator unsigned long long() const;
62 constexpr explicit operator absl::int128() const;
63 constexpr explicit operator absl::uint128() const;
64 explicit operator float() const;
65 explicit operator double() const;
66 explicit operator long double() const;
67
68 // Trivial copy constructor, assignment operator and destructor.
69
70 void Initialize(absl::uint128 top, absl::uint128 bottom);
71
72 // Arithmetic operators.
73 uint256& operator+=(const uint256& b);
74 uint256& operator-=(const uint256& b);
75 uint256& operator*=(const uint256& b);
76 // Long division/modulo for uint256.
77 uint256& operator/=(const uint256& b);
78 uint256& operator%=(const uint256& b);
79 uint256 operator++(int);
80 uint256 operator--(int);
81 uint256& operator<<=(int);
82 uint256& operator>>=(int);
83 uint256& operator&=(const uint256& b);
84 uint256& operator|=(const uint256& b);
85 uint256& operator^=(const uint256& b);
86 uint256& operator++();
87 uint256& operator--();
88
89 friend absl::uint128 Uint256Low128(const uint256& v);
90 friend absl::uint128 Uint256High128(const uint256& v);
91
92 // We add "std::" to avoid including all of port.h.
93 friend std::ostream& operator<<(std::ostream& o, const uint256& b);
94
95 private:
96 static void DivModImpl(uint256 dividend, uint256 divisor,
97 uint256* quotient_ret, uint256* remainder_ret);
98
99 // Little-endian memory order optimizations can benefit from
100 // having lo_ first, hi_ last.
101 // See util/endian/endian.h and Load256/Store256 for storing a uint256.
102 // Adding any new members will cause sizeof(uint256) tests to fail.
103 absl::uint128 lo_;
104 absl::uint128 hi_;
105
106 // Uint256Max()
107 //
108 // Returns the highest value for a 256-bit unsigned integer.
109 friend constexpr uint256 Uint256Max();
110
111 // Not implemented, just declared for catching automatic type conversions.
112 uint256(Uint16);
113 uint256(float v);
114 uint256(double v);
115 };
116
Uint256Max()117 constexpr uint256 Uint256Max() {
118 return uint256((std::numeric_limits<absl::uint128>::max)(),
119 (std::numeric_limits<absl::uint128>::max)());
120 }
121
122 // This is a POD form of uint256 which can be used for static variables which
123 // need to be operated on as uint256.
124 struct uint256_pod {
125 // Note: The ordering of fields is different than 'class uint256' but the
126 // same as its 2-arg constructor. This enables more obvious initialization
127 // of static instances, which is the primary reason for this struct in the
128 // first place. This does not seem to defeat any optimizations wrt
129 // operations involving this struct.
130 absl::uint128 hi;
131 absl::uint128 lo;
132 };
133
134 constexpr uint256_pod kuint256max = {absl::Uint128Max(), absl::Uint128Max()};
135
136 // allow uint256 to be logged
137 extern std::ostream& operator<<(std::ostream& o, const uint256& b);
138
139 // Methods to access low and high pieces of 256-bit value.
140 // Defined externally from uint256 to facilitate conversion
141 // to native 256-bit types when compilers support them.
Uint256Low128(const uint256 & v)142 inline absl::uint128 Uint256Low128(const uint256& v) { return v.lo_; }
Uint256High128(const uint256 & v)143 inline absl::uint128 Uint256High128(const uint256& v) { return v.hi_; }
144
145 // --------------------------------------------------------------------------
146 // Implementation details follow
147 // --------------------------------------------------------------------------
148 inline bool operator==(const uint256& lhs, const uint256& rhs) {
149 return (Uint256Low128(lhs) == Uint256Low128(rhs) &&
150 Uint256High128(lhs) == Uint256High128(rhs));
151 }
152 inline bool operator!=(const uint256& lhs, const uint256& rhs) {
153 return !(lhs == rhs);
154 }
155
uint256()156 inline constexpr uint256::uint256() : lo_(0), hi_(0) {}
uint256(absl::uint128 top,absl::uint128 bottom)157 inline constexpr uint256::uint256(absl::uint128 top, absl::uint128 bottom)
158 : lo_(bottom), hi_(top) {}
uint256(const uint256_pod & v)159 inline constexpr uint256::uint256(const uint256_pod& v)
160 : lo_(v.lo), hi_(v.hi) {}
uint256(absl::uint128 bottom)161 inline constexpr uint256::uint256(absl::uint128 bottom) : lo_(bottom), hi_(0) {}
162 #ifndef SWIG
uint256(int bottom)163 inline constexpr uint256::uint256(int bottom)
164 : lo_(bottom), hi_((bottom < 0) ? -1 : 0) {}
uint256(Uint32 bottom)165 inline constexpr uint256::uint256(Uint32 bottom) : lo_(bottom), hi_(0) {}
166 #endif
uint256(Uint8 bottom)167 inline constexpr uint256::uint256(Uint8 bottom) : lo_(bottom), hi_(0) {}
168
uint256(unsigned long bottom)169 inline constexpr uint256::uint256(unsigned long bottom)
170 : lo_(bottom), hi_(0) {}
171
uint256(unsigned long long bottom)172 inline constexpr uint256::uint256(unsigned long long bottom)
173 : lo_(bottom), hi_(0) {}
174
Initialize(absl::uint128 top,absl::uint128 bottom)175 inline void uint256::Initialize(absl::uint128 top, absl::uint128 bottom) {
176 hi_ = top;
177 lo_ = bottom;
178 }
179
180 // Conversion operators to integer types.
181
182 constexpr uint256::operator bool() const { return lo_ || hi_; }
183
184 constexpr uint256::operator char() const { return static_cast<char>(lo_); }
185
186 constexpr uint256::operator signed char() const {
187 return static_cast<signed char>(lo_);
188 }
189
190 constexpr uint256::operator unsigned char() const {
191 return static_cast<unsigned char>(lo_);
192 }
193
char16_t()194 constexpr uint256::operator char16_t() const {
195 return static_cast<char16_t>(lo_);
196 }
197
char32_t()198 constexpr uint256::operator char32_t() const {
199 return static_cast<char32_t>(lo_);
200 }
201
202
203 constexpr uint256::operator short() const { return static_cast<short>(lo_); }
204
205 constexpr uint256::operator unsigned short() const {
206 return static_cast<unsigned short>(lo_);
207 }
208
209 constexpr uint256::operator int() const { return static_cast<int>(lo_); }
210
211 constexpr uint256::operator unsigned int() const {
212 return static_cast<unsigned int>(lo_);
213 }
214
215
216 constexpr uint256::operator long() const { return static_cast<long>(lo_); }
217
218 constexpr uint256::operator unsigned long() const {
219 return static_cast<unsigned long>(lo_);
220 }
221
222 constexpr uint256::operator long long() const {
223 return static_cast<long long>(lo_);
224 }
225
226 constexpr uint256::operator unsigned long long() const {
227 return static_cast<unsigned long long>(lo_);
228 }
229
230
uint128()231 constexpr uint256::operator absl::uint128() const { return lo_; }
int128()232 constexpr uint256::operator absl::int128() const {
233 return static_cast<absl::int128>(lo_);
234 }
235
236 // Conversion operators to floating point types.
237
238 inline uint256::operator float() const {
239 return static_cast<float>(lo_) + std::ldexp(static_cast<float>(hi_), 128);
240 }
241
242 inline uint256::operator double() const {
243 return static_cast<double>(lo_) + std::ldexp(static_cast<double>(hi_), 128);
244 }
245
246 inline uint256::operator long double() const {
247 return static_cast<long double>(lo_) +
248 std::ldexp(static_cast<long double>(hi_), 128);
249 }
250
251 // Comparison operators.
252
253 #define CMP256(op) \
254 inline bool operator op(const uint256& lhs, const uint256& rhs) { \
255 return (Uint256High128(lhs) == Uint256High128(rhs)) \
256 ? (Uint256Low128(lhs) op Uint256Low128(rhs)) \
257 : (Uint256High128(lhs) op Uint256High128(rhs)); \
258 }
259
260 CMP256(<)
261 CMP256(>)
262 CMP256(>=)
263 CMP256(<=)
264
265 #undef CMP256
266
267 // Unary operators
268
269 inline uint256 operator-(const uint256& val) {
270 const absl::uint128 hi_flip = ~Uint256High128(val);
271 const absl::uint128 lo_flip = ~Uint256Low128(val);
272 const absl::uint128 lo_add = lo_flip + 1;
273 if (lo_add < lo_flip) {
274 return uint256(hi_flip + 1, lo_add);
275 }
276 return uint256(hi_flip, lo_add);
277 }
278
279 inline bool operator!(const uint256& val) {
280 return !Uint256High128(val) && !Uint256Low128(val);
281 }
282
283 // Logical operators.
284
285 inline uint256 operator~(const uint256& val) {
286 return uint256(~Uint256High128(val), ~Uint256Low128(val));
287 }
288
289 #define LOGIC256(op) \
290 inline uint256 operator op(const uint256& lhs, const uint256& rhs) { \
291 return uint256(Uint256High128(lhs) op Uint256High128(rhs), \
292 Uint256Low128(lhs) op Uint256Low128(rhs)); \
293 }
294
295 LOGIC256(|)
296 LOGIC256(&)
297 LOGIC256(^)
298
299 #undef LOGIC256
300
301 #define LOGICASSIGN256(op) \
302 inline uint256& uint256::operator op(const uint256& b) { \
303 hi_ op b.hi_; \
304 lo_ op b.lo_; \
305 return *this; \
306 }
307
308 LOGICASSIGN256(|=)
309 LOGICASSIGN256(&=)
310 LOGICASSIGN256(^=)
311
312 #undef LOGICASSIGN256
313
314 // Shift operators.
315
316 inline uint256 operator<<(const uint256& val, int amount) {
317 uint256 out(val);
318 out <<= amount;
319 return out;
320 }
321
322 inline uint256 operator>>(const uint256& val, int amount) {
323 uint256 out(val);
324 out >>= amount;
325 return out;
326 }
327
328 inline uint256& uint256::operator<<=(int amount) {
329 // uint128 shifts of >= 128 are undefined, so we will need some special-casing
330 if (amount < 128) {
331 if (amount != 0) {
332 hi_ = (hi_ << amount) | (lo_ >> (128 - amount));
333 lo_ = lo_ << amount;
334 }
335 } else if (amount < 256) {
336 hi_ = lo_ << (amount - 128);
337 lo_ = 0;
338 } else {
339 hi_ = 0;
340 lo_ = 0;
341 }
342 return *this;
343 }
344
345 inline uint256& uint256::operator>>=(int amount) {
346 // uint128 shifts of >= 128 are undefined, so we will need some special-casing
347 if (amount < 128) {
348 if (amount != 0) {
349 lo_ = (lo_ >> amount) | (hi_ << (128 - amount));
350 hi_ = hi_ >> amount;
351 }
352 } else if (amount < 256) {
353 lo_ = hi_ >> (amount - 128);
354 hi_ = 0;
355 } else {
356 lo_ = 0;
357 hi_ = 0;
358 }
359 return *this;
360 }
361
362 inline uint256 operator+(const uint256& lhs, const uint256& rhs) {
363 return uint256(lhs) += rhs;
364 }
365
366 inline uint256 operator-(const uint256& lhs, const uint256& rhs) {
367 return uint256(lhs) -= rhs;
368 }
369
370 inline uint256 operator*(const uint256& lhs, const uint256& rhs) {
371 return uint256(lhs) *= rhs;
372 }
373
374 inline uint256 operator/(const uint256& lhs, const uint256& rhs) {
375 return uint256(lhs) /= rhs;
376 }
377
378 inline uint256 operator%(const uint256& lhs, const uint256& rhs) {
379 return uint256(lhs) %= rhs;
380 }
381
382 inline uint256& uint256::operator+=(const uint256& b) {
383 hi_ += b.hi_;
384 absl::uint128 lolo = lo_ + b.lo_;
385 if (lolo < lo_)
386 ++hi_;
387 lo_ = lolo;
388 return *this;
389 }
390
391 inline uint256& uint256::operator-=(const uint256& b) {
392 hi_ -= b.hi_;
393 if (b.lo_ > lo_)
394 --hi_;
395 lo_ -= b.lo_;
396 return *this;
397 }
398
399 inline uint256& uint256::operator*=(const uint256& b) {
400 // Computes the product c = a * b modulo 2^256.
401 //
402 // We have that
403 // a = [a.hi_ || a.lo_] and b = [b.hi_ || b.lo_]
404 // where hi_, lo_ are 128-bit numbers. Further, we have that
405 // a.lo_ = [a64 || a00] and b.lo_ = [b64 || b00]
406 // where a64, a00, b64, b00 are 64-bit numbers.
407 //
408 // The product c = (a * b mod 2^256) is equal to
409 // (a.hi_ * b.lo_ + a64 * b64 + b.hi_ * a.lo_ mod 2^128) * 2^128 +
410 // (a64 * b00 + a00 * b64) * 2^64 +
411 // (a00 * b00)
412 //
413 // The first and last lines can be computed without worrying about the
414 // carries, and then we add the two elements from the second line.
415 absl::uint128 a64 = absl::Uint128High64(lo_);
416 absl::uint128 a00 = absl::Uint128Low64(lo_);
417 absl::uint128 b64 = absl::Uint128High64(b.lo_);
418 absl::uint128 b00 = absl::Uint128Low64(b.lo_);
419
420 // Compute the high order and low order part of c (safe to ignore carry bits).
421 this->hi_ = hi_ * b.lo_ + a64 * b64 + lo_ * b.hi_;
422 this->lo_ = a00 * b00;
423
424 // add middle term and capture carry
425 uint256 middle_term = uint256(a64 * b00) + uint256(a00 * b64);
426 *this += middle_term << 64;
427 return *this;
428 }
429
430 inline uint256 uint256::operator++(int) {
431 uint256 tmp(*this);
432 lo_++;
433 if (lo_ == 0) hi_++; // If there was a wrap around, increase the high word.
434 return tmp;
435 }
436
437 inline uint256 uint256::operator--(int) {
438 uint256 tmp(*this);
439 if (lo_ == 0) hi_--; // If it wraps around, decrease the high word.
440 lo_--;
441 return tmp;
442 }
443
444 inline uint256& uint256::operator++() {
445 lo_++;
446 if (lo_ == 0) hi_++; // If there was a wrap around, increase the high word.
447 return *this;
448 }
449
450 inline uint256& uint256::operator--() {
451 if (lo_ == 0) hi_--; // If it wraps around, decrease the high word.
452 lo_--;
453 return *this;
454 }
455
456 } // namespace rlwe
457
458 // Specialized numeric_limits for uint256.
459 namespace std {
460 template <>
461 class numeric_limits<rlwe::uint256> {
462 public:
463 static constexpr bool is_specialized = true;
464 static constexpr bool is_signed = false;
465 static constexpr bool is_integer = true;
466 static constexpr bool is_exact = true;
467 static constexpr bool has_infinity = false;
468 static constexpr bool has_quiet_NaN = false;
469 static constexpr bool has_signaling_NaN = false;
470 static constexpr float_denorm_style has_denorm = denorm_absent;
471 static constexpr bool has_denorm_loss = false;
472 static constexpr float_round_style round_style = round_toward_zero;
473 static constexpr bool is_iec559 = false;
474 static constexpr bool is_bounded = true;
475 static constexpr bool is_modulo = true;
476 static constexpr int digits = 256;
477 static constexpr int digits10 = 77;
478 static constexpr int max_digits10 = 0;
479 static constexpr int radix = 2;
480 static constexpr int min_exponent = 0;
481 static constexpr int min_exponent10 = 0;
482 static constexpr int max_exponent = 0;
483 static constexpr int max_exponent10 = 0;
484 static constexpr bool traps = numeric_limits<absl::uint128>::traps;
485 static constexpr bool tinyness_before = false;
486
uint256(min)487 static constexpr rlwe::uint256(min)() { return 0; }
lowest()488 static constexpr rlwe::uint256 lowest() { return 0; }
uint256(max)489 static constexpr rlwe::uint256(max)() { return rlwe::Uint256Max(); }
epsilon()490 static constexpr rlwe::uint256 epsilon() { return 0; }
round_error()491 static constexpr rlwe::uint256 round_error() { return 0; }
infinity()492 static constexpr rlwe::uint256 infinity() { return 0; }
quiet_NaN()493 static constexpr rlwe::uint256 quiet_NaN() { return 0; }
signaling_NaN()494 static constexpr rlwe::uint256 signaling_NaN() { return 0; }
denorm_min()495 static constexpr rlwe::uint256 denorm_min() { return 0; }
496 };
497 } // namespace std
498 #endif // RLWE_INT256_H_
499