1 // Copyright (C) 2020-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3
4 #pragma once
5
6 #include <stdint.h>
7
8 #include <limits>
9 #include <vector>
10
11 #include "hexl/util/check.hpp"
12 #include "hexl/util/compiler.hpp"
13
14 namespace intel {
15 namespace hexl {
16
17 /// @brief Pre-computes a Barrett factor with which modular multiplication can
18 /// be performed more efficiently
19 class MultiplyFactor {
20 public:
21 MultiplyFactor() = default;
22
23 /// @brief Computes and stores the Barrett factor floor((operand << bit_shift)
24 /// / modulus). This is useful when modular multiplication of the form
25 /// (x * operand) mod modulus is performed with same modulus and operand
26 /// several times. Note, passing operand=1 can be used to pre-compute a
27 /// Barrett factor for multiplications of the form (x * y) mod modulus, where
28 /// only the modulus is re-used across calls to modular multiplication.
MultiplyFactor(uint64_t operand,uint64_t bit_shift,uint64_t modulus)29 MultiplyFactor(uint64_t operand, uint64_t bit_shift, uint64_t modulus)
30 : m_operand(operand) {
31 HEXL_CHECK(operand <= modulus, "operand " << operand
32 << " must be less than modulus "
33 << modulus);
34 HEXL_CHECK(bit_shift == 32 || bit_shift == 52 || bit_shift == 64,
35 "Unsupported BitShift " << bit_shift);
36 uint64_t op_hi = operand >> (64 - bit_shift);
37 uint64_t op_lo = (bit_shift == 64) ? 0 : (operand << bit_shift);
38
39 m_barrett_factor = DivideUInt128UInt64Lo(op_hi, op_lo, modulus);
40 }
41
42 /// @brief Returns the pre-computed Barrett factor
BarrettFactor() const43 inline uint64_t BarrettFactor() const { return m_barrett_factor; }
44
45 /// @brief Returns the operand corresponding to the Barrett factor
Operand() const46 inline uint64_t Operand() const { return m_operand; }
47
48 private:
49 uint64_t m_operand;
50 uint64_t m_barrett_factor;
51 };
52
53 /// @brief Returns whether or not num is a power of two
IsPowerOfTwo(uint64_t num)54 inline bool IsPowerOfTwo(uint64_t num) { return num && !(num & (num - 1)); }
55
56 /// @brief Returns floor(log2(x))
Log2(uint64_t x)57 inline uint64_t Log2(uint64_t x) { return MSB(x); }
58
IsPowerOfFour(uint64_t num)59 inline bool IsPowerOfFour(uint64_t num) {
60 return IsPowerOfTwo(num) && (Log2(num) % 2 == 0);
61 }
62
63 /// @brief Returns the maximum value that can be represented using \p bits bits
MaximumValue(uint64_t bits)64 inline uint64_t MaximumValue(uint64_t bits) {
65 HEXL_CHECK(bits <= 64, "MaximumValue requires bits <= 64; got " << bits);
66 if (bits == 64) {
67 return (std::numeric_limits<uint64_t>::max)();
68 }
69 return (1ULL << bits) - 1;
70 }
71
72 /// @brief Reverses the bits
73 /// @param[in] x Input to reverse
74 /// @param[in] bit_width Number of bits in the input; must be >= MSB(x)
75 /// @return The bit-reversed representation of \p x using \p bit_width bits
76 uint64_t ReverseBits(uint64_t x, uint64_t bit_width);
77
78 /// @brief Returns x^{-1} mod modulus
79 /// @details Requires x % modulus != 0
80 uint64_t InverseMod(uint64_t x, uint64_t modulus);
81
82 /// @brief Returns (x * y) mod modulus
83 /// @details Assumes x, y < modulus
84 uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t modulus);
85
86 /// @brief Returns (x * y) mod modulus
87 /// @param[in] x
88 /// @param[in] y
89 /// @param[in] y_precon 64-bit precondition factor floor(2**64 / modulus)
90 /// @param[in] modulus
91 uint64_t MultiplyMod(uint64_t x, uint64_t y, uint64_t y_precon,
92 uint64_t modulus);
93
94 /// @brief Returns (x + y) mod modulus
95 /// @details Assumes x, y < modulus
96 uint64_t AddUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
97
98 /// @brief Returns (x - y) mod modulus
99 /// @details Assumes x, y < modulus
100 uint64_t SubUIntMod(uint64_t x, uint64_t y, uint64_t modulus);
101
102 /// @brief Returns base^exp mod modulus
103 uint64_t PowMod(uint64_t base, uint64_t exp, uint64_t modulus);
104
105 /// @brief Returns whether or not root is a degree-th root of unity mod modulus
106 /// @param[in] root Root of unity to check
107 /// @param[in] degree Degree of root of unity; must be a power of two
108 /// @param[in] modulus Modulus of finite field
109 bool IsPrimitiveRoot(uint64_t root, uint64_t degree, uint64_t modulus);
110
111 /// @brief Tries to return a primitive degree-th root of unity
112 /// @details Returns 0 or throws an error if no root is found
113 uint64_t GeneratePrimitiveRoot(uint64_t degree, uint64_t modulus);
114
115 /// @brief Returns whether or not root is a degree-th root of unity
116 /// @param[in] degree Must be a power of two
117 /// @param[in] modulus Modulus of finite field
118 uint64_t MinimalPrimitiveRoot(uint64_t degree, uint64_t modulus);
119
120 /// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 *
121 /// modulus]
122 /// @param[in] x
123 /// @param[in] y_operand also denoted y
124 /// @param[in] modulus
125 /// @param[in] y_barrett_factor Pre-computed Barrett reduction factor floor((y
126 /// << BitShift) / modulus)
127 template <int BitShift>
MultiplyModLazy(uint64_t x,uint64_t y_operand,uint64_t y_barrett_factor,uint64_t modulus)128 inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y_operand,
129 uint64_t y_barrett_factor, uint64_t modulus) {
130 HEXL_CHECK(y_operand < modulus, "y_operand " << y_operand
131 << " must be less than modulus "
132 << modulus);
133 HEXL_CHECK(
134 modulus <= MaximumValue(BitShift),
135 "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift));
136 HEXL_CHECK(x <= MaximumValue(BitShift),
137 "Operand " << x << " exceeds bound " << MaximumValue(BitShift));
138
139 uint64_t Q = MultiplyUInt64Hi<BitShift>(x, y_barrett_factor);
140 return y_operand * x - Q * modulus;
141 }
142
143 /// @brief Computes (x * y) mod modulus, except that the output is in [0, 2 *
144 /// modulus]
145 /// @param[in] x
146 /// @param[in] y
147 /// @param[in] modulus
148 template <int BitShift>
MultiplyModLazy(uint64_t x,uint64_t y,uint64_t modulus)149 inline uint64_t MultiplyModLazy(uint64_t x, uint64_t y, uint64_t modulus) {
150 HEXL_CHECK(BitShift == 64 || BitShift == 52,
151 "Unsupported BitShift " << BitShift);
152 HEXL_CHECK(x <= MaximumValue(BitShift),
153 "Operand " << x << " exceeds bound " << MaximumValue(BitShift));
154 HEXL_CHECK(y < modulus,
155 "y " << y << " must be less than modulus " << modulus);
156 HEXL_CHECK(
157 modulus <= MaximumValue(BitShift),
158 "Modulus " << modulus << " exceeds bound " << MaximumValue(BitShift));
159
160 uint64_t y_barrett = MultiplyFactor(y, BitShift, modulus).BarrettFactor();
161 return MultiplyModLazy<BitShift>(x, y, y_barrett, modulus);
162 }
163
164 /// @brief Adds two unsigned 64-bit integers
165 /// @param operand1 Number to add
166 /// @param operand2 Number to add
167 /// @param result Stores the sum
168 /// @return The carry bit
AddUInt64(uint64_t operand1,uint64_t operand2,uint64_t * result)169 inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2,
170 uint64_t* result) {
171 *result = operand1 + operand2;
172 return static_cast<unsigned char>(*result < operand1);
173 }
174
175 /// @brief Returns whether or not the input is prime
176 bool IsPrime(uint64_t n);
177
178 /// @brief Generates a list of num_primes primes in the range [2^(bit_size),
179 // 2^(bit_size+1)]. Ensures each prime q satisfies
180 // q % (2*ntt_size+1)) == 1
181 /// @param[in] num_primes Number of primes to generate
182 /// @param[in] bit_size Bit size of each prime
183 /// @param[in] prefer_small_primes When true, returns primes starting from
184 /// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1)
185 /// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must
186 /// be a power of two less than 2^bit_size.
187 std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
188 bool prefer_small_primes,
189 size_t ntt_size = 1);
190
191 /// @brief Returns input mod modulus, computed via 64-bit Barrett reduction
192 /// @param[in] input
193 /// @param[in] modulus
194 /// @param[in] q_barr floor(2^64 / modulus)
195 template <int OutputModFactor = 1>
BarrettReduce64(uint64_t input,uint64_t modulus,uint64_t q_barr)196 uint64_t BarrettReduce64(uint64_t input, uint64_t modulus, uint64_t q_barr) {
197 HEXL_CHECK(modulus != 0, "modulus == 0");
198 uint64_t q = MultiplyUInt64Hi<64>(input, q_barr);
199 uint64_t q_times_input = input - q * modulus;
200 if (OutputModFactor == 2) {
201 return q_times_input;
202 } else {
203 return (q_times_input >= modulus) ? q_times_input - modulus : q_times_input;
204 }
205 }
206
207 /// @brief Returns x mod modulus, assuming x < InputModFactor * modulus
208 /// @param[in] x
209 /// @param[in] modulus also denoted q
210 /// @param[in] twice_modulus 2 * q; must not be nullptr if InputModFactor == 4
211 /// or 8
212 /// @param[in] four_times_modulus 4 * q; must not be nullptr if InputModFactor
213 /// == 8
214 template <int InputModFactor>
ReduceMod(uint64_t x,uint64_t modulus,const uint64_t * twice_modulus=nullptr,const uint64_t * four_times_modulus=nullptr)215 uint64_t ReduceMod(uint64_t x, uint64_t modulus,
216 const uint64_t* twice_modulus = nullptr,
217 const uint64_t* four_times_modulus = nullptr) {
218 HEXL_CHECK(InputModFactor == 1 || InputModFactor == 2 ||
219 InputModFactor == 4 || InputModFactor == 8,
220 "InputModFactor should be 1, 2, 4, or 8");
221 if (InputModFactor == 1) {
222 return x;
223 }
224 if (InputModFactor == 2) {
225 if (x >= modulus) {
226 x -= modulus;
227 }
228 return x;
229 }
230 if (InputModFactor == 4) {
231 HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr");
232 if (x >= *twice_modulus) {
233 x -= *twice_modulus;
234 }
235 if (x >= modulus) {
236 x -= modulus;
237 }
238 return x;
239 }
240 if (InputModFactor == 8) {
241 HEXL_CHECK(twice_modulus != nullptr, "twice_modulus should not be nullptr");
242 HEXL_CHECK(four_times_modulus != nullptr,
243 "four_times_modulus should not be nullptr");
244
245 if (x >= *four_times_modulus) {
246 x -= *four_times_modulus;
247 }
248 if (x >= *twice_modulus) {
249 x -= *twice_modulus;
250 }
251 if (x >= modulus) {
252 x -= modulus;
253 }
254 return x;
255 }
256 HEXL_CHECK(false, "Should be unreachable");
257 return x;
258 }
259
260 /// @brief Returns Montgomery form of ab mod q, computed via the REDC algorithm,
261 /// also known as Montgomery reduction.
262 /// @param[in] r
263 /// @param[in] q with R = 2^r such that gcd(R, q) = 1. R > q.
264 /// @param[in] inv_mod in [0, R − 1] such that q*v_inv_mod ≡ −1 mod R.
265 /// @param[in] mod_R_msk take r last bits to apply mod R.
266 /// @param[in] T_hi of T = ab in the range [0, Rq − 1].
267 /// @param[in] T_lo of T.
268 /// @return Unsigned long int in the range [0, q − 1] such that S ≡ TR^−1 mod q
269 template <int BitShift>
MontgomeryReduce(uint64_t T_hi,uint64_t T_lo,uint64_t q,int r,uint64_t mod_R_msk,uint64_t inv_mod)270 inline uint64_t MontgomeryReduce(uint64_t T_hi, uint64_t T_lo, uint64_t q,
271 int r, uint64_t mod_R_msk, uint64_t inv_mod) {
272 HEXL_CHECK(BitShift == 64 || BitShift == 52,
273 "Unsupported BitShift " << BitShift);
274 HEXL_CHECK((1ULL << r) > static_cast<uint64_t>(q),
275 "R value should be greater than q = " << static_cast<uint64_t>(q));
276
277 uint64_t mq_hi;
278 uint64_t mq_lo;
279
280 uint64_t m = ((T_lo & mod_R_msk) * inv_mod) & mod_R_msk;
281 MultiplyUInt64(m, q, &mq_hi, &mq_lo);
282
283 if (BitShift == 52) {
284 mq_hi = (mq_hi << 12) | (mq_lo >> 52);
285 mq_lo &= (1ULL << 52) - 1;
286 }
287
288 uint64_t t_hi;
289 uint64_t t_lo;
290
291 // first 64bit block
292 t_lo = T_lo + mq_lo;
293 unsigned int carry = static_cast<unsigned int>(t_lo < T_lo);
294 t_hi = T_hi + mq_hi + carry;
295
296 t_hi = t_hi << (BitShift - r);
297 t_lo = t_lo >> r;
298 t_lo = t_hi + t_lo;
299
300 return (t_lo >= q) ? (t_lo - q) : t_lo;
301 }
302
303 /// @brief Hensel's Lemma for 2-adic numbers
304 /// Find solution for qX + 1 = 0 mod 2^r
305 /// @param[in] r
306 /// @param[in] q such that gcd(2, q) = 1
307 /// @return Unsigned long int in [0, 2^r − 1] such that q*x ≡ −1 mod 2^r
HenselLemma2adicRoot(uint32_t r,uint64_t q)308 inline uint64_t HenselLemma2adicRoot(uint32_t r, uint64_t q) {
309 uint64_t a_prev = 1;
310 uint64_t c = 2;
311 uint64_t mod_mask = 3;
312
313 // Root:
314 // f(x) = qX + 1 and a_(0) = 1 then f(1) ≡ 0 mod 2
315 // General Case:
316 // - a_(n) ≡ a_(n-1) mod 2^(n)
317 // => a_(n) = a_(n-1) + 2^(n)*t
318 // - Find 't' such that f(a_(n)) = 0 mod 2^(n+1)
319 // First case in for:
320 // - a_(1) ≡ 1 mod 2 or a_(1) = 1 + 2t
321 // - Find 't' so f(a_(1)) ≡ 0 mod 4 => q(1 + 2t) + 1 ≡ 0 mod 4
322 for (uint64_t k = 2; k <= r; k++) {
323 uint64_t f = 0;
324 uint64_t t = 0;
325 uint64_t a = 0;
326
327 do {
328 a = a_prev + c * t++;
329 f = q * a + 1ULL;
330 } while (f & mod_mask); // f(a) ≡ 0 mod 2^(k)
331
332 // Update vars
333 mod_mask = mod_mask * 2 + 1ULL;
334 c *= 2;
335 a_prev = a;
336 }
337
338 return a_prev;
339 }
340
341 } // namespace hexl
342 } // namespace intel
343