1 /* A class for modular arithmetic with residues and modulus of up to 64 2 * bits. */ 3 4 #ifndef MODMPZ_HPP 5 #define MODMPZ_HPP 6 7 /**********************************************************************/ 8 #include <cstdint> 9 #include <gmp.h> // for mp_limb_t, mpz_size, __gmpn_copyi, __gmpn_s... 10 #include <cstddef> // for size_t, NULL 11 #include <new> // for operator new 12 #include "macros.h" 13 #include "gmp_aux.h" // for mpz_cmp_uint64, mpz_get_uint64, mpz_set_uint64 14 #include "cxx_mpz.hpp" 15 // #include "modint.hpp" 16 #include "mod_stdop.hpp" 17 #include "misc.h" 18 19 class ModulusMPZ { 20 /* Type definitions */ 21 public: 22 typedef cxx_mpz Integer; 23 class Residue { 24 friend class ModulusMPZ; 25 protected: 26 mp_limb_t *r; 27 public: 28 typedef ModulusMPZ Modulus; 29 typedef Modulus::Integer Integer; 30 typedef bool IsResidueType; 31 Residue() = delete; Residue(const Modulus & m)32 Residue(const Modulus &m) { 33 r = new mp_limb_t[mpz_size(m.m) + 1]; 34 mpn_zero(r, mpz_size(m.m) + 1); 35 } Residue(const Modulus & m,const Residue & s)36 Residue(const Modulus &m, const Residue &s) { 37 r = new mp_limb_t[mpz_size(m.m) + 1]; 38 mpn_copyi(r, s.r, mpz_size(m.m)); 39 } ~Residue()40 ~Residue() { 41 delete[] r; 42 } Residue(Residue && s)43 Residue(Residue &&s) : r(s.r) { 44 s.r = NULL; 45 } 46 Residue(Residue const & s) = delete; operator =(Residue && s)47 Residue& operator=(Residue &&s) { 48 delete[] r; 49 r = s.r; 50 s.r = NULL; 51 return *this; 52 } 53 Residue& operator=(Residue const & s) = delete; 54 }; 55 56 typedef ResidueStdOp<Residue> ResidueOp; 57 58 /* Data members */ 59 protected: 60 mpz_t m; 61 size_t bits; 62 static const size_t limbsPerUint64 = iceildiv(64, GMP_NUMB_BITS); 63 64 /** Convert a uint64_t to an array of mp_limb_t. 65 * Always writes exactly limbsToWrite limbs. */ uint64ToLimbs(mp_limb_t * r,const uint64_t s,const size_t limbsToWrite)66 static void uint64ToLimbs(mp_limb_t *r, const uint64_t s, const size_t limbsToWrite) { 67 uint64_t t = s; 68 for (size_t i = 0; i < limbsToWrite; i++) { 69 r[i] = t & GMP_NUMB_MASK; 70 #if GMP_NUMB_BITS < 64 71 t >>= GMP_NUMB_BITS; 72 #else 73 t = 0; 74 #endif 75 } 76 ASSERT_ALWAYS(t == 0); 77 } 78 79 /* Methods used internally */ 80 /** Set a residue to m. 81 * This leaves the residue in a non-reduced state. */ setM(mp_limb_t * r) const82 void setM(mp_limb_t *r) const { 83 mpn_copyi(r, m->_mp_d, mpz_size(m)); 84 } 85 /** Add M to a residue, return carry */ addM(mp_limb_t * r,const mp_limb_t * s) const86 mp_limb_t addM(mp_limb_t *r, const mp_limb_t *s) const { 87 return mpn_add_n(r, s, m->_mp_d, mpz_size(m)); 88 } 89 /** Subtract M from a residue, return borrow */ subM(mp_limb_t * r,const mp_limb_t * s) const90 mp_limb_t subM(mp_limb_t *r, const mp_limb_t *s) const { 91 return mpn_sub_n(r, s, m->_mp_d, mpz_size(m)); 92 } 93 /** Returns a negative value if s < m, 0 if s == m, and a positive value if s > m. */ cmpM(const mp_limb_t * s) const94 int cmpM(const mp_limb_t *s) const { 95 return mpn_cmp(s, m->_mp_d, mpz_size(m)); 96 } modM(const uint64_t s) const97 uint64_t modM(const uint64_t s) const { 98 if (bits > 64 || mpz_cmp_uint64(m, s) > 0) 99 return s; 100 /* m <= s, thus m fits in uint64_t */ 101 uint64_t m64 = mpz_get_uint64(m); 102 /* We always make sure that the modulus in non-zero in the ctor 103 */ 104 ASSERT_FOR_STATIC_ANALYZER(m64 != 0); 105 return s % m64; 106 } set_residue_u64(Residue & r,const uint64_t s) const107 void set_residue_u64(Residue &r, const uint64_t s) const { 108 ASSERT(mpz_cmp_uint64(m, s) > 0); 109 const size_t limbsToWrite = MIN(limbsPerUint64, mpz_size(m)); 110 uint64ToLimbs(r.r, s, limbsToWrite); 111 for (size_t i = limbsToWrite; i < mpz_size(m); i++) 112 r.r[i] = 0; 113 } set_residue_mpz(Residue & r,const mpz_t s) const114 void set_residue_mpz(Residue &r, const mpz_t s) const { 115 ASSERT_ALWAYS(mpz_cmp(s, m) < 0); 116 size_t written; 117 mpz_export(r.r, &written, -1, sizeof(mp_limb_t), 0, GMP_NAIL_BITS, s); 118 ASSERT_ALWAYS(written <= mpz_size(m)); 119 for (size_t i = written; i < mpz_size(m); i++) 120 r.r[i] = 0; 121 } set_mpz_residue(mpz_t r,const Residue & s) const122 void set_mpz_residue(mpz_t r, const Residue &s) const { 123 mpz_import(r, mpz_size(m), -1, sizeof(mp_limb_t), 0, GMP_NAIL_BITS, s.r); 124 } assertValid(const Residue & a MAYBE_UNUSED) const125 void assertValid(const Residue &a MAYBE_UNUSED) const { 126 ASSERT_EXPENSIVE (cmpM(a.r) < 0); 127 } assertValid(const uint64_t a MAYBE_UNUSED) const128 void assertValid(const uint64_t a MAYBE_UNUSED) const { 129 ASSERT_EXPENSIVE (mpz_cmp_uint64(m, a) > 0); 130 } assertValid(const mpz_t s MAYBE_UNUSED) const131 void assertValid(const mpz_t s MAYBE_UNUSED) const { 132 ASSERT_EXPENSIVE (mpz_cmp(s, m) < 0); 133 } assertValid(const cxx_mpz & s MAYBE_UNUSED) const134 void assertValid(const cxx_mpz &s MAYBE_UNUSED) const { 135 assertValid((mpz_srcptr) s); 136 } 137 138 /* Methods of the API */ 139 public: getminmod(Integer & r)140 static void getminmod(Integer &r) { 141 mpz_set_ui(r, 0); 142 } getmaxmod(Integer & r)143 static void getmaxmod(Integer &r) { 144 mpz_set_ui(r, 0); 145 } valid(const Integer & m MAYBE_UNUSED)146 static bool valid(const Integer &m MAYBE_UNUSED) { 147 return true; 148 } 149 ModulusMPZ(const uint64_t s)150 ModulusMPZ(const uint64_t s) { 151 ASSERT_ALWAYS(s > 0); 152 mpz_init(m); 153 mpz_set_uint64(m, s); 154 bits = mpz_sizeinbase(m, 2); 155 } ModulusMPZ(const ModulusMPZ & s)156 ModulusMPZ(const ModulusMPZ &s) { 157 mpz_init_set(m, s.m); 158 bits = s.bits; 159 } ModulusMPZ(const Integer & s)160 ModulusMPZ(const Integer &s) { 161 ASSERT_ALWAYS(mpz_sgn(s) > 0); 162 mpz_init(m); 163 mpz_set(m, s); 164 bits = mpz_sizeinbase(m, 2); 165 } ~ModulusMPZ()166 ~ModulusMPZ() { 167 mpz_clear(m); 168 } getmod(Integer & r) const169 void getmod (Integer &r) const { 170 mpz_set(r, m); 171 } 172 173 /* Methods for residues */ 174 175 /** Allocate an array of len residues. 176 * 177 * Must be freed with deleteArray(), not with delete[]. 178 */ newArray(const size_t len) const179 Residue *newArray(const size_t len) const { 180 void *t = operator new[](len * sizeof(Residue)); 181 if (t == NULL) 182 return NULL; 183 Residue *ptr = static_cast<Residue *>(t); 184 for(size_t i = 0; i < len; i++) { 185 new(&ptr[i]) Residue(*this); 186 } 187 return ptr; 188 } 189 deleteArray(Residue * ptr,const size_t len) const190 void deleteArray(Residue *ptr, const size_t len) const { 191 for(size_t i = len; i > 0; i++) { 192 ptr[i - 1].~Residue(); 193 } 194 operator delete[](ptr); 195 } 196 197 set(Residue & r,const Residue & s) const198 void set (Residue &r, const Residue &s) const { 199 assertValid(s); 200 mpn_copyi(r.r, s.r, mpz_size(m)); 201 } set(Residue & r,const uint64_t s) const202 void set (Residue &r, const uint64_t s) const { 203 const uint64_t sm = modM(s); 204 set_residue_u64(r, sm); 205 } set(Residue & r,const Integer & s) const206 void set (Residue &r, const Integer &s) const { 207 cxx_mpz t; 208 mpz_mod(t, s, m); 209 set_residue_mpz(r, t); 210 } 211 /* Sets the Residue to the class represented by the integer s. Assumes that 212 s is reduced (mod m), i.e. 0 <= s < m */ set_reduced(Residue & r,const uint64_t s) const213 void set_reduced (Residue &r, const uint64_t s) const { 214 assertValid(s); 215 set_residue_u64(r, s); 216 } set_reduced(Residue & r,const Integer & s) const217 void set_reduced (Residue &r, const Integer &s) const { 218 assertValid(r); 219 set_residue_mpz(r, s); 220 } set_int64(Residue & r,const int64_t s) const221 void set_int64 (Residue &r, const int64_t s) const { 222 const uint64_t u = modM(safe_abs64(s)); 223 set_residue_u64(r, u); 224 if (s < 0) 225 neg(r, r); 226 } set0(Residue & r) const227 void set0 (Residue &r) const { 228 mpn_zero(r.r, mpz_size(m)); 229 } set1(Residue & r) const230 void set1 (Residue &r) const { 231 set0(r); 232 if (mpz_cmp_ui(m, 1) > 0) { 233 r.r[0] = 1; 234 } 235 } 236 /* Exchanges the values of the two arguments */ swap(Residue & a,Residue & b) const237 void swap (Residue &a, Residue &b) const { 238 mp_limb_t *t = a.r; 239 a.r = b.r; 240 b.r = t; 241 } get(Integer & r,const Residue & s) const242 void get (Integer &r, const Residue &s) const { 243 assertValid(s); 244 set_mpz_residue(r, s); 245 } equal(const Residue & a,const Residue & b) const246 bool equal (const Residue &a, const Residue &b) const { 247 assertValid(a); 248 assertValid(b); 249 return mpn_cmp(a.r, b.r, mpz_size(m)) == 0; 250 } is0(const Residue & a) const251 bool is0 (const Residue &a) const { 252 assertValid(a); 253 return mpn_zero_p(a.r, mpz_size(m)); 254 } is1(const Residue & a) const255 bool is1 (const Residue &a) const { 256 assertValid(a); 257 if (mpz_cmp_ui(m, 1) == 0) 258 return 1; 259 return a.r[0] == 1 && (mpz_size(m) == 1 || mpn_zero_p(a.r + 1, mpz_size(m) - 1)); 260 } neg(Residue & r,const Residue & a) const261 void neg (Residue &r, const Residue &a) const { 262 assertValid(a); 263 if (is0(a) != 0) { 264 if (&r != &a) 265 set0(r); 266 } else { 267 mp_limb_t bw = mpn_sub_n(r.r, m->_mp_d, r.r, mpz_size(m)); 268 ASSERT_ALWAYS(bw == 0); 269 } 270 assertValid(r); 271 } add(Residue & r,const Residue & a,const Residue & b) const272 void add (Residue &r, const Residue &a, const Residue &b) const { 273 assertValid(a); 274 assertValid(b); 275 const mp_limb_t cy = mpn_add_n(r.r, a.r, b.r, mpz_size(m)); 276 if (cy || cmpM(r.r) >= 0) { 277 const mp_limb_t bw = subM(r.r, r.r); 278 ASSERT_ALWAYS(bw == cy); 279 } 280 assertValid(r); 281 } add1(Residue & r,const Residue & a) const282 void add1 (Residue &r, const Residue &a) const { 283 assertValid(a); 284 const mp_limb_t cy = mpn_add_1(r.r, a.r, 1, mpz_size(m)); 285 if (cy || cmpM(r.r) >= 0) { 286 const mp_limb_t bw = subM(r.r, r.r); 287 ASSERT_ALWAYS(bw == cy); 288 } 289 assertValid(r); 290 } add(Residue & r,const Residue & a,const uint64_t b) const291 void add (Residue &r, const Residue &a, const uint64_t b) const { 292 assertValid(a); 293 const uint64_t bm = modM(b); 294 mp_limb_t cy; 295 296 if (limbsPerUint64 == 1 || bm < GMP_NUMB_MAX) { 297 cy = mpn_add_1(r.r, a.r, (mp_limb_t) bm, mpz_size(m)); 298 } else { 299 mp_limb_t t[limbsPerUint64]; 300 const size_t toWrite = MIN(limbsPerUint64, mpz_size(m)); 301 302 uint64ToLimbs(t, bm, limbsPerUint64); 303 cy = mpn_add(r.r, a.r, mpz_size(m), t, toWrite); 304 } 305 if (cy || cmpM(r.r) >= 0) { 306 mp_limb_t bw = subM(r.r, r.r); 307 ASSERT_ALWAYS(bw == cy); 308 } 309 assertValid(r); 310 } sub(Residue & r,const Residue & a,const Residue & b) const311 void sub (Residue &r, const Residue &a, const Residue &b) const { 312 assertValid(a); 313 assertValid(b); 314 const mp_limb_t bw = mpn_sub_n(r.r, a.r, b.r, mpz_size(m)); 315 if (bw) { 316 const mp_limb_t cy = addM(r.r, r.r); 317 ASSERT_ALWAYS(cy == bw); 318 } 319 assertValid(r); 320 } sub1(Residue & r,const Residue & a) const321 void sub1 (Residue &r, const Residue &a) const { 322 assertValid(a); 323 const mp_limb_t bw = mpn_sub_1(r.r, a.r, 1, mpz_size(m)); 324 if (bw) { 325 const mp_limb_t cy = addM(r.r, r.r); 326 ASSERT_ALWAYS(cy == bw); 327 } 328 assertValid(r); 329 } sub(Residue & r,const Residue & a,const uint64_t b) const330 void sub (Residue &r, const Residue &a, const uint64_t b) const { 331 assertValid(a); 332 const uint64_t bm = modM(b); 333 mp_limb_t bw; 334 335 if (limbsPerUint64 == 1 || bm < GMP_NUMB_MAX) { 336 bw = mpn_sub_1(r.r, a.r, (mp_limb_t) bm, mpz_size(m)); 337 } else { 338 mp_limb_t t[limbsPerUint64]; 339 const size_t toWrite = MIN(limbsPerUint64, mpz_size(m)); 340 341 uint64ToLimbs(t, bm, limbsPerUint64); 342 bw = mpn_sub(r.r, a.r, mpz_size(m), t, toWrite); 343 } 344 if (bw) { 345 mp_limb_t cy = addM(r.r, r.r); 346 ASSERT_ALWAYS(cy == bw); 347 } 348 assertValid(r); 349 } mul(Residue & r,const Residue & a,const Residue & b) const350 void mul (Residue &r, const Residue &a, const Residue &b) const { 351 const size_t nrWords = mpz_size(m); 352 mp_limb_t Q[2]; 353 mp_limb_t *t; 354 if (r.r == a.r || r.r == b.r) { 355 t = new mp_limb_t[nrWords + 1]; 356 } else { 357 t = r.r; 358 } 359 t[nrWords] = mpn_mul_1 (t, a.r, nrWords, b.r[nrWords - 1]); 360 if (t[nrWords] != 0) 361 mpn_tdiv_qr (Q, t, 0, t, nrWords + 1, m->_mp_d, nrWords); 362 /* t <= (m-1) * beta */ 363 for (size_t iWord = nrWords - 1; iWord > 0; iWord--) { 364 mpn_copyd(t + 1, t, nrWords); 365 t[0] = 0; 366 mp_limb_t msw = mpn_addmul_1 (t, a.r, nrWords, b.r[iWord - 1]); 367 t[nrWords] += msw; 368 mp_limb_t cy = t[nrWords] < msw; 369 if (cy) { 370 mp_limb_t bw = subM(t + 1, t + 1); 371 ASSERT_ALWAYS(bw == cy); 372 } 373 if (t[nrWords] != 0) 374 mpn_tdiv_qr (Q, t, 0, t, nrWords + 1, m->_mp_d, nrWords); 375 } 376 if (cmpM(t) >= 0) { 377 mpn_tdiv_qr (Q, t, 0, t, nrWords, m->_mp_d, nrWords); 378 } 379 if (r.r == a.r || r.r == b.r) { 380 mpn_copyi(r.r, t, nrWords); 381 delete[] t; 382 } 383 } sqr(Residue & r,const Residue & a) const384 void sqr (Residue &r, const Residue &a) const { 385 mul(r, a, a); 386 } next(Residue & r) const387 bool next (Residue &r) const { 388 add1(r, r); 389 return finished(r); 390 } finished(const Residue & r) const391 bool finished (const Residue &r) const { 392 return is0(r); 393 } div2(Residue & r,const Residue & a) const394 bool div2 (Residue &r, const Residue &a) const { 395 assertValid(a); 396 if (mpz_even_p(m)) { 397 return false; 398 } else { 399 if (a.r[0] % 2 == 0) { 400 mp_limb_t lsb = mpn_rshift(r.r, a.r, mpz_size(m), 1); 401 ASSERT_ALWAYS(lsb == 0); 402 } else { 403 mp_limb_t cy = addM(r.r, a.r); 404 mp_limb_t lsb = mpn_rshift(r.r, r.r, mpz_size(m), 1); 405 ASSERT_ALWAYS(lsb == 0); 406 r.r[mpz_size(m) - 1] |= cy << (GMP_NUMB_BITS - 1); 407 } 408 assertValid(r); 409 return true; 410 } 411 } 412 413 /* Given a = V_n (x), b = V_m (x) and d = V_{n-m} (x), compute V_{m+n} (x). 414 * r can be the same variable as a or b but must not be the same variable as d. 415 */ V_dadd(Residue & r,const Residue & a,const Residue & b,const Residue & d) const416 void V_dadd (Residue &r, const Residue &a, const Residue &b, 417 const Residue &d) const { 418 ASSERT (&r != &d); 419 mul (r, a, b); 420 sub (r, r, d); 421 } 422 423 /* Given a = V_n (x) and two = 2, compute V_{2n} (x). 424 * r can be the same variable as a but must not be the same variable as two. 425 */ V_dbl(Residue & r,const Residue & a,const Residue & two) const426 void V_dbl (Residue &r, const Residue &a, const Residue &two) const { 427 ASSERT (&r != &two); 428 sqr (r, a); 429 sub (r, r, two); 430 } 431 432 433 /* prototypes of non-inline functions */ 434 bool div3 (Residue &, const Residue &) const; 435 bool div5 (Residue &, const Residue &) const; 436 bool div7 (Residue &, const Residue &) const; 437 bool div11 (Residue &, const Residue &) const; 438 bool div13 (Residue &, const Residue &) const; 439 bool divn (Residue &, const Residue &, unsigned long) const; 440 void gcd (Integer &, const Residue &) const; 441 void pow (Residue &, const Residue &, const uint64_t) const; 442 void pow (Residue &, const Residue &, const uint64_t *, const size_t) const; 443 void pow (Residue &r, const Residue &b, const Integer &e) const; 444 void pow2 (Residue &, const uint64_t) const; 445 void pow2 (Residue &, const uint64_t *, const size_t) const; 446 void pow2 (Residue &r, const Integer &e) const; 447 void pow3 (Residue &, uint64_t) const; 448 void V (Residue &, const Residue &, const uint64_t) const; 449 void V (Residue &, const Residue &, const uint64_t *, const int) const; 450 void V (Residue &r, const Residue &b, const Integer &e) const; 451 void V (Residue &r, Residue *rp1, const Residue &b, 452 const uint64_t k) const; 453 bool sprp (const Residue &) const; 454 bool sprp2 () const; 455 bool isprime () const; 456 bool inv (Residue &, const Residue &) const; 457 bool inv_odd (Residue &, const Residue &) const; 458 bool inv_powerof2 (Residue &, const Residue &) const; 459 bool batchinv (Residue *, const Residue *, size_t, const Residue *) const; 460 int jacobi (const Residue &) const; 461 protected: 462 bool find_minus1 (Residue &r1, const Residue &minusone, const int po2) const; 463 bool divn (Residue &, const Residue &, unsigned long, const mp_limb_t *, mp_limb_t) const; 464 }; 465 466 #endif /* MOD64_HPP */ 467