1 /* A class for modular arithmetic with residues and modulus of up to 64
2  * bits. */
3 
4 #ifndef MOD64_HPP
5 #define MOD64_HPP
6 
7 /**********************************************************************/
8 #include <cstdint>
9 #include <cstdlib>
10 #include <new>
11 #include "macros.h"
12 #include "u64arith.h"
13 #include "modint.hpp"
14 #include "mod_stdop.hpp"
15 
16 class Modulus64 {
17     /* Type definitions */
18 public:
19     typedef Integer64 Integer;
20     class Residue {
21         friend class Modulus64;
22     protected:
23         uint64_t r;
24     public:
25         typedef Modulus64 Modulus;
26         typedef Modulus::Integer Integer;
27         typedef bool IsResidueType;
28         Residue() = delete;
Residue(const Modulus & m MAYBE_UNUSED)29         Residue(const Modulus &m MAYBE_UNUSED) : r(0) {}
Residue(const Modulus & m MAYBE_UNUSED,const Residue & s)30         Residue(const Modulus &m MAYBE_UNUSED, const Residue &s) : r(s.r) {}
Residue(const Residue && s)31         Residue(const Residue &&s) : r(s.r) {}
32     protected:
operator =(const Residue & s)33         Residue &operator=(const Residue &s) {r = s.r; return *this;}
operator =(const Integer & s)34         Residue &operator=(const Integer &s) {r = 0; s.get(&r, 1); return *this;}
operator =(const uint64_t s)35         Residue &operator=(const uint64_t s) {r = s; return *this;}
36     };
37 
38     typedef ResidueStdOp<Residue> ResidueOp;
39 
40     /* Data members */
41 protected:
42     uint64_t m;
43     /* Methods used internally */
assertValid(const Residue & a MAYBE_UNUSED) const44     void assertValid(const Residue &a MAYBE_UNUSED) const {ASSERT_EXPENSIVE (a.r < m);}
assertValid(const uint64_t a MAYBE_UNUSED) const45     void assertValid(const uint64_t a MAYBE_UNUSED) const {ASSERT_EXPENSIVE (a < m);}
get_u64(const Residue & s) const46     uint64_t get_u64 (const Residue &s) const {assertValid(s); return s.r;}
47 
48     /* Methods of the API */
49 public:
getminmod()50     static Integer getminmod() {return Integer(1);}
getmaxmod()51     static Integer getmaxmod() {return Integer(UINT64_MAX);}
getminmod(Integer & r)52     static void getminmod(Integer &r) {r = getminmod();}
getmaxmod(Integer & r)53     static void getmaxmod(Integer &r) {r = getmaxmod();}
valid(const Integer & m)54     static bool valid(const Integer &m) {return getminmod() <= m && m <= getmaxmod();}
55 
Modulus64(const uint64_t s)56     Modulus64(const uint64_t s) : m(s){}
Modulus64(const Modulus64 & s)57     Modulus64(const Modulus64 &s) : m(s.m){}
Modulus64(const Integer & s)58     Modulus64(const Integer &s) {s.get(&m, 1);}
~Modulus64()59     ~Modulus64() {}
getmod_u64() const60     uint64_t getmod_u64 () const {return m;}
getmod(Integer & r) const61     void getmod (Integer &r) const {r = m;}
62 
63     /* Methods for residues */
64 
65     /** Allocate an array of len residues.
66      *
67      * Must be freed with deleteArray(), not with delete[].
68      */
newArray(const size_t len) const69     Residue *newArray(const size_t len) const {
70         void *t = operator new[](len * sizeof(Residue));
71         if (t == NULL)
72             return NULL;
73         Residue *ptr = static_cast<Residue *>(t);
74         for(size_t i = 0; i < len; i++) {
75             new(&ptr[i]) Residue(*this);
76         }
77         return ptr;
78     }
79 
deleteArray(Residue * ptr,const size_t len) const80     void deleteArray(Residue *ptr, const size_t len) const {
81         for(size_t i = len; i > 0; i++) {
82             ptr[i - 1].~Residue();
83         }
84         operator delete[](ptr);
85     }
86 
set(Residue & r,const Residue & s) const87     void set (Residue &r, const Residue &s) const {assertValid(s); r = s;}
set(Residue & r,const uint64_t s) const88     void set (Residue &r, const uint64_t s) const {r.r = s % m;}
set(Residue & r,const Integer & s) const89     void set (Residue &r, const Integer &s) const {s.get(&r.r, 1); r.r %= m;}
90     /* Sets the Residue to the class represented by the integer s. Assumes that
91        s is reduced (mod m), i.e. 0 <= s < m */
set_reduced(Residue & r,const uint64_t s) const92     void set_reduced (Residue &r, const uint64_t s) const {assertValid(s); r.r = s;}
set_reduced(Residue & r,const Integer & s) const93     void set_reduced (Residue &r, const Integer &s) const {s.get(&r.r, 1); assertValid(r);}
set_int64(Residue & r,const int64_t s) const94     void set_int64 (Residue &r, const int64_t s) const {r.r = llabs(s) % m; if (s < 0) neg(r, r);}
set0(Residue & r) const95     void set0 (Residue &r) const {r.r = 0;}
set1(Residue & r) const96     void set1 (Residue &r) const {r.r = (m != 1);}
97     /* Exchanges the values of the two arguments */
swap(Residue & a,Residue & b) const98     void swap (Residue &a, Residue &b) const {uint64_t t = a.r; a.r = b.r; b.r = t;}
get(Integer & r,const Residue & s) const99     void get (Integer &r, const Residue &s) const {assertValid(s); r = Integer(s.r);}
equal(const Residue & a,const Residue & b) const100     bool equal (const Residue &a, const Residue &b) const {assertValid(a); assertValid(b); return (a.r == b.r);}
is0(const Residue & a) const101     bool is0 (const Residue &a) const {assertValid(a); return (a.r == 0);}
is1(const Residue & a) const102     bool is1 (const Residue &a) const {assertValid(a); return (a.r == 1);}
neg(Residue & r,const Residue & a) const103     void neg (Residue &r, const Residue &a) const {
104         assertValid(a);
105         if (a.r == 0)
106             r.r = a.r;
107         else
108             r.r = m - a.r;
109     }
add(Residue & r,const Residue & a,const Residue & b) const110   void add (Residue &r, const Residue &a, const Residue &b) const {u64arith_addmod_1_1(&r.r, a.r, b.r, m);}
add1(Residue & r,const Residue & a) const111   void add1 (Residue &r, const Residue &a) const {
112     assertValid(a);
113     r.r = a.r + 1;
114     if (r.r == m)
115       r.r = 0;
116   }
add(Residue & r,const Residue & a,const uint64_t b) const117   void add (Residue &r, const Residue &a, const uint64_t b) const {
118     u64arith_addmod_1_1(&r.r, a.r, b % m, m);
119   }
sub(Residue & r,const Residue & a,const Residue & b) const120   void sub (Residue &r, const Residue &a, const Residue &b) const {
121     u64arith_submod_1_1(&r.r, a.r, b.r, m);
122   }
sub1(Residue & r,const Residue & a) const123   void sub1 (Residue &r, const Residue &a) const {
124     u64arith_submod_1_1(&r.r, a.r, 1, m);
125   }
sub(Residue & r,const Residue & a,const uint64_t b) const126   void sub (Residue &r, const Residue &a, const uint64_t b) const {
127     u64arith_submod_1_1(&r.r, a.r, b % m, m);
128   }
mul(Residue & r,const Residue & a,const Residue & b) const129   void mul (Residue &r, const Residue &a, const Residue &b) const {
130     uint64_t t1, t2;
131     assertValid(a);
132     assertValid(b);
133     u64arith_mul_1_1_2 (&t1, &t2, a.r, b.r);
134     u64arith_divr_2_1_1 (&r.r, t1, t2, m);
135   }
sqr(Residue & r,const Residue & a) const136   void sqr (Residue &r, const Residue &a) const {
137     uint64_t t1, t2;
138     assertValid(a);
139     u64arith_mul_1_1_2 (&t1, &t2, a.r, a.r);
140     u64arith_divr_2_1_1 (&r.r, t1, t2, m);
141   }
142   /* Computes (a * 2^wordsize) % m */
tomontgomery(Residue & r,const Residue & a) const143   void tomontgomery (Residue &r, const Residue &a) const {
144     assertValid(a);
145     u64arith_divr_2_1_1 (&r.r, 0, a.r, m);
146   }
147   /* Computes (a / 2^wordsize) % m */
frommontgomery(Residue & r,const Residue & a,const uint64_t invm) const148   void frommontgomery (Residue &r, const Residue &a, const uint64_t invm) const {
149     uint64_t tlow, thigh;
150     assertValid(a);
151     tlow = a.r * invm;
152     u64arith_mul_1_1_2 (&tlow, &thigh, tlow, m);
153     r.r = thigh + (a.r != 0 ? 1 : 0);
154   }
155   /* Computes (a / 2^wordsize) % m, but result can be r = m.
156      Input a must not be equal 0 */
redcsemi_u64_not0(Residue & r,const uint64_t a,const uint64_t invm) const157   void redcsemi_u64_not0 (Residue &r, const uint64_t a, const uint64_t invm) const {
158     uint64_t tlow, thigh;
159     ASSERT (a != 0);
160     tlow = a * invm; /* tlow <= 2^w-1 */
161     u64arith_mul_1_1_2 (&tlow, &thigh, tlow, m);
162     /* thigh:tlow <= (2^w-1) * m */
163     r.r = thigh + 1;
164     /* (thigh+1):tlow <= 2^w + (2^w-1) * m  <= 2^w + 2^w*m - m
165                       <= 2^w * (m + 1) - m */
166     /* r <= floor ((2^w * (m + 1) - m) / 2^w) <= floor((m + 1) - m/2^w)
167          <= m */
168   }
next(Residue & r) const169   bool next (Residue &r) const {return (++r.r == m);}
finished(const Residue & r) const170   bool finished (const Residue &r) const {return (r.r == m);}
div2(Residue & r,const Residue & a) const171   bool div2 (Residue &r, const Residue &a) const {
172       if (m % 2 == 0)
173           return false;
174       else {
175           r.r = u64arith_div2mod(a.r, m);
176           return true;
177       }
178     }
179 
180   /* Given a = V_n (x), b = V_m (x) and d = V_{n-m} (x), compute V_{m+n} (x).
181    * r can be the same variable as a or b but must not be the same variable as d.
182    */
V_dadd(Residue & r,const Residue & a,const Residue & b,const Residue & d) const183   void V_dadd (Residue &r, const Residue &a, const Residue &b,
184                const Residue &d) const {
185     ASSERT (&r != &d);
186     mul (r, a, b);
187     sub (r, r, d);
188   }
189 
190   /* Given a = V_n (x) and two = 2, compute V_{2n} (x).
191    * r can be the same variable as a but must not be the same variable as two.
192    */
V_dbl(Residue & r,const Residue & a,const Residue & two) const193   void V_dbl (Residue &r, const Residue &a, const Residue &two) const {
194     ASSERT (&r != &two);
195     sqr (r, a);
196     sub (r, r, two);
197   }
198 
199   /* prototypes of non-inline functions */
200   bool div3 (Residue &, const Residue &) const;
201   bool div5 (Residue &, const Residue &) const;
202   bool div7 (Residue &, const Residue &) const;
203   bool div11 (Residue &, const Residue &) const;
204   bool div13 (Residue &, const Residue &) const;
205   void gcd (Integer &, const Residue &) const;
206   void pow (Residue &, const Residue &, const uint64_t) const;
207   void pow (Residue &, const Residue &, const uint64_t *, const size_t) const;
208   void pow (Residue &, const Residue &, const Integer &) const;
209   void pow2 (Residue &, const uint64_t) const;
210   void pow2 (Residue &, const uint64_t *, const size_t) const;
211   void pow2 (Residue &, const Integer &) const;
212   void pow3 (Residue &, uint64_t) const;
213   void V (Residue &, const Residue &, const uint64_t) const;
214   void V (Residue &, const Residue &, const uint64_t *, const size_t) const;
215   void V (Residue &, const Residue &, const Integer &) const;
216   void V (Residue &r, Residue *rp1, const Residue &b,
217           const uint64_t k) const;
218   bool sprp (const Residue &) const;
219   bool sprp2 () const;
220   bool isprime () const;
221   bool inv (Residue &, const Residue &) const;
222   bool inv_odd (Residue &, const Residue &) const;
223   bool inv_powerof2 (Residue &, const Residue &) const;
224   bool batchinv (Residue *, const Residue *, size_t, const Residue *) const;
225   int jacobi (const Residue &) const;
226 protected:
227   bool find_minus1 (Residue &r1, const Residue &minusone, const int po2) const;
228 };
229 
230 #endif  /* MOD64_HPP */
231