1 /* A class for modular arithmetic with residues and modulus of up to 64
2  * bits. */
3 
4 #ifndef MODMPZ_HPP
5 #define MODMPZ_HPP
6 
7 /**********************************************************************/
8 #include <cstdint>
9 #include <gmp.h>          // for mp_limb_t, mpz_size, __gmpn_copyi, __gmpn_s...
10 #include <cstddef>        // for size_t, NULL
11 #include <new>            // for operator new
12 #include "macros.h"
13 #include "gmp_aux.h"      // for mpz_cmp_uint64, mpz_get_uint64, mpz_set_uint64
14 #include "cxx_mpz.hpp"
15 // #include "modint.hpp"
16 #include "mod_stdop.hpp"
17 #include "misc.h"
18 
19 class ModulusMPZ {
20     /* Type definitions */
21 public:
22     typedef cxx_mpz Integer;
23     class Residue {
24         friend class ModulusMPZ;
25     protected:
26         mp_limb_t *r;
27     public:
28         typedef ModulusMPZ Modulus;
29         typedef Modulus::Integer Integer;
30         typedef bool IsResidueType;
31         Residue() = delete;
Residue(const Modulus & m)32         Residue(const Modulus &m) {
33             r = new mp_limb_t[mpz_size(m.m) + 1];
34             mpn_zero(r, mpz_size(m.m) + 1);
35         }
Residue(const Modulus & m,const Residue & s)36         Residue(const Modulus &m, const Residue &s) {
37             r = new mp_limb_t[mpz_size(m.m) + 1];
38             mpn_copyi(r, s.r, mpz_size(m.m));
39         }
~Residue()40         ~Residue() {
41             delete[] r;
42         }
Residue(Residue && s)43         Residue(Residue &&s) : r(s.r) {
44             s.r = NULL;
45         }
46         Residue(Residue const & s) = delete;
operator =(Residue && s)47         Residue& operator=(Residue &&s) {
48             delete[] r;
49             r = s.r;
50             s.r = NULL;
51             return *this;
52         }
53         Residue& operator=(Residue const & s) = delete;
54     };
55 
56     typedef ResidueStdOp<Residue> ResidueOp;
57 
58     /* Data members */
59 protected:
60     mpz_t m;
61     size_t bits;
62     static const size_t limbsPerUint64 = iceildiv(64, GMP_NUMB_BITS);
63 
64     /** Convert a uint64_t to an array of mp_limb_t.
65      * Always writes exactly limbsToWrite limbs. */
uint64ToLimbs(mp_limb_t * r,const uint64_t s,const size_t limbsToWrite)66     static void uint64ToLimbs(mp_limb_t *r, const uint64_t s, const size_t limbsToWrite) {
67         uint64_t t = s;
68         for (size_t i = 0; i < limbsToWrite; i++) {
69             r[i] = t & GMP_NUMB_MASK;
70 #if GMP_NUMB_BITS < 64
71                 t >>= GMP_NUMB_BITS;
72 #else
73                 t = 0;
74 #endif
75         }
76         ASSERT_ALWAYS(t == 0);
77     }
78 
79     /* Methods used internally */
80     /** Set a residue to m.
81      *  This leaves the residue in a non-reduced state. */
setM(mp_limb_t * r) const82     void setM(mp_limb_t *r) const {
83         mpn_copyi(r, m->_mp_d, mpz_size(m));
84     }
85     /** Add M to a residue, return carry */
addM(mp_limb_t * r,const mp_limb_t * s) const86     mp_limb_t addM(mp_limb_t *r, const mp_limb_t *s) const {
87         return mpn_add_n(r, s, m->_mp_d, mpz_size(m));
88     }
89     /** Subtract M from a residue, return borrow */
subM(mp_limb_t * r,const mp_limb_t * s) const90     mp_limb_t subM(mp_limb_t *r, const mp_limb_t *s) const {
91         return mpn_sub_n(r, s, m->_mp_d, mpz_size(m));
92     }
93     /** Returns a negative value if s < m, 0 if s == m, and a positive value if s > m. */
cmpM(const mp_limb_t * s) const94     int cmpM(const mp_limb_t *s) const {
95         return mpn_cmp(s, m->_mp_d, mpz_size(m));
96     }
modM(const uint64_t s) const97     uint64_t modM(const uint64_t s) const {
98         if (bits > 64 || mpz_cmp_uint64(m, s) > 0)
99             return s;
100         /* m <= s, thus m fits in uint64_t */
101         uint64_t m64 = mpz_get_uint64(m);
102         /* We always make sure that the modulus in non-zero in the ctor
103          */
104         ASSERT_FOR_STATIC_ANALYZER(m64 != 0);
105         return s % m64;
106     }
set_residue_u64(Residue & r,const uint64_t s) const107     void set_residue_u64(Residue &r, const uint64_t s) const {
108         ASSERT(mpz_cmp_uint64(m, s) > 0);
109         const size_t limbsToWrite = MIN(limbsPerUint64, mpz_size(m));
110         uint64ToLimbs(r.r, s, limbsToWrite);
111         for (size_t i = limbsToWrite; i < mpz_size(m); i++)
112             r.r[i] = 0;
113     }
set_residue_mpz(Residue & r,const mpz_t s) const114     void set_residue_mpz(Residue &r, const mpz_t s) const {
115         ASSERT_ALWAYS(mpz_cmp(s, m) < 0);
116         size_t written;
117         mpz_export(r.r, &written, -1, sizeof(mp_limb_t), 0, GMP_NAIL_BITS, s);
118         ASSERT_ALWAYS(written <= mpz_size(m));
119         for (size_t i = written; i < mpz_size(m); i++)
120             r.r[i] = 0;
121     }
set_mpz_residue(mpz_t r,const Residue & s) const122     void set_mpz_residue(mpz_t r, const Residue &s) const {
123         mpz_import(r, mpz_size(m), -1, sizeof(mp_limb_t), 0, GMP_NAIL_BITS, s.r);
124     }
assertValid(const Residue & a MAYBE_UNUSED) const125     void assertValid(const Residue &a MAYBE_UNUSED) const {
126         ASSERT_EXPENSIVE (cmpM(a.r) < 0);
127     }
assertValid(const uint64_t a MAYBE_UNUSED) const128     void assertValid(const uint64_t a MAYBE_UNUSED) const {
129         ASSERT_EXPENSIVE (mpz_cmp_uint64(m, a) > 0);
130     }
assertValid(const mpz_t s MAYBE_UNUSED) const131     void assertValid(const mpz_t s MAYBE_UNUSED) const {
132         ASSERT_EXPENSIVE (mpz_cmp(s, m) < 0);
133     }
assertValid(const cxx_mpz & s MAYBE_UNUSED) const134     void assertValid(const cxx_mpz &s MAYBE_UNUSED) const {
135         assertValid((mpz_srcptr) s);
136     }
137 
138     /* Methods of the API */
139 public:
getminmod(Integer & r)140     static void getminmod(Integer &r) {
141         mpz_set_ui(r, 0);
142     }
getmaxmod(Integer & r)143     static void getmaxmod(Integer &r) {
144         mpz_set_ui(r, 0);
145     }
valid(const Integer & m MAYBE_UNUSED)146     static bool valid(const Integer &m MAYBE_UNUSED) {
147         return true;
148     }
149 
ModulusMPZ(const uint64_t s)150     ModulusMPZ(const uint64_t s) {
151         ASSERT_ALWAYS(s > 0);
152         mpz_init(m);
153         mpz_set_uint64(m, s);
154         bits = mpz_sizeinbase(m, 2);
155     }
ModulusMPZ(const ModulusMPZ & s)156     ModulusMPZ(const ModulusMPZ &s) {
157         mpz_init_set(m, s.m);
158         bits = s.bits;
159     }
ModulusMPZ(const Integer & s)160     ModulusMPZ(const Integer &s) {
161         ASSERT_ALWAYS(mpz_sgn(s) > 0);
162         mpz_init(m);
163         mpz_set(m, s);
164         bits = mpz_sizeinbase(m, 2);
165     }
~ModulusMPZ()166     ~ModulusMPZ() {
167         mpz_clear(m);
168     }
getmod(Integer & r) const169     void getmod (Integer &r) const {
170         mpz_set(r, m);
171     }
172 
173     /* Methods for residues */
174 
175     /** Allocate an array of len residues.
176      *
177      * Must be freed with deleteArray(), not with delete[].
178      */
newArray(const size_t len) const179     Residue *newArray(const size_t len) const {
180         void *t = operator new[](len * sizeof(Residue));
181         if (t == NULL)
182             return NULL;
183         Residue *ptr = static_cast<Residue *>(t);
184         for(size_t i = 0; i < len; i++) {
185             new(&ptr[i]) Residue(*this);
186         }
187         return ptr;
188     }
189 
deleteArray(Residue * ptr,const size_t len) const190     void deleteArray(Residue *ptr, const size_t len) const {
191         for(size_t i = len; i > 0; i++) {
192             ptr[i - 1].~Residue();
193         }
194         operator delete[](ptr);
195     }
196 
197 
set(Residue & r,const Residue & s) const198     void set (Residue &r, const Residue &s) const {
199         assertValid(s);
200         mpn_copyi(r.r, s.r, mpz_size(m));
201     }
set(Residue & r,const uint64_t s) const202     void set (Residue &r, const uint64_t s) const {
203         const uint64_t sm = modM(s);
204         set_residue_u64(r, sm);
205     }
set(Residue & r,const Integer & s) const206     void set (Residue &r, const Integer &s) const {
207         cxx_mpz t;
208         mpz_mod(t, s, m);
209         set_residue_mpz(r, t);
210     }
211     /* Sets the Residue to the class represented by the integer s. Assumes that
212        s is reduced (mod m), i.e. 0 <= s < m */
set_reduced(Residue & r,const uint64_t s) const213     void set_reduced (Residue &r, const uint64_t s) const {
214         assertValid(s);
215         set_residue_u64(r, s);
216     }
set_reduced(Residue & r,const Integer & s) const217     void set_reduced (Residue &r, const Integer &s) const {
218         assertValid(r);
219         set_residue_mpz(r, s);
220     }
set_int64(Residue & r,const int64_t s) const221     void set_int64 (Residue &r, const int64_t s) const {
222         const uint64_t u = modM(safe_abs64(s));
223         set_residue_u64(r, u);
224         if (s < 0)
225             neg(r, r);
226     }
set0(Residue & r) const227     void set0 (Residue &r) const {
228         mpn_zero(r.r, mpz_size(m));
229     }
set1(Residue & r) const230     void set1 (Residue &r) const {
231         set0(r);
232         if (mpz_cmp_ui(m, 1) > 0) {
233             r.r[0] = 1;
234         }
235     }
236     /* Exchanges the values of the two arguments */
swap(Residue & a,Residue & b) const237     void swap (Residue &a, Residue &b) const {
238         mp_limb_t *t = a.r;
239         a.r = b.r;
240         b.r = t;
241     }
get(Integer & r,const Residue & s) const242     void get (Integer &r, const Residue &s) const {
243         assertValid(s);
244         set_mpz_residue(r, s);
245     }
equal(const Residue & a,const Residue & b) const246     bool equal (const Residue &a, const Residue &b) const {
247         assertValid(a);
248         assertValid(b);
249         return mpn_cmp(a.r, b.r, mpz_size(m)) == 0;
250     }
is0(const Residue & a) const251     bool is0 (const Residue &a) const {
252         assertValid(a);
253         return mpn_zero_p(a.r, mpz_size(m));
254     }
is1(const Residue & a) const255     bool is1 (const Residue &a) const {
256         assertValid(a);
257         if (mpz_cmp_ui(m, 1) == 0)
258             return 1;
259         return a.r[0] == 1 && (mpz_size(m) == 1 || mpn_zero_p(a.r + 1, mpz_size(m) - 1));
260     }
neg(Residue & r,const Residue & a) const261     void neg (Residue &r, const Residue &a) const {
262         assertValid(a);
263         if (is0(a) != 0) {
264             if (&r != &a)
265                 set0(r);
266         } else {
267             mp_limb_t bw = mpn_sub_n(r.r, m->_mp_d, r.r, mpz_size(m));
268             ASSERT_ALWAYS(bw == 0);
269         }
270         assertValid(r);
271     }
add(Residue & r,const Residue & a,const Residue & b) const272     void add (Residue &r, const Residue &a, const Residue &b) const {
273         assertValid(a);
274         assertValid(b);
275         const mp_limb_t cy = mpn_add_n(r.r, a.r, b.r, mpz_size(m));
276         if (cy || cmpM(r.r) >= 0) {
277             const mp_limb_t bw = subM(r.r, r.r);
278             ASSERT_ALWAYS(bw == cy);
279         }
280         assertValid(r);
281     }
add1(Residue & r,const Residue & a) const282     void add1 (Residue &r, const Residue &a) const {
283         assertValid(a);
284         const mp_limb_t cy = mpn_add_1(r.r, a.r, 1, mpz_size(m));
285         if (cy || cmpM(r.r) >= 0) {
286             const mp_limb_t bw = subM(r.r, r.r);
287             ASSERT_ALWAYS(bw == cy);
288         }
289         assertValid(r);
290     }
add(Residue & r,const Residue & a,const uint64_t b) const291     void add (Residue &r, const Residue &a, const uint64_t b) const {
292         assertValid(a);
293         const uint64_t bm = modM(b);
294         mp_limb_t cy;
295 
296         if (limbsPerUint64 == 1 || bm < GMP_NUMB_MAX) {
297             cy = mpn_add_1(r.r, a.r, (mp_limb_t) bm, mpz_size(m));
298         } else {
299             mp_limb_t t[limbsPerUint64];
300             const size_t toWrite = MIN(limbsPerUint64, mpz_size(m));
301 
302             uint64ToLimbs(t, bm, limbsPerUint64);
303             cy = mpn_add(r.r, a.r, mpz_size(m), t, toWrite);
304         }
305         if (cy || cmpM(r.r) >= 0) {
306             mp_limb_t bw = subM(r.r, r.r);
307             ASSERT_ALWAYS(bw == cy);
308         }
309         assertValid(r);
310     }
sub(Residue & r,const Residue & a,const Residue & b) const311     void sub (Residue &r, const Residue &a, const Residue &b) const {
312         assertValid(a);
313         assertValid(b);
314         const mp_limb_t bw = mpn_sub_n(r.r, a.r, b.r, mpz_size(m));
315         if (bw) {
316             const mp_limb_t cy = addM(r.r, r.r);
317             ASSERT_ALWAYS(cy == bw);
318         }
319         assertValid(r);
320     }
sub1(Residue & r,const Residue & a) const321     void sub1 (Residue &r, const Residue &a) const {
322         assertValid(a);
323         const mp_limb_t bw = mpn_sub_1(r.r, a.r, 1, mpz_size(m));
324         if (bw) {
325             const mp_limb_t cy = addM(r.r, r.r);
326             ASSERT_ALWAYS(cy == bw);
327         }
328         assertValid(r);
329     }
sub(Residue & r,const Residue & a,const uint64_t b) const330     void sub (Residue &r, const Residue &a, const uint64_t b) const {
331         assertValid(a);
332         const uint64_t bm = modM(b);
333         mp_limb_t bw;
334 
335         if (limbsPerUint64 == 1 || bm < GMP_NUMB_MAX) {
336             bw = mpn_sub_1(r.r, a.r, (mp_limb_t) bm, mpz_size(m));
337         } else {
338             mp_limb_t t[limbsPerUint64];
339             const size_t toWrite = MIN(limbsPerUint64, mpz_size(m));
340 
341             uint64ToLimbs(t, bm, limbsPerUint64);
342             bw = mpn_sub(r.r, a.r, mpz_size(m), t, toWrite);
343         }
344         if (bw) {
345             mp_limb_t cy = addM(r.r, r.r);
346             ASSERT_ALWAYS(cy == bw);
347         }
348         assertValid(r);
349     }
mul(Residue & r,const Residue & a,const Residue & b) const350     void mul (Residue &r, const Residue &a, const Residue &b) const {
351         const size_t nrWords = mpz_size(m);
352         mp_limb_t Q[2];
353         mp_limb_t *t;
354         if (r.r == a.r || r.r == b.r) {
355             t = new mp_limb_t[nrWords + 1];
356         } else {
357             t = r.r;
358         }
359         t[nrWords] = mpn_mul_1 (t, a.r, nrWords, b.r[nrWords - 1]);
360         if (t[nrWords] != 0)
361             mpn_tdiv_qr (Q, t, 0, t, nrWords + 1, m->_mp_d, nrWords);
362         /* t <= (m-1) * beta */
363         for (size_t iWord = nrWords - 1; iWord > 0; iWord--) {
364             mpn_copyd(t + 1, t, nrWords);
365             t[0] = 0;
366             mp_limb_t msw = mpn_addmul_1 (t, a.r, nrWords, b.r[iWord - 1]);
367             t[nrWords] += msw;
368             mp_limb_t cy = t[nrWords] < msw;
369             if (cy) {
370                 mp_limb_t bw = subM(t + 1, t + 1);
371                 ASSERT_ALWAYS(bw == cy);
372             }
373             if (t[nrWords] != 0)
374                 mpn_tdiv_qr (Q, t, 0, t, nrWords + 1, m->_mp_d, nrWords);
375         }
376         if (cmpM(t) >= 0) {
377             mpn_tdiv_qr (Q, t, 0, t, nrWords, m->_mp_d, nrWords);
378         }
379         if (r.r == a.r || r.r == b.r) {
380             mpn_copyi(r.r, t, nrWords);
381             delete[] t;
382         }
383     }
sqr(Residue & r,const Residue & a) const384     void sqr (Residue &r, const Residue &a) const {
385         mul(r, a, a);
386     }
next(Residue & r) const387     bool next (Residue &r) const {
388         add1(r, r);
389         return finished(r);
390     }
finished(const Residue & r) const391     bool finished (const Residue &r) const {
392         return is0(r);
393     }
div2(Residue & r,const Residue & a) const394     bool div2 (Residue &r, const Residue &a) const {
395         assertValid(a);
396         if (mpz_even_p(m)) {
397             return false;
398         } else {
399             if (a.r[0] % 2 == 0) {
400                 mp_limb_t lsb = mpn_rshift(r.r, a.r, mpz_size(m), 1);
401                 ASSERT_ALWAYS(lsb == 0);
402             } else {
403                 mp_limb_t cy = addM(r.r, a.r);
404                 mp_limb_t lsb = mpn_rshift(r.r, r.r, mpz_size(m), 1);
405                 ASSERT_ALWAYS(lsb == 0);
406                 r.r[mpz_size(m) - 1] |= cy << (GMP_NUMB_BITS - 1);
407             }
408             assertValid(r);
409             return true;
410         }
411     }
412 
413     /* Given a = V_n (x), b = V_m (x) and d = V_{n-m} (x), compute V_{m+n} (x).
414      * r can be the same variable as a or b but must not be the same variable as d.
415      */
V_dadd(Residue & r,const Residue & a,const Residue & b,const Residue & d) const416     void V_dadd (Residue &r, const Residue &a, const Residue &b,
417                      const Residue &d) const {
418         ASSERT (&r != &d);
419         mul (r, a, b);
420         sub (r, r, d);
421     }
422 
423     /* Given a = V_n (x) and two = 2, compute V_{2n} (x).
424      * r can be the same variable as a but must not be the same variable as two.
425      */
V_dbl(Residue & r,const Residue & a,const Residue & two) const426     void V_dbl (Residue &r, const Residue &a, const Residue &two) const {
427         ASSERT (&r != &two);
428         sqr (r, a);
429         sub (r, r, two);
430     }
431 
432 
433     /* prototypes of non-inline functions */
434     bool div3 (Residue &, const Residue &) const;
435     bool div5 (Residue &, const Residue &) const;
436     bool div7 (Residue &, const Residue &) const;
437     bool div11 (Residue &, const Residue &) const;
438     bool div13 (Residue &, const Residue &) const;
439     bool divn (Residue &, const Residue &, unsigned long) const;
440     void gcd (Integer &, const Residue &) const;
441     void pow (Residue &, const Residue &, const uint64_t) const;
442     void pow (Residue &, const Residue &, const uint64_t *, const size_t) const;
443     void pow (Residue &r, const Residue &b, const Integer &e) const;
444     void pow2 (Residue &, const uint64_t) const;
445     void pow2 (Residue &, const uint64_t *, const size_t) const;
446     void pow2 (Residue &r, const Integer &e) const;
447     void pow3 (Residue &, uint64_t) const;
448     void V (Residue &, const Residue &, const uint64_t) const;
449     void V (Residue &, const Residue &, const uint64_t *, const int) const;
450     void V (Residue &r, const Residue &b, const Integer &e) const;
451     void V (Residue &r, Residue *rp1, const Residue &b,
452             const uint64_t k) const;
453     bool sprp (const Residue &) const;
454     bool sprp2 () const;
455     bool isprime () const;
456     bool inv (Residue &, const Residue &) const;
457     bool inv_odd (Residue &, const Residue &) const;
458     bool inv_powerof2 (Residue &, const Residue &) const;
459     bool batchinv (Residue *, const Residue *, size_t, const Residue *) const;
460     int jacobi (const Residue &) const;
461 protected:
462     bool find_minus1 (Residue &r1, const Residue &minusone, const int po2) const;
463     bool divn (Residue &, const Residue &, unsigned long, const mp_limb_t *, mp_limb_t) const;
464 };
465 
466 #endif  /* MOD64_HPP */
467