1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 
4 #pragma once
5 
6 #include <stdint.h>
7 
8 #include <limits>
9 #include <vector>
10 
11 #include "hexl/util/check.hpp"
12 #include "hexl/util/compiler.hpp"
13 
14 namespace intel {
15 namespace hexl {
16 
17 /// @brief Pre-computes a Barrett factor with which modular multiplication can
18 /// be performed more efficiently
19 class MultiplyFactor {
20  public:
21   MultiplyFactor() = default;
22 
23   /// @brief Computes and stores the Barrett factor floor((operand << bit_shift)
24   /// / modulus). This is useful when modular multiplication of the form
25   /// (x * operand) mod modulus is performed with same modulus and operand
26   /// several times. Note, passing operand=1 can be used to pre-compute a
27   /// Barrett factor for multiplications of the form (x * y) mod modulus, where
28   /// only the modulus is re-used across calls to modular multiplication.
MultiplyFactor(uint64_t operand,uint64_t bit_shift,uint64_t modulus)29   MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus)
30       : m_operand(operand) {
31     HEXL_CHECK(operand <= modulus, "operand " << operand
32                                               << " must be less than modulus "
33                                               << modulus);
34     HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64,
35                "Unsupported BitShift " << bit_shift);
36     uint64_t op_hi = operand >> (64 - bit_shift);
37     uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift);
38 
39     m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus);
40   }
41 
42   /// @brief Returns the pre-computed Barrett factor
BarrettFactor() const43   inline uint64_t BarrettFactor() const { return m_barrett_factor; }
44 
45   /// @brief Returns the operand corresponding to the Barrett factor
Operand() const46   inline uint64_t Operand() const { return m_operand; }
47 
48  private:
49   uint64_t m_operand;
50   uint64_t m_barrett_factor;
51 };
52 
53 /// @brief Returns whether or not num is a power of two
IsPowerOfTwo(uint64_t num)54 inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); }
55 
56 /// @brief Returns floor(log2(x))
Log2(uint64_t x)57 inline uint64_t Log2(uint64_t x) { return MSB(x); }
58 
IsPowerOfFour(uint64_t num)59 inline bool IsPowerOfFour(uint64_t num) {
60   return IsPowerOfTwo(num) && (Log2(num) % 2 == 0);
61 }
62 
63 /// @brief Returns the maximum value that can be represented using \p bits bits
MaximumValue(uint64_t bits)64 inline uint64_t MaximumValue(uint64_t bits) {
65   HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits);
66   if (bits == 64) {
67     return (std::numeric_limits<uint64_t>::max)();
68   }
69   return (1ULL << bits) - 1;
70 }
71 
72 /// @brief Reverses the bits
73 /// @param[in] x Input to reverse
74 /// @param[in] bit_width Number of bits in the input; must be >= MSB(x)
75 /// @return The bit-reversed representation of \p x using \p bit_width bits
76 uint64_t ReverseBits(uint64_t x, uint64_t bit_width);
77 
78 /// @brief Returns x^{-1} mod modulus
79 /// @details Requires x % modulus != 0
80 uint64_t InverseMod(uint64_t x, uint64_t modulus);
81 
82 /// @brief Returns (x * y) mod modulus
83 /// @details Assumes x, y < modulus
84 uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus);
85 
86 /// @brief Returns (x * y) mod modulus
87 /// @param[in] x
88 /// @param[in] y
89 /// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus)
90 /// @param[in] modulus
91 uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon,
92                      uint64_t modulus);
93 
94 /// @brief Returns (x + y) mod modulus
95 /// @details Assumes x, y < modulus
96 uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
97 
98 /// @brief Returns (x - y) mod modulus
99 /// @details Assumes x, y < modulus
100 uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
101 
102 /// @brief Returns base^exp mod modulus
103 uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus);
104 
105 /// @brief Returns whether or not root is a degree-th root of unity mod modulus
106 /// @param[in] root Root of unity to check
107 /// @param[in] degree Degree of root of unity; must be a power of two
108 /// @param[in] modulus Modulus of finite field
109 bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus);
110 
111 /// @brief Tries to return a primitive degree-th root of unity
112 /// @details Returns 0 or throws an error if no root is found
113 uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus);
114 
115 /// @brief Returns whether or not root is a degree-th root of unity
116 /// @param[in] degree Must be a power of two
117 /// @param[in] modulus Modulus of finite field
118 uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus);
119 
120 /// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 *
121 /// modulus]
122 /// @param[in] x
123 /// @param[in] y_operand also denoted y
124 /// @param[in] modulus
125 /// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y
126 /// << BitShift) / modulus)
127 template <int BitShift>
MultiplyModLazy(uint64_t x,uint64_t y_operand,uint64_t y_barrett_factor,uint64_t modulus)128 inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand,
129                                 uint64_t y_barrett_factor, uint64_t modulus) {
130   HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand
131                                                << " must be less than modulus "
132                                                << modulus);
133   HEXL_CHECK(
134       modulus <= MaximumValue(BitShift),
135       "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift));
136   HEXL_CHECK(x <= MaximumValue(BitShift),
137              "Operand " << x << " exceeds bound " << MaximumValue(BitShift));
138 
139   uint64_t Q = MultiplyUInt64Hi<BitShift>(x, y_barrett_factor);
140   return y_operand * x - Q * modulus;
141 }
142 
143 /// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 *
144 /// modulus]
145 /// @param[in] x
146 /// @param[in] y
147 /// @param[in] modulus
148 template <int BitShift>
MultiplyModLazy(uint64_t x,uint64_t y,uint64_t modulus)149 inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) {
150   HEXL_CHECK(BitShift == 64 || BitShift == 52,
151              "Unsupported BitShift " << BitShift);
152   HEXL_CHECK(x <= MaximumValue(BitShift),
153              "Operand " << x << " exceeds bound " << MaximumValue(BitShift));
154   HEXL_CHECK(y < modulus,
155              "y " << y << " must be less than modulus " << modulus);
156   HEXL_CHECK(
157       modulus <= MaximumValue(BitShift),
158       "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift));
159 
160   uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor();
161   return MultiplyModLazy<BitShift>(x, y, y_barrett, modulus);
162 }
163 
164 /// @brief Adds two unsigned 64-bit integers
165 /// @param operand1 Number to add
166 /// @param operand2 Number to add
167 /// @param result Stores the sum
168 /// @return The carry bit
AddUInt64(uint64_t operand1,uint64_t operand2,uint64_t * result)169 inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2,
170                                uint64_t* result) {
171   *result = operand1 + operand2;
172   return static_cast<unsigned char>(*result < operand1);
173 }
174 
175 /// @brief Returns whether or not the input is prime
176 bool IsPrime(uint64_t n);
177 
178 /// @brief Generates a list of num_primes primes in the range [2^(bit_size),
179 // 2^(bit_size+1)]. Ensures each prime q satisfies
180 // q % (2*ntt_size+1)) == 1
181 /// @param[in] num_primes Number of primes to generate
182 /// @param[in] bit_size Bit size of each prime
183 /// @param[in] prefer_small_primes When true, returns primes starting from
184 /// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1)
185 /// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must
186 /// be a power of two less than 2^bit_size.
187 std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
188                                      bool prefer_small_primes,
189                                      size_t ntt_size = 1);
190 
191 /// @brief Returns input mod modulus, computed via 64-bit Barrett reduction
192 /// @param[in] input
193 /// @param[in] modulus
194 /// @param[in] q_barr floor(2^64 / modulus)
195 template <int OutputModFactor = 1>
BarrettReduce64(uint64_t input,uint64_t modulus,uint64_t q_barr)196 uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) {
197   HEXL_CHECK(modulus != 0, "modulus == 0");
198   uint64_t q = MultiplyUInt64Hi<64>(input, q_barr);
199   uint64_t q_times_input = input - q * modulus;
200   if (OutputModFactor == 2) {
201     return q_times_input;
202   } else {
203     return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input;
204   }
205 }
206 
207 /// @brief Returns x mod modulus, assuming x < InputModFactor * modulus
208 /// @param[in] x
209 /// @param[in] modulus also denoted q
210 /// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4
211 /// or 8
212 /// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor
213 /// == 8
214 template <int InputModFactor>
ReduceMod(uint64_t x,uint64_t modulus,const uint64_t * twice_modulus=nullptr,const uint64_t * four_times_modulus=nullptr)215 uint64_t ReduceMod(uint64_t x, uint64_t modulus,
216                    const uint64_t* twice_modulus = nullptr,
217                    const uint64_t* four_times_modulus = nullptr) {
218   HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 ||
219                  InputModFactor == 4 || InputModFactor == 8,
220              "InputModFactor should be 1, 2, 4, or 8");
221   if (InputModFactor == 1) {
222     return x;
223   }
224   if (InputModFactor == 2) {
225     if (x >= modulus) {
226       x -= modulus;
227     }
228     return x;
229   }
230   if (InputModFactor == 4) {
231     HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr");
232     if (x >= *twice_modulus) {
233       x -= *twice_modulus;
234     }
235     if (x >= modulus) {
236       x -= modulus;
237     }
238     return x;
239   }
240   if (InputModFactor == 8) {
241     HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr");
242     HEXL_CHECK(four_times_modulus != nullptr,
243                "four_times_modulus should not be nullptr");
244 
245     if (x >= *four_times_modulus) {
246       x -= *four_times_modulus;
247     }
248     if (x >= *twice_modulus) {
249       x -= *twice_modulus;
250     }
251     if (x >= modulus) {
252       x -= modulus;
253     }
254     return x;
255   }
256   HEXL_CHECK(false, "Should be unreachable");
257   return x;
258 }
259 
260 /// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm,
261 /// also known as Montgomery reduction.
262 /// @param[in] r
263 /// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q.
264 /// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R.
265 /// @param[in] mod_R_msk take r last bits to apply mod R.
266 /// @param[in] T_hi of T = ab in the range [0, Rq − 1].
267 /// @param[in] T_lo of T.
268 /// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q
269 template <int BitShift>
MontgomeryReduce(uint64_t T_hi,uint64_t T_lo,uint64_t q,int r,uint64_t mod_R_msk,uint64_t inv_mod)270 inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q,
271                                  int r, uint64_t mod_R_msk, uint64_t inv_mod) {
272   HEXL_CHECK(BitShift == 64 || BitShift == 52,
273              "Unsupported BitShift " << BitShift);
274   HEXL_CHECK((1ULL << r) > static_cast<uint64_t>(q),
275              "R value should be greater than q = " << static_cast<uint64_t>(q));
276 
277   uint64_t mq_hi;
278   uint64_t mq_lo;
279 
280   uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk;
281   MultiplyUInt64(m, q, &mq_hi, &mq_lo);
282 
283   if (BitShift == 52) {
284     mq_hi = (mq_hi << 12) | (mq_lo >> 52);
285     mq_lo &= (1ULL << 52) - 1;
286   }
287 
288   uint64_t t_hi;
289   uint64_t t_lo;
290 
291   // first 64bit block
292   t_lo = T_lo + mq_lo;
293   unsigned int carry = static_cast<unsigned int>(t_lo < T_lo);
294   t_hi = T_hi + mq_hi + carry;
295 
296   t_hi = t_hi << (BitShift - r);
297   t_lo = t_lo >> r;
298   t_lo = t_hi + t_lo;
299 
300   return (t_lo >= q) ? (t_lo - q) : t_lo;
301 }
302 
303 /// @brief Hensel's Lemma for 2-adic numbers
304 /// Find solution for qX + 1 = 0 mod 2^r
305 /// @param[in] r
306 /// @param[in] q such that gcd(2, q) = 1
307 /// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r
HenselLemma2adicRoot(uint32_t r,uint64_t q)308 inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) {
309   uint64_t a_prev = 1;
310   uint64_t c = 2;
311   uint64_t mod_mask = 3;
312 
313   // Root:
314   //    f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2
315   // General Case:
316   //    - a_(n) ≡ a_(n-1) mod 2^(n)
317   //      => a_(n) = a_(n-1) + 2^(n)*t
318   //    - Find 't' such that f(a_(n)) = 0 mod  2^(n+1)
319   // First case in for:
320   //    - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t
321   //    - Find 't' so f(a_(1)) ≡ 0 mod 4  => q(1 + 2t) + 1 ≡ 0 mod 4
322   for (uint64_t k = 2; k <= r; k++) {
323     uint64_t f = 0;
324     uint64_t t = 0;
325     uint64_t a = 0;
326 
327     do {
328       a = a_prev + c * t++;
329       f = q * a + 1ULL;
330     } while (f & mod_mask);  // f(a) ≡ 0 mod 2^(k)
331 
332     // Update vars
333     mod_mask = mod_mask * 2 + 1ULL;
334     c *= 2;
335     a_prev = a;
336   }
337 
338   return a_prev;
339 }
340 
341 }  // namespace hexl
342 }  // namespace intel
343