1 /* A class for modular arithmetic with residues and modulus of up to 64 2 * bits. */ 3 4 #ifndef MOD64_HPP 5 #define MOD64_HPP 6 7 /**********************************************************************/ 8 #include <cstdint> 9 #include <cstdlib> 10 #include <new> 11 #include "macros.h" 12 #include "u64arith.h" 13 #include "modint.hpp" 14 #include "mod_stdop.hpp" 15 16 class Modulus64 { 17 /* Type definitions */ 18 public: 19 typedef Integer64 Integer; 20 class Residue { 21 friend class Modulus64; 22 protected: 23 uint64_t r; 24 public: 25 typedef Modulus64 Modulus; 26 typedef Modulus::Integer Integer; 27 typedef bool IsResidueType; 28 Residue() = delete; Residue(const Modulus & m MAYBE_UNUSED)29 Residue(const Modulus &m MAYBE_UNUSED) : r(0) {} Residue(const Modulus & m MAYBE_UNUSED,const Residue & s)30 Residue(const Modulus &m MAYBE_UNUSED, const Residue &s) : r(s.r) {} Residue(const Residue && s)31 Residue(const Residue &&s) : r(s.r) {} 32 protected: operator =(const Residue & s)33 Residue &operator=(const Residue &s) {r = s.r; return *this;} operator =(const Integer & s)34 Residue &operator=(const Integer &s) {r = 0; s.get(&r, 1); return *this;} operator =(const uint64_t s)35 Residue &operator=(const uint64_t s) {r = s; return *this;} 36 }; 37 38 typedef ResidueStdOp<Residue> ResidueOp; 39 40 /* Data members */ 41 protected: 42 uint64_t m; 43 /* Methods used internally */ assertValid(const Residue & a MAYBE_UNUSED) const44 void assertValid(const Residue &a MAYBE_UNUSED) const {ASSERT_EXPENSIVE (a.r < m);} assertValid(const uint64_t a MAYBE_UNUSED) const45 void assertValid(const uint64_t a MAYBE_UNUSED) const {ASSERT_EXPENSIVE (a < m);} get_u64(const Residue & s) const46 uint64_t get_u64 (const Residue &s) const {assertValid(s); return s.r;} 47 48 /* Methods of the API */ 49 public: getminmod()50 static Integer getminmod() {return Integer(1);} getmaxmod()51 static Integer getmaxmod() {return Integer(UINT64_MAX);} getminmod(Integer & r)52 static void getminmod(Integer &r) {r = getminmod();} getmaxmod(Integer & r)53 static void getmaxmod(Integer &r) {r = getmaxmod();} valid(const Integer & m)54 static bool valid(const Integer &m) {return getminmod() <= m && m <= getmaxmod();} 55 Modulus64(const uint64_t s)56 Modulus64(const uint64_t s) : m(s){} Modulus64(const Modulus64 & s)57 Modulus64(const Modulus64 &s) : m(s.m){} Modulus64(const Integer & s)58 Modulus64(const Integer &s) {s.get(&m, 1);} ~Modulus64()59 ~Modulus64() {} getmod_u64() const60 uint64_t getmod_u64 () const {return m;} getmod(Integer & r) const61 void getmod (Integer &r) const {r = m;} 62 63 /* Methods for residues */ 64 65 /** Allocate an array of len residues. 66 * 67 * Must be freed with deleteArray(), not with delete[]. 68 */ newArray(const size_t len) const69 Residue *newArray(const size_t len) const { 70 void *t = operator new[](len * sizeof(Residue)); 71 if (t == NULL) 72 return NULL; 73 Residue *ptr = static_cast<Residue *>(t); 74 for(size_t i = 0; i < len; i++) { 75 new(&ptr[i]) Residue(*this); 76 } 77 return ptr; 78 } 79 deleteArray(Residue * ptr,const size_t len) const80 void deleteArray(Residue *ptr, const size_t len) const { 81 for(size_t i = len; i > 0; i++) { 82 ptr[i - 1].~Residue(); 83 } 84 operator delete[](ptr); 85 } 86 set(Residue & r,const Residue & s) const87 void set (Residue &r, const Residue &s) const {assertValid(s); r = s;} set(Residue & r,const uint64_t s) const88 void set (Residue &r, const uint64_t s) const {r.r = s % m;} set(Residue & r,const Integer & s) const89 void set (Residue &r, const Integer &s) const {s.get(&r.r, 1); r.r %= m;} 90 /* Sets the Residue to the class represented by the integer s. Assumes that 91 s is reduced (mod m), i.e. 0 <= s < m */ set_reduced(Residue & r,const uint64_t s) const92 void set_reduced (Residue &r, const uint64_t s) const {assertValid(s); r.r = s;} set_reduced(Residue & r,const Integer & s) const93 void set_reduced (Residue &r, const Integer &s) const {s.get(&r.r, 1); assertValid(r);} set_int64(Residue & r,const int64_t s) const94 void set_int64 (Residue &r, const int64_t s) const {r.r = llabs(s) % m; if (s < 0) neg(r, r);} set0(Residue & r) const95 void set0 (Residue &r) const {r.r = 0;} set1(Residue & r) const96 void set1 (Residue &r) const {r.r = (m != 1);} 97 /* Exchanges the values of the two arguments */ swap(Residue & a,Residue & b) const98 void swap (Residue &a, Residue &b) const {uint64_t t = a.r; a.r = b.r; b.r = t;} get(Integer & r,const Residue & s) const99 void get (Integer &r, const Residue &s) const {assertValid(s); r = Integer(s.r);} equal(const Residue & a,const Residue & b) const100 bool equal (const Residue &a, const Residue &b) const {assertValid(a); assertValid(b); return (a.r == b.r);} is0(const Residue & a) const101 bool is0 (const Residue &a) const {assertValid(a); return (a.r == 0);} is1(const Residue & a) const102 bool is1 (const Residue &a) const {assertValid(a); return (a.r == 1);} neg(Residue & r,const Residue & a) const103 void neg (Residue &r, const Residue &a) const { 104 assertValid(a); 105 if (a.r == 0) 106 r.r = a.r; 107 else 108 r.r = m - a.r; 109 } add(Residue & r,const Residue & a,const Residue & b) const110 void add (Residue &r, const Residue &a, const Residue &b) const {u64arith_addmod_1_1(&r.r, a.r, b.r, m);} add1(Residue & r,const Residue & a) const111 void add1 (Residue &r, const Residue &a) const { 112 assertValid(a); 113 r.r = a.r + 1; 114 if (r.r == m) 115 r.r = 0; 116 } add(Residue & r,const Residue & a,const uint64_t b) const117 void add (Residue &r, const Residue &a, const uint64_t b) const { 118 u64arith_addmod_1_1(&r.r, a.r, b % m, m); 119 } sub(Residue & r,const Residue & a,const Residue & b) const120 void sub (Residue &r, const Residue &a, const Residue &b) const { 121 u64arith_submod_1_1(&r.r, a.r, b.r, m); 122 } sub1(Residue & r,const Residue & a) const123 void sub1 (Residue &r, const Residue &a) const { 124 u64arith_submod_1_1(&r.r, a.r, 1, m); 125 } sub(Residue & r,const Residue & a,const uint64_t b) const126 void sub (Residue &r, const Residue &a, const uint64_t b) const { 127 u64arith_submod_1_1(&r.r, a.r, b % m, m); 128 } mul(Residue & r,const Residue & a,const Residue & b) const129 void mul (Residue &r, const Residue &a, const Residue &b) const { 130 uint64_t t1, t2; 131 assertValid(a); 132 assertValid(b); 133 u64arith_mul_1_1_2 (&t1, &t2, a.r, b.r); 134 u64arith_divr_2_1_1 (&r.r, t1, t2, m); 135 } sqr(Residue & r,const Residue & a) const136 void sqr (Residue &r, const Residue &a) const { 137 uint64_t t1, t2; 138 assertValid(a); 139 u64arith_mul_1_1_2 (&t1, &t2, a.r, a.r); 140 u64arith_divr_2_1_1 (&r.r, t1, t2, m); 141 } 142 /* Computes (a * 2^wordsize) % m */ tomontgomery(Residue & r,const Residue & a) const143 void tomontgomery (Residue &r, const Residue &a) const { 144 assertValid(a); 145 u64arith_divr_2_1_1 (&r.r, 0, a.r, m); 146 } 147 /* Computes (a / 2^wordsize) % m */ frommontgomery(Residue & r,const Residue & a,const uint64_t invm) const148 void frommontgomery (Residue &r, const Residue &a, const uint64_t invm) const { 149 uint64_t tlow, thigh; 150 assertValid(a); 151 tlow = a.r * invm; 152 u64arith_mul_1_1_2 (&tlow, &thigh, tlow, m); 153 r.r = thigh + (a.r != 0 ? 1 : 0); 154 } 155 /* Computes (a / 2^wordsize) % m, but result can be r = m. 156 Input a must not be equal 0 */ redcsemi_u64_not0(Residue & r,const uint64_t a,const uint64_t invm) const157 void redcsemi_u64_not0 (Residue &r, const uint64_t a, const uint64_t invm) const { 158 uint64_t tlow, thigh; 159 ASSERT (a != 0); 160 tlow = a * invm; /* tlow <= 2^w-1 */ 161 u64arith_mul_1_1_2 (&tlow, &thigh, tlow, m); 162 /* thigh:tlow <= (2^w-1) * m */ 163 r.r = thigh + 1; 164 /* (thigh+1):tlow <= 2^w + (2^w-1) * m <= 2^w + 2^w*m - m 165 <= 2^w * (m + 1) - m */ 166 /* r <= floor ((2^w * (m + 1) - m) / 2^w) <= floor((m + 1) - m/2^w) 167 <= m */ 168 } next(Residue & r) const169 bool next (Residue &r) const {return (++r.r == m);} finished(const Residue & r) const170 bool finished (const Residue &r) const {return (r.r == m);} div2(Residue & r,const Residue & a) const171 bool div2 (Residue &r, const Residue &a) const { 172 if (m % 2 == 0) 173 return false; 174 else { 175 r.r = u64arith_div2mod(a.r, m); 176 return true; 177 } 178 } 179 180 /* Given a = V_n (x), b = V_m (x) and d = V_{n-m} (x), compute V_{m+n} (x). 181 * r can be the same variable as a or b but must not be the same variable as d. 182 */ V_dadd(Residue & r,const Residue & a,const Residue & b,const Residue & d) const183 void V_dadd (Residue &r, const Residue &a, const Residue &b, 184 const Residue &d) const { 185 ASSERT (&r != &d); 186 mul (r, a, b); 187 sub (r, r, d); 188 } 189 190 /* Given a = V_n (x) and two = 2, compute V_{2n} (x). 191 * r can be the same variable as a but must not be the same variable as two. 192 */ V_dbl(Residue & r,const Residue & a,const Residue & two) const193 void V_dbl (Residue &r, const Residue &a, const Residue &two) const { 194 ASSERT (&r != &two); 195 sqr (r, a); 196 sub (r, r, two); 197 } 198 199 /* prototypes of non-inline functions */ 200 bool div3 (Residue &, const Residue &) const; 201 bool div5 (Residue &, const Residue &) const; 202 bool div7 (Residue &, const Residue &) const; 203 bool div11 (Residue &, const Residue &) const; 204 bool div13 (Residue &, const Residue &) const; 205 void gcd (Integer &, const Residue &) const; 206 void pow (Residue &, const Residue &, const uint64_t) const; 207 void pow (Residue &, const Residue &, const uint64_t *, const size_t) const; 208 void pow (Residue &, const Residue &, const Integer &) const; 209 void pow2 (Residue &, const uint64_t) const; 210 void pow2 (Residue &, const uint64_t *, const size_t) const; 211 void pow2 (Residue &, const Integer &) const; 212 void pow3 (Residue &, uint64_t) const; 213 void V (Residue &, const Residue &, const uint64_t) const; 214 void V (Residue &, const Residue &, const uint64_t *, const size_t) const; 215 void V (Residue &, const Residue &, const Integer &) const; 216 void V (Residue &r, Residue *rp1, const Residue &b, 217 const uint64_t k) const; 218 bool sprp (const Residue &) const; 219 bool sprp2 () const; 220 bool isprime () const; 221 bool inv (Residue &, const Residue &) const; 222 bool inv_odd (Residue &, const Residue &) const; 223 bool inv_powerof2 (Residue &, const Residue &) const; 224 bool batchinv (Residue *, const Residue *, size_t, const Residue *) const; 225 int jacobi (const Residue &) const; 226 protected: 227 bool find_minus1 (Residue &r1, const Residue &minusone, const int po2) const; 228 }; 229 230 #endif /* MOD64_HPP */ 231