1 /*
2  * This file is part of the MicroPython project, http://micropython.org/
3  *
4  * The MIT License (MIT)
5  *
6  * Copyright (c) 2013, 2014 Damien P. George
7  *
8  * Permission is hereby granted, free of charge, to any person obtaining a copy
9  * of this software and associated documentation files (the "Software"), to deal
10  * in the Software without restriction, including without limitation the rights
11  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12  * copies of the Software, and to permit persons to whom the Software is
13  * furnished to do so, subject to the following conditions:
14  *
15  * The above copyright notice and this permission notice shall be included in
16  * all copies or substantial portions of the Software.
17  *
18  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24  * THE SOFTWARE.
25  */
26 
27 #include <string.h>
28 #include <assert.h>
29 
30 #include "py/mpz.h"
31 
32 #if MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
33 
34 #define DIG_SIZE (MPZ_DIG_SIZE)
35 #define DIG_MASK ((MPZ_LONG_1 << DIG_SIZE) - 1)
36 #define DIG_MSB  (MPZ_LONG_1 << (DIG_SIZE - 1))
37 #define DIG_BASE (MPZ_LONG_1 << DIG_SIZE)
38 
39 /*
40  mpz is an arbitrary precision integer type with a public API.
41 
42  mpn functions act on non-negative integers represented by an array of generalised
43  digits (eg a word per digit).  You also need to specify separately the length of the
44  array.  There is no public API for mpn.  Rather, the functions are used by mpz to
45  implement its features.
46 
47  Integer values are stored little endian (first digit is first in memory).
48 
49  Definition of normalise: ?
50 */
51 
mpn_remove_trailing_zeros(mpz_dig_t * oidig,mpz_dig_t * idig)52 STATIC size_t mpn_remove_trailing_zeros(mpz_dig_t *oidig, mpz_dig_t *idig) {
53     for (--idig; idig >= oidig && *idig == 0; --idig) {
54     }
55     return idig + 1 - oidig;
56 }
57 
58 /* compares i with j
59    returns sign(i - j)
60    assumes i, j are normalised
61 */
mpn_cmp(const mpz_dig_t * idig,size_t ilen,const mpz_dig_t * jdig,size_t jlen)62 STATIC int mpn_cmp(const mpz_dig_t *idig, size_t ilen, const mpz_dig_t *jdig, size_t jlen) {
63     if (ilen < jlen) {
64         return -1;
65     }
66     if (ilen > jlen) {
67         return 1;
68     }
69 
70     for (idig += ilen, jdig += ilen; ilen > 0; --ilen) {
71         mpz_dbl_dig_signed_t cmp = (mpz_dbl_dig_t)*(--idig) - (mpz_dbl_dig_t)*(--jdig);
72         if (cmp < 0) {
73             return -1;
74         }
75         if (cmp > 0) {
76             return 1;
77         }
78     }
79 
80     return 0;
81 }
82 
83 /* computes i = j << n
84    returns number of digits in i
85    assumes enough memory in i; assumes normalised j; assumes n > 0
86    can have i, j pointing to same memory
87 */
mpn_shl(mpz_dig_t * idig,mpz_dig_t * jdig,size_t jlen,mp_uint_t n)88 STATIC size_t mpn_shl(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mp_uint_t n) {
89     mp_uint_t n_whole = (n + DIG_SIZE - 1) / DIG_SIZE;
90     mp_uint_t n_part = n % DIG_SIZE;
91     if (n_part == 0) {
92         n_part = DIG_SIZE;
93     }
94 
95     // start from the high end of the digit arrays
96     idig += jlen + n_whole - 1;
97     jdig += jlen - 1;
98 
99     // shift the digits
100     mpz_dbl_dig_t d = 0;
101     for (size_t i = jlen; i > 0; i--, idig--, jdig--) {
102         d |= *jdig;
103         *idig = (d >> (DIG_SIZE - n_part)) & DIG_MASK;
104         d <<= DIG_SIZE;
105     }
106 
107     // store remaining bits
108     *idig = (d >> (DIG_SIZE - n_part)) & DIG_MASK;
109     idig -= n_whole - 1;
110     memset(idig, 0, (n_whole - 1) * sizeof(mpz_dig_t));
111 
112     // work out length of result
113     jlen += n_whole;
114     while (jlen != 0 && idig[jlen - 1] == 0) {
115         jlen--;
116     }
117 
118     // return length of result
119     return jlen;
120 }
121 
122 /* computes i = j >> n
123    returns number of digits in i
124    assumes enough memory in i; assumes normalised j; assumes n > 0
125    can have i, j pointing to same memory
126 */
mpn_shr(mpz_dig_t * idig,mpz_dig_t * jdig,size_t jlen,mp_uint_t n)127 STATIC size_t mpn_shr(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mp_uint_t n) {
128     mp_uint_t n_whole = n / DIG_SIZE;
129     mp_uint_t n_part = n % DIG_SIZE;
130 
131     if (n_whole >= jlen) {
132         return 0;
133     }
134 
135     jdig += n_whole;
136     jlen -= n_whole;
137 
138     for (size_t i = jlen; i > 0; i--, idig++, jdig++) {
139         mpz_dbl_dig_t d = *jdig;
140         if (i > 1) {
141             d |= (mpz_dbl_dig_t)jdig[1] << DIG_SIZE;
142         }
143         d >>= n_part;
144         *idig = d & DIG_MASK;
145     }
146 
147     if (idig[-1] == 0) {
148         jlen--;
149     }
150 
151     return jlen;
152 }
153 
154 /* computes i = j + k
155    returns number of digits in i
156    assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
157    can have i, j, k pointing to same memory
158 */
mpn_add(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen)159 STATIC size_t mpn_add(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
160     mpz_dig_t *oidig = idig;
161     mpz_dbl_dig_t carry = 0;
162 
163     jlen -= klen;
164 
165     for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
166         carry += (mpz_dbl_dig_t)*jdig + (mpz_dbl_dig_t)*kdig;
167         *idig = carry & DIG_MASK;
168         carry >>= DIG_SIZE;
169     }
170 
171     for (; jlen > 0; --jlen, ++idig, ++jdig) {
172         carry += *jdig;
173         *idig = carry & DIG_MASK;
174         carry >>= DIG_SIZE;
175     }
176 
177     if (carry != 0) {
178         *idig++ = carry;
179     }
180 
181     return idig - oidig;
182 }
183 
184 /* computes i = j - k
185    returns number of digits in i
186    assumes enough memory in i; assumes normalised j, k; assumes j >= k
187    can have i, j, k pointing to same memory
188 */
mpn_sub(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen)189 STATIC size_t mpn_sub(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
190     mpz_dig_t *oidig = idig;
191     mpz_dbl_dig_signed_t borrow = 0;
192 
193     jlen -= klen;
194 
195     for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
196         borrow += (mpz_dbl_dig_t)*jdig - (mpz_dbl_dig_t)*kdig;
197         *idig = borrow & DIG_MASK;
198         borrow >>= DIG_SIZE;
199     }
200 
201     for (; jlen > 0; --jlen, ++idig, ++jdig) {
202         borrow += *jdig;
203         *idig = borrow & DIG_MASK;
204         borrow >>= DIG_SIZE;
205     }
206 
207     return mpn_remove_trailing_zeros(oidig, idig);
208 }
209 
210 #if MICROPY_OPT_MPZ_BITWISE
211 
212 /* computes i = j & k
213    returns number of digits in i
214    assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen (jlen argument not needed)
215    can have i, j, k pointing to same memory
216 */
mpn_and(mpz_dig_t * idig,const mpz_dig_t * jdig,const mpz_dig_t * kdig,size_t klen)217 STATIC size_t mpn_and(mpz_dig_t *idig, const mpz_dig_t *jdig, const mpz_dig_t *kdig, size_t klen) {
218     mpz_dig_t *oidig = idig;
219 
220     for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
221         *idig = *jdig & *kdig;
222     }
223 
224     return mpn_remove_trailing_zeros(oidig, idig);
225 }
226 
227 #endif
228 
229 /*  i = -((-j) & (-k))                = ~((~j + 1) & (~k + 1)) + 1
230     i =  (j & (-k)) =  (j & (~k + 1)) =  (  j      & (~k + 1))
231     i =  ((-j) & k) =  ((~j + 1) & k) =  ((~j + 1) &   k     )
232    computes general form:
233    i = (im ^ (((j ^ jm) + jc) & ((k ^ km) + kc))) + ic  where Xm = Xc == 0 ? 0 : DIG_MASK
234    returns number of digits in i
235    assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
236    can have i, j, k pointing to same memory
237 */
mpn_and_neg(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen,mpz_dbl_dig_t carryi,mpz_dbl_dig_t carryj,mpz_dbl_dig_t carryk)238 STATIC size_t mpn_and_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
239     mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
240     mpz_dig_t *oidig = idig;
241     mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
242     mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
243     mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
244 
245     for (; jlen > 0; ++idig, ++jdig) {
246         carryj += *jdig ^ jmask;
247         carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
248         carryi += ((carryj & carryk) ^ imask) & DIG_MASK;
249         *idig = carryi & DIG_MASK;
250         carryk >>= DIG_SIZE;
251         carryj >>= DIG_SIZE;
252         carryi >>= DIG_SIZE;
253     }
254 
255     if (0 != carryi) {
256         *idig++ = carryi;
257     }
258 
259     return mpn_remove_trailing_zeros(oidig, idig);
260 }
261 
262 #if MICROPY_OPT_MPZ_BITWISE
263 
264 /* computes i = j | k
265    returns number of digits in i
266    assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
267    can have i, j, k pointing to same memory
268 */
mpn_or(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen)269 STATIC size_t mpn_or(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
270     mpz_dig_t *oidig = idig;
271 
272     jlen -= klen;
273 
274     for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
275         *idig = *jdig | *kdig;
276     }
277 
278     for (; jlen > 0; --jlen, ++idig, ++jdig) {
279         *idig = *jdig;
280     }
281 
282     return idig - oidig;
283 }
284 
285 #endif
286 
287 /*  i = -((-j) | (-k))                = ~((~j + 1) | (~k + 1)) + 1
288     i = -(j | (-k)) = -(j | (~k + 1)) = ~(  j      | (~k + 1)) + 1
289     i = -((-j) | k) = -((~j + 1) | k) = ~((~j + 1) |   k     ) + 1
290    computes general form:
291    i = ~(((j ^ jm) + jc) | ((k ^ km) + kc)) + 1  where Xm = Xc == 0 ? 0 : DIG_MASK
292    returns number of digits in i
293    assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
294    can have i, j, k pointing to same memory
295 */
296 
297 #if MICROPY_OPT_MPZ_BITWISE
298 
mpn_or_neg(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen,mpz_dbl_dig_t carryj,mpz_dbl_dig_t carryk)299 STATIC size_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
300     mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
301     mpz_dig_t *oidig = idig;
302     mpz_dbl_dig_t carryi = 1;
303     mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
304     mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
305 
306     for (; jlen > 0; ++idig, ++jdig) {
307         carryj += *jdig ^ jmask;
308         carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
309         carryi += ((carryj | carryk) ^ DIG_MASK) & DIG_MASK;
310         *idig = carryi & DIG_MASK;
311         carryk >>= DIG_SIZE;
312         carryj >>= DIG_SIZE;
313         carryi >>= DIG_SIZE;
314     }
315 
316     // At least one of j,k must be negative so the above for-loop runs at least
317     // once.  For carryi to be non-zero here it must be equal to 1 at the end of
318     // each iteration of the loop.  So the accumulation of carryi must overflow
319     // each time, ie carryi += 0xff..ff.  So carryj|carryk must be 0 in the
320     // DIG_MASK bits on each iteration.  But considering all cases of signs of
321     // j,k one sees that this is not possible.
322     assert(carryi == 0);
323 
324     return mpn_remove_trailing_zeros(oidig, idig);
325 }
326 
327 #else
328 
mpn_or_neg(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen,mpz_dbl_dig_t carryi,mpz_dbl_dig_t carryj,mpz_dbl_dig_t carryk)329 STATIC size_t mpn_or_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
330     mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
331     mpz_dig_t *oidig = idig;
332     mpz_dig_t imask = (0 == carryi) ? 0 : DIG_MASK;
333     mpz_dig_t jmask = (0 == carryj) ? 0 : DIG_MASK;
334     mpz_dig_t kmask = (0 == carryk) ? 0 : DIG_MASK;
335 
336     for (; jlen > 0; ++idig, ++jdig) {
337         carryj += *jdig ^ jmask;
338         carryk += (--klen <= --jlen) ? (*kdig++ ^ kmask) : kmask;
339         carryi += ((carryj | carryk) ^ imask) & DIG_MASK;
340         *idig = carryi & DIG_MASK;
341         carryk >>= DIG_SIZE;
342         carryj >>= DIG_SIZE;
343         carryi >>= DIG_SIZE;
344     }
345 
346     // See comment in above mpn_or_neg for why carryi must be 0.
347     assert(carryi == 0);
348 
349     return mpn_remove_trailing_zeros(oidig, idig);
350 }
351 
352 #endif
353 
354 #if MICROPY_OPT_MPZ_BITWISE
355 
356 /* computes i = j ^ k
357    returns number of digits in i
358    assumes enough memory in i; assumes normalised j, k; assumes jlen >= klen
359    can have i, j, k pointing to same memory
360 */
mpn_xor(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen)361 STATIC size_t mpn_xor(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen) {
362     mpz_dig_t *oidig = idig;
363 
364     jlen -= klen;
365 
366     for (; klen > 0; --klen, ++idig, ++jdig, ++kdig) {
367         *idig = *jdig ^ *kdig;
368     }
369 
370     for (; jlen > 0; --jlen, ++idig, ++jdig) {
371         *idig = *jdig;
372     }
373 
374     return mpn_remove_trailing_zeros(oidig, idig);
375 }
376 
377 #endif
378 
379 /*  i = (-j) ^ (-k) = ~(j - 1) ^ ~(k - 1)                   = (j - 1) ^ (k - 1)
380     i = -(j ^ (-k)) = -(j ^ ~(k - 1)) = ~(j ^ ~(k - 1)) + 1 = (j ^ (k - 1)) + 1
381     i = -((-j) ^ k) = -(~(j - 1) ^ k) = ~(~(j - 1) ^ k) + 1 = ((j - 1) ^ k) + 1
382    computes general form:
383    i = ((j - 1 + jc) ^ (k - 1 + kc)) + ic
384    returns number of digits in i
385    assumes enough memory in i; assumes normalised j, k; assumes length j >= length k
386    can have i, j, k pointing to same memory
387 */
mpn_xor_neg(mpz_dig_t * idig,const mpz_dig_t * jdig,size_t jlen,const mpz_dig_t * kdig,size_t klen,mpz_dbl_dig_t carryi,mpz_dbl_dig_t carryj,mpz_dbl_dig_t carryk)388 STATIC size_t mpn_xor_neg(mpz_dig_t *idig, const mpz_dig_t *jdig, size_t jlen, const mpz_dig_t *kdig, size_t klen,
389     mpz_dbl_dig_t carryi, mpz_dbl_dig_t carryj, mpz_dbl_dig_t carryk) {
390     mpz_dig_t *oidig = idig;
391 
392     for (; jlen > 0; ++idig, ++jdig) {
393         carryj += *jdig + DIG_MASK;
394         carryk += (--klen <= --jlen) ? (*kdig++ + DIG_MASK) : DIG_MASK;
395         carryi += (carryj ^ carryk) & DIG_MASK;
396         *idig = carryi & DIG_MASK;
397         carryk >>= DIG_SIZE;
398         carryj >>= DIG_SIZE;
399         carryi >>= DIG_SIZE;
400     }
401 
402     if (0 != carryi) {
403         *idig++ = carryi;
404     }
405 
406     return mpn_remove_trailing_zeros(oidig, idig);
407 }
408 
409 /* computes i = i * d1 + d2
410    returns number of digits in i
411    assumes enough memory in i; assumes normalised i; assumes dmul != 0
412 */
mpn_mul_dig_add_dig(mpz_dig_t * idig,size_t ilen,mpz_dig_t dmul,mpz_dig_t dadd)413 STATIC size_t mpn_mul_dig_add_dig(mpz_dig_t *idig, size_t ilen, mpz_dig_t dmul, mpz_dig_t dadd) {
414     mpz_dig_t *oidig = idig;
415     mpz_dbl_dig_t carry = dadd;
416 
417     for (; ilen > 0; --ilen, ++idig) {
418         carry += (mpz_dbl_dig_t)*idig * (mpz_dbl_dig_t)dmul; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
419         *idig = carry & DIG_MASK;
420         carry >>= DIG_SIZE;
421     }
422 
423     if (carry != 0) {
424         *idig++ = carry;
425     }
426 
427     return idig - oidig;
428 }
429 
430 /* computes i = j * k
431    returns number of digits in i
432    assumes enough memory in i; assumes i is zeroed; assumes normalised j, k
433    can have j, k point to same memory
434 */
mpn_mul(mpz_dig_t * idig,mpz_dig_t * jdig,size_t jlen,mpz_dig_t * kdig,size_t klen)435 STATIC size_t mpn_mul(mpz_dig_t *idig, mpz_dig_t *jdig, size_t jlen, mpz_dig_t *kdig, size_t klen) {
436     mpz_dig_t *oidig = idig;
437     size_t ilen = 0;
438 
439     for (; klen > 0; --klen, ++idig, ++kdig) {
440         mpz_dig_t *id = idig;
441         mpz_dbl_dig_t carry = 0;
442 
443         size_t jl = jlen;
444         for (mpz_dig_t *jd = jdig; jl > 0; --jl, ++jd, ++id) {
445             carry += (mpz_dbl_dig_t)*id + (mpz_dbl_dig_t)*jd * (mpz_dbl_dig_t)*kdig; // will never overflow so long as DIG_SIZE <= 8*sizeof(mpz_dbl_dig_t)/2
446             *id = carry & DIG_MASK;
447             carry >>= DIG_SIZE;
448         }
449 
450         if (carry != 0) {
451             *id++ = carry;
452         }
453 
454         ilen = id - oidig;
455     }
456 
457     return ilen;
458 }
459 
460 /* natural_div - quo * den + new_num = old_num (ie num is replaced with rem)
461    assumes den != 0
462    assumes num_dig has enough memory to be extended by 1 digit
463    assumes quo_dig has enough memory (as many digits as num)
464    assumes quo_dig is filled with zeros
465 */
mpn_div(mpz_dig_t * num_dig,size_t * num_len,const mpz_dig_t * den_dig,size_t den_len,mpz_dig_t * quo_dig,size_t * quo_len)466 STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_dig, size_t den_len, mpz_dig_t *quo_dig, size_t *quo_len) {
467     mpz_dig_t *orig_num_dig = num_dig;
468     mpz_dig_t *orig_quo_dig = quo_dig;
469     mpz_dig_t norm_shift = 0;
470     mpz_dbl_dig_t lead_den_digit;
471 
472     // handle simple cases
473     {
474         int cmp = mpn_cmp(num_dig, *num_len, den_dig, den_len);
475         if (cmp == 0) {
476             *num_len = 0;
477             quo_dig[0] = 1;
478             *quo_len = 1;
479             return;
480         } else if (cmp < 0) {
481             // numerator remains the same
482             *quo_len = 0;
483             return;
484         }
485     }
486 
487     // We need to normalise the denominator (leading bit of leading digit is 1)
488     // so that the division routine works.  Since the denominator memory is
489     // read-only we do the normalisation on the fly, each time a digit of the
490     // denominator is needed.  We need to know is how many bits to shift by.
491 
492     // count number of leading zeros in leading digit of denominator
493     {
494         mpz_dig_t d = den_dig[den_len - 1];
495         while ((d & DIG_MSB) == 0) {
496             d <<= 1;
497             ++norm_shift;
498         }
499     }
500 
501     // now need to shift numerator by same amount as denominator
502     // first, increase length of numerator in case we need more room to shift
503     num_dig[*num_len] = 0;
504     ++(*num_len);
505     for (mpz_dig_t *num = num_dig, carry = 0; num < num_dig + *num_len; ++num) {
506         mpz_dig_t n = *num;
507         *num = ((n << norm_shift) | carry) & DIG_MASK;
508         carry = (mpz_dbl_dig_t)n >> (DIG_SIZE - norm_shift);
509     }
510 
511     // cache the leading digit of the denominator
512     lead_den_digit = (mpz_dbl_dig_t)den_dig[den_len - 1] << norm_shift;
513     if (den_len >= 2) {
514         lead_den_digit |= (mpz_dbl_dig_t)den_dig[den_len - 2] >> (DIG_SIZE - norm_shift);
515     }
516 
517     // point num_dig to last digit in numerator
518     num_dig += *num_len - 1;
519 
520     // calculate number of digits in quotient
521     *quo_len = *num_len - den_len;
522 
523     // point to last digit to store for quotient
524     quo_dig += *quo_len - 1;
525 
526     // keep going while we have enough digits to divide
527     while (*num_len > den_len) {
528         mpz_dbl_dig_t quo = ((mpz_dbl_dig_t)*num_dig << DIG_SIZE) | num_dig[-1];
529 
530         // get approximate quotient
531         quo /= lead_den_digit;
532 
533         // Multiply quo by den and subtract from num to get remainder.
534         // Must be careful with overflow of the borrow variable.  Both
535         // borrow and low_digs are signed values and need signed right-shift,
536         // but x is unsigned and may take a full-range value.
537         const mpz_dig_t *d = den_dig;
538         mpz_dbl_dig_t d_norm = 0;
539         mpz_dbl_dig_signed_t borrow = 0;
540         for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
541             // Get the next digit in (den).
542             d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
543             // Multiply the next digit in (quo * den).
544             mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
545             // Compute the low DIG_MASK bits of the next digit in (num - quo * den)
546             mpz_dbl_dig_signed_t low_digs = (borrow & DIG_MASK) + *n - (x & DIG_MASK);
547             // Store the digit result for (num).
548             *n = low_digs & DIG_MASK;
549             // Compute the borrow, shifted right before summing to avoid overflow.
550             borrow = (borrow >> DIG_SIZE) - (x >> DIG_SIZE) + (low_digs >> DIG_SIZE);
551         }
552 
553         // At this point we have either:
554         //
555         //   1. quo was the correct value and the most-sig-digit of num is exactly
556         //      cancelled by borrow (borrow + *num_dig == 0).  In this case there is
557         //      nothing more to do.
558         //
559         //   2. quo was too large, we subtracted too many den from num, and the
560         //      most-sig-digit of num is less than needed (borrow + *num_dig < 0).
561         //      In this case we must reduce quo and add back den to num until the
562         //      carry from this operation cancels out the borrow.
563         //
564         borrow += *num_dig;
565         for (; borrow != 0; --quo) {
566             d = den_dig;
567             d_norm = 0;
568             mpz_dbl_dig_t carry = 0;
569             for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
570                 d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
571                 carry += (mpz_dbl_dig_t)*n + (d_norm & DIG_MASK);
572                 *n = carry & DIG_MASK;
573                 carry >>= DIG_SIZE;
574             }
575             borrow += carry;
576         }
577 
578         // store this digit of the quotient
579         *quo_dig = quo & DIG_MASK;
580         --quo_dig;
581 
582         // move down to next digit of numerator
583         --num_dig;
584         --(*num_len);
585     }
586 
587     // unnormalise numerator (remainder now)
588     for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) {
589         mpz_dig_t n = *num;
590         *num = ((n >> norm_shift) | carry) & DIG_MASK;
591         carry = (mpz_dbl_dig_t)n << (DIG_SIZE - norm_shift);
592     }
593 
594     // strip trailing zeros
595 
596     while (*quo_len > 0 && orig_quo_dig[*quo_len - 1] == 0) {
597         --(*quo_len);
598     }
599 
600     while (*num_len > 0 && orig_num_dig[*num_len - 1] == 0) {
601         --(*num_len);
602     }
603 }
604 
605 #define MIN_ALLOC (2)
606 
mpz_init_zero(mpz_t * z)607 void mpz_init_zero(mpz_t *z) {
608     z->neg = 0;
609     z->fixed_dig = 0;
610     z->alloc = 0;
611     z->len = 0;
612     z->dig = NULL;
613 }
614 
mpz_init_from_int(mpz_t * z,mp_int_t val)615 void mpz_init_from_int(mpz_t *z, mp_int_t val) {
616     mpz_init_zero(z);
617     mpz_set_from_int(z, val);
618 }
619 
mpz_init_fixed_from_int(mpz_t * z,mpz_dig_t * dig,size_t alloc,mp_int_t val)620 void mpz_init_fixed_from_int(mpz_t *z, mpz_dig_t *dig, size_t alloc, mp_int_t val) {
621     z->neg = 0;
622     z->fixed_dig = 1;
623     z->alloc = alloc;
624     z->len = 0;
625     z->dig = dig;
626     mpz_set_from_int(z, val);
627 }
628 
mpz_deinit(mpz_t * z)629 void mpz_deinit(mpz_t *z) {
630     if (z != NULL && !z->fixed_dig) {
631         m_del(mpz_dig_t, z->dig, z->alloc);
632     }
633 }
634 
635 #if 0
636 these functions are unused
637 
638 mpz_t *mpz_zero(void) {
639     mpz_t *z = m_new_obj(mpz_t);
640     mpz_init_zero(z);
641     return z;
642 }
643 
644 mpz_t *mpz_from_int(mp_int_t val) {
645     mpz_t *z = mpz_zero();
646     mpz_set_from_int(z, val);
647     return z;
648 }
649 
650 mpz_t *mpz_from_ll(long long val, bool is_signed) {
651     mpz_t *z = mpz_zero();
652     mpz_set_from_ll(z, val, is_signed);
653     return z;
654 }
655 
656 #if MICROPY_PY_BUILTINS_FLOAT
657 mpz_t *mpz_from_float(mp_float_t val) {
658     mpz_t *z = mpz_zero();
659     mpz_set_from_float(z, val);
660     return z;
661 }
662 #endif
663 
664 mpz_t *mpz_from_str(const char *str, size_t len, bool neg, unsigned int base) {
665     mpz_t *z = mpz_zero();
666     mpz_set_from_str(z, str, len, neg, base);
667     return z;
668 }
669 #endif
670 
mpz_free(mpz_t * z)671 STATIC void mpz_free(mpz_t *z) {
672     if (z != NULL) {
673         m_del(mpz_dig_t, z->dig, z->alloc);
674         m_del_obj(mpz_t, z);
675     }
676 }
677 
mpz_need_dig(mpz_t * z,size_t need)678 STATIC void mpz_need_dig(mpz_t *z, size_t need) {
679     if (need < MIN_ALLOC) {
680         need = MIN_ALLOC;
681     }
682 
683     if (z->dig == NULL || z->alloc < need) {
684         // if z has fixed digit buffer there's not much we can do as the caller will
685         // be expecting a buffer with at least "need" bytes (but it shouldn't happen)
686         assert(!z->fixed_dig);
687         z->dig = m_renew(mpz_dig_t, z->dig, z->alloc, need);
688         z->alloc = need;
689     }
690 }
691 
mpz_clone(const mpz_t * src)692 STATIC mpz_t *mpz_clone(const mpz_t *src) {
693     assert(src->alloc != 0);
694     mpz_t *z = m_new_obj(mpz_t);
695     z->neg = src->neg;
696     z->fixed_dig = 0;
697     z->alloc = src->alloc;
698     z->len = src->len;
699     z->dig = m_new(mpz_dig_t, z->alloc);
700     memcpy(z->dig, src->dig, src->alloc * sizeof(mpz_dig_t));
701     return z;
702 }
703 
704 /* sets dest = src
705    can have dest, src the same
706 */
mpz_set(mpz_t * dest,const mpz_t * src)707 void mpz_set(mpz_t *dest, const mpz_t *src) {
708     mpz_need_dig(dest, src->len);
709     dest->neg = src->neg;
710     dest->len = src->len;
711     memcpy(dest->dig, src->dig, src->len * sizeof(mpz_dig_t));
712 }
713 
mpz_set_from_int(mpz_t * z,mp_int_t val)714 void mpz_set_from_int(mpz_t *z, mp_int_t val) {
715     if (val == 0) {
716         z->len = 0;
717         return;
718     }
719 
720     mpz_need_dig(z, MPZ_NUM_DIG_FOR_INT);
721 
722     mp_uint_t uval;
723     if (val < 0) {
724         z->neg = 1;
725         uval = -val;
726     } else {
727         z->neg = 0;
728         uval = val;
729     }
730 
731     z->len = 0;
732     while (uval > 0) {
733         z->dig[z->len++] = uval & DIG_MASK;
734         uval >>= DIG_SIZE;
735     }
736 }
737 
mpz_set_from_ll(mpz_t * z,long long val,bool is_signed)738 void mpz_set_from_ll(mpz_t *z, long long val, bool is_signed) {
739     mpz_need_dig(z, MPZ_NUM_DIG_FOR_LL);
740 
741     unsigned long long uval;
742     if (is_signed && val < 0) {
743         z->neg = 1;
744         uval = -(unsigned long long)val;
745     } else {
746         z->neg = 0;
747         uval = val;
748     }
749 
750     z->len = 0;
751     while (uval > 0) {
752         z->dig[z->len++] = uval & DIG_MASK;
753         uval >>= DIG_SIZE;
754     }
755 }
756 
757 #if MICROPY_PY_BUILTINS_FLOAT
mpz_set_from_float(mpz_t * z,mp_float_t src)758 void mpz_set_from_float(mpz_t *z, mp_float_t src) {
759     mp_float_union_t u = {src};
760     z->neg = u.p.sgn;
761     if (u.p.exp == 0) {
762         // value == 0 || value < 1
763         mpz_set_from_int(z, 0);
764     } else if (u.p.exp == ((1 << MP_FLOAT_EXP_BITS) - 1)) {
765         // u.p.frc == 0 indicates inf, else NaN
766         // should be handled by caller
767         mpz_set_from_int(z, 0);
768     } else {
769         const int adj_exp = (int)u.p.exp - MP_FLOAT_EXP_BIAS;
770         if (adj_exp < 0) {
771             // value < 1 , truncates to 0
772             mpz_set_from_int(z, 0);
773         } else if (adj_exp == 0) {
774             // 1 <= value < 2 , so truncates to 1
775             mpz_set_from_int(z, 1);
776         } else {
777             // 2 <= value
778             const int dig_cnt = (adj_exp + 1 + (DIG_SIZE - 1)) / DIG_SIZE;
779             const unsigned int rem = adj_exp % DIG_SIZE;
780             int dig_ind, shft;
781             mp_float_uint_t frc = u.p.frc | ((mp_float_uint_t)1 << MP_FLOAT_FRAC_BITS);
782 
783             if (adj_exp < MP_FLOAT_FRAC_BITS) {
784                 shft = 0;
785                 dig_ind = 0;
786                 frc >>= MP_FLOAT_FRAC_BITS - adj_exp;
787             } else {
788                 shft = (rem - MP_FLOAT_FRAC_BITS) % DIG_SIZE;
789                 dig_ind = (adj_exp - MP_FLOAT_FRAC_BITS) / DIG_SIZE;
790             }
791             mpz_need_dig(z, dig_cnt);
792             z->len = dig_cnt;
793             if (dig_ind != 0) {
794                 memset(z->dig, 0, dig_ind * sizeof(mpz_dig_t));
795             }
796             if (shft != 0) {
797                 z->dig[dig_ind++] = (frc << shft) & DIG_MASK;
798                 frc >>= DIG_SIZE - shft;
799             }
800             #if DIG_SIZE < (MP_FLOAT_FRAC_BITS + 1)
801             while (dig_ind != dig_cnt) {
802                 z->dig[dig_ind++] = frc & DIG_MASK;
803                 frc >>= DIG_SIZE;
804             }
805             #else
806             if (dig_ind != dig_cnt) {
807                 z->dig[dig_ind] = frc;
808             }
809             #endif
810         }
811     }
812 }
813 #endif
814 
815 // returns number of bytes from str that were processed
mpz_set_from_str(mpz_t * z,const char * str,size_t len,bool neg,unsigned int base)816 size_t mpz_set_from_str(mpz_t *z, const char *str, size_t len, bool neg, unsigned int base) {
817     assert(base <= 36);
818 
819     const char *cur = str;
820     const char *top = str + len;
821 
822     mpz_need_dig(z, len * 8 / DIG_SIZE + 1);
823 
824     if (neg) {
825         z->neg = 1;
826     } else {
827         z->neg = 0;
828     }
829 
830     z->len = 0;
831     for (; cur < top; ++cur) { // XXX UTF8 next char
832         // mp_uint_t v = char_to_numeric(cur#); // XXX UTF8 get char
833         mp_uint_t v = *cur;
834         if ('0' <= v && v <= '9') {
835             v -= '0';
836         } else if ('A' <= v && v <= 'Z') {
837             v -= 'A' - 10;
838         } else if ('a' <= v && v <= 'z') {
839             v -= 'a' - 10;
840         } else {
841             break;
842         }
843         if (v >= base) {
844             break;
845         }
846         z->len = mpn_mul_dig_add_dig(z->dig, z->len, base, v);
847     }
848 
849     return cur - str;
850 }
851 
mpz_set_from_bytes(mpz_t * z,bool big_endian,size_t len,const byte * buf)852 void mpz_set_from_bytes(mpz_t *z, bool big_endian, size_t len, const byte *buf) {
853     int delta = 1;
854     if (big_endian) {
855         buf += len - 1;
856         delta = -1;
857     }
858 
859     mpz_need_dig(z, (len * 8 + DIG_SIZE - 1) / DIG_SIZE);
860 
861     mpz_dig_t d = 0;
862     int num_bits = 0;
863     z->neg = 0;
864     z->len = 0;
865     while (len) {
866         while (len && num_bits < DIG_SIZE) {
867             d |= *buf << num_bits;
868             num_bits += 8;
869             buf += delta;
870             len--;
871         }
872         z->dig[z->len++] = d & DIG_MASK;
873         // Need this #if because it's C undefined behavior to do: uint32_t >> 32
874         #if DIG_SIZE != 8 && DIG_SIZE != 16 && DIG_SIZE != 32
875         d >>= DIG_SIZE;
876         #else
877         d = 0;
878         #endif
879         num_bits -= DIG_SIZE;
880     }
881 
882     z->len = mpn_remove_trailing_zeros(z->dig, z->dig + z->len);
883 }
884 
885 #if 0
886 these functions are unused
887 
888 bool mpz_is_pos(const mpz_t *z) {
889     return z->len > 0 && z->neg == 0;
890 }
891 
892 bool mpz_is_odd(const mpz_t *z) {
893     return z->len > 0 && (z->dig[0] & 1) != 0;
894 }
895 
896 bool mpz_is_even(const mpz_t *z) {
897     return z->len == 0 || (z->dig[0] & 1) == 0;
898 }
899 #endif
900 
mpz_cmp(const mpz_t * z1,const mpz_t * z2)901 int mpz_cmp(const mpz_t *z1, const mpz_t *z2) {
902     // to catch comparison of -0 with +0
903     if (z1->len == 0 && z2->len == 0) {
904         return 0;
905     }
906     int cmp = (int)z2->neg - (int)z1->neg;
907     if (cmp != 0) {
908         return cmp;
909     }
910     cmp = mpn_cmp(z1->dig, z1->len, z2->dig, z2->len);
911     if (z1->neg != 0) {
912         cmp = -cmp;
913     }
914     return cmp;
915 }
916 
917 #if 0
918 // obsolete
919 // compares mpz with an integer that fits within DIG_SIZE bits
920 mp_int_t mpz_cmp_sml_int(const mpz_t *z, mp_int_t sml_int) {
921     mp_int_t cmp;
922     if (z->neg == 0) {
923         if (sml_int < 0) {
924             return 1;
925         }
926         if (sml_int == 0) {
927             if (z->len == 0) {
928                 return 0;
929             }
930             return 1;
931         }
932         if (z->len == 0) {
933             return -1;
934         }
935         assert(sml_int < (1 << DIG_SIZE));
936         if (z->len != 1) {
937             return 1;
938         }
939         cmp = z->dig[0] - sml_int;
940     } else {
941         if (sml_int > 0) {
942             return -1;
943         }
944         if (sml_int == 0) {
945             if (z->len == 0) {
946                 return 0;
947             }
948             return -1;
949         }
950         if (z->len == 0) {
951             return 1;
952         }
953         assert(sml_int > -(1 << DIG_SIZE));
954         if (z->len != 1) {
955             return -1;
956         }
957         cmp = -z->dig[0] - sml_int;
958     }
959     if (cmp < 0) {
960         return -1;
961     }
962     if (cmp > 0) {
963         return 1;
964     }
965     return 0;
966 }
967 #endif
968 
969 #if 0
970 these functions are unused
971 
972 /* returns abs(z)
973 */
974 mpz_t *mpz_abs(const mpz_t *z) {
975     // TODO: handle case of z->alloc=0
976     mpz_t *z2 = mpz_clone(z);
977     z2->neg = 0;
978     return z2;
979 }
980 
981 /* returns -z
982 */
983 mpz_t *mpz_neg(const mpz_t *z) {
984     // TODO: handle case of z->alloc=0
985     mpz_t *z2 = mpz_clone(z);
986     z2->neg = 1 - z2->neg;
987     return z2;
988 }
989 
990 /* returns lhs + rhs
991    can have lhs, rhs the same
992 */
993 mpz_t *mpz_add(const mpz_t *lhs, const mpz_t *rhs) {
994     mpz_t *z = mpz_zero();
995     mpz_add_inpl(z, lhs, rhs);
996     return z;
997 }
998 
999 /* returns lhs - rhs
1000    can have lhs, rhs the same
1001 */
1002 mpz_t *mpz_sub(const mpz_t *lhs, const mpz_t *rhs) {
1003     mpz_t *z = mpz_zero();
1004     mpz_sub_inpl(z, lhs, rhs);
1005     return z;
1006 }
1007 
1008 /* returns lhs * rhs
1009    can have lhs, rhs the same
1010 */
1011 mpz_t *mpz_mul(const mpz_t *lhs, const mpz_t *rhs) {
1012     mpz_t *z = mpz_zero();
1013     mpz_mul_inpl(z, lhs, rhs);
1014     return z;
1015 }
1016 
1017 /* returns lhs ** rhs
1018    can have lhs, rhs the same
1019 */
1020 mpz_t *mpz_pow(const mpz_t *lhs, const mpz_t *rhs) {
1021     mpz_t *z = mpz_zero();
1022     mpz_pow_inpl(z, lhs, rhs);
1023     return z;
1024 }
1025 
1026 /* computes new integers in quo and rem such that:
1027        quo * rhs + rem = lhs
1028        0 <= rem < rhs
1029    can have lhs, rhs the same
1030 */
1031 void mpz_divmod(const mpz_t *lhs, const mpz_t *rhs, mpz_t **quo, mpz_t **rem) {
1032     *quo = mpz_zero();
1033     *rem = mpz_zero();
1034     mpz_divmod_inpl(*quo, *rem, lhs, rhs);
1035 }
1036 #endif
1037 
1038 /* computes dest = abs(z)
1039    can have dest, z the same
1040 */
mpz_abs_inpl(mpz_t * dest,const mpz_t * z)1041 void mpz_abs_inpl(mpz_t *dest, const mpz_t *z) {
1042     if (dest != z) {
1043         mpz_set(dest, z);
1044     }
1045     dest->neg = 0;
1046 }
1047 
1048 /* computes dest = -z
1049    can have dest, z the same
1050 */
mpz_neg_inpl(mpz_t * dest,const mpz_t * z)1051 void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) {
1052     if (dest != z) {
1053         mpz_set(dest, z);
1054     }
1055     dest->neg = 1 - dest->neg;
1056 }
1057 
1058 /* computes dest = ~z (= -z - 1)
1059    can have dest, z the same
1060 */
mpz_not_inpl(mpz_t * dest,const mpz_t * z)1061 void mpz_not_inpl(mpz_t *dest, const mpz_t *z) {
1062     if (dest != z) {
1063         mpz_set(dest, z);
1064     }
1065     if (dest->len == 0) {
1066         mpz_need_dig(dest, 1);
1067         dest->dig[0] = 1;
1068         dest->len = 1;
1069         dest->neg = 1;
1070     } else if (dest->neg) {
1071         dest->neg = 0;
1072         mpz_dig_t k = 1;
1073         dest->len = mpn_sub(dest->dig, dest->dig, dest->len, &k, 1);
1074     } else {
1075         mpz_need_dig(dest, dest->len + 1);
1076         mpz_dig_t k = 1;
1077         dest->len = mpn_add(dest->dig, dest->dig, dest->len, &k, 1);
1078         dest->neg = 1;
1079     }
1080 }
1081 
1082 /* computes dest = lhs << rhs
1083    can have dest, lhs the same
1084 */
mpz_shl_inpl(mpz_t * dest,const mpz_t * lhs,mp_uint_t rhs)1085 void mpz_shl_inpl(mpz_t *dest, const mpz_t *lhs, mp_uint_t rhs) {
1086     if (lhs->len == 0 || rhs == 0) {
1087         mpz_set(dest, lhs);
1088     } else {
1089         mpz_need_dig(dest, lhs->len + (rhs + DIG_SIZE - 1) / DIG_SIZE);
1090         dest->len = mpn_shl(dest->dig, lhs->dig, lhs->len, rhs);
1091         dest->neg = lhs->neg;
1092     }
1093 }
1094 
1095 /* computes dest = lhs >> rhs
1096    can have dest, lhs the same
1097 */
mpz_shr_inpl(mpz_t * dest,const mpz_t * lhs,mp_uint_t rhs)1098 void mpz_shr_inpl(mpz_t *dest, const mpz_t *lhs, mp_uint_t rhs) {
1099     if (lhs->len == 0 || rhs == 0) {
1100         mpz_set(dest, lhs);
1101     } else {
1102         mpz_need_dig(dest, lhs->len);
1103         dest->len = mpn_shr(dest->dig, lhs->dig, lhs->len, rhs);
1104         dest->neg = lhs->neg;
1105         if (dest->neg) {
1106             // arithmetic shift right, rounding to negative infinity
1107             mp_uint_t n_whole = rhs / DIG_SIZE;
1108             mp_uint_t n_part = rhs % DIG_SIZE;
1109             mpz_dig_t round_up = 0;
1110             for (size_t i = 0; i < lhs->len && i < n_whole; i++) {
1111                 if (lhs->dig[i] != 0) {
1112                     round_up = 1;
1113                     break;
1114                 }
1115             }
1116             if (n_whole < lhs->len && (lhs->dig[n_whole] & ((1 << n_part) - 1)) != 0) {
1117                 round_up = 1;
1118             }
1119             if (round_up) {
1120                 if (dest->len == 0) {
1121                     // dest == 0, so need to add 1 by hand (answer will be -1)
1122                     dest->dig[0] = 1;
1123                     dest->len = 1;
1124                 } else {
1125                     // dest > 0, so can use mpn_add to add 1
1126                     dest->len = mpn_add(dest->dig, dest->dig, dest->len, &round_up, 1);
1127                 }
1128             }
1129         }
1130     }
1131 }
1132 
1133 /* computes dest = lhs + rhs
1134    can have dest, lhs, rhs the same
1135 */
mpz_add_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs)1136 void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1137     if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
1138         const mpz_t *temp = lhs;
1139         lhs = rhs;
1140         rhs = temp;
1141     }
1142 
1143     if (lhs->neg == rhs->neg) {
1144         mpz_need_dig(dest, lhs->len + 1);
1145         dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1146     } else {
1147         mpz_need_dig(dest, lhs->len);
1148         dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1149     }
1150 
1151     dest->neg = lhs->neg;
1152 }
1153 
1154 /* computes dest = lhs - rhs
1155    can have dest, lhs, rhs the same
1156 */
mpz_sub_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs)1157 void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1158     bool neg = false;
1159 
1160     if (mpn_cmp(lhs->dig, lhs->len, rhs->dig, rhs->len) < 0) {
1161         const mpz_t *temp = lhs;
1162         lhs = rhs;
1163         rhs = temp;
1164         neg = true;
1165     }
1166 
1167     if (lhs->neg != rhs->neg) {
1168         mpz_need_dig(dest, lhs->len + 1);
1169         dest->len = mpn_add(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1170     } else {
1171         mpz_need_dig(dest, lhs->len);
1172         dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1173     }
1174 
1175     if (neg) {
1176         dest->neg = 1 - lhs->neg;
1177     } else {
1178         dest->neg = lhs->neg;
1179     }
1180 }
1181 
1182 /* computes dest = lhs & rhs
1183    can have dest, lhs, rhs the same
1184 */
mpz_and_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs)1185 void mpz_and_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1186     // make sure lhs has the most digits
1187     if (lhs->len < rhs->len) {
1188         const mpz_t *temp = lhs;
1189         lhs = rhs;
1190         rhs = temp;
1191     }
1192 
1193     #if MICROPY_OPT_MPZ_BITWISE
1194 
1195     if ((0 == lhs->neg) && (0 == rhs->neg)) {
1196         mpz_need_dig(dest, lhs->len);
1197         dest->len = mpn_and(dest->dig, lhs->dig, rhs->dig, rhs->len);
1198         dest->neg = 0;
1199     } else {
1200         mpz_need_dig(dest, lhs->len + 1);
1201         dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1202             lhs->neg == rhs->neg, 0 != lhs->neg, 0 != rhs->neg);
1203         dest->neg = lhs->neg & rhs->neg;
1204     }
1205 
1206     #else
1207 
1208     mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1209     dest->len = mpn_and_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1210         (lhs->neg == rhs->neg) ? lhs->neg : 0, lhs->neg, rhs->neg);
1211     dest->neg = lhs->neg & rhs->neg;
1212 
1213     #endif
1214 }
1215 
1216 /* computes dest = lhs | rhs
1217    can have dest, lhs, rhs the same
1218 */
mpz_or_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs)1219 void mpz_or_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1220     // make sure lhs has the most digits
1221     if (lhs->len < rhs->len) {
1222         const mpz_t *temp = lhs;
1223         lhs = rhs;
1224         rhs = temp;
1225     }
1226 
1227     #if MICROPY_OPT_MPZ_BITWISE
1228 
1229     if ((0 == lhs->neg) && (0 == rhs->neg)) {
1230         mpz_need_dig(dest, lhs->len);
1231         dest->len = mpn_or(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1232         dest->neg = 0;
1233     } else {
1234         mpz_need_dig(dest, lhs->len + 1);
1235         dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1236             0 != lhs->neg, 0 != rhs->neg);
1237         dest->neg = 1;
1238     }
1239 
1240     #else
1241 
1242     mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1243     dest->len = mpn_or_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1244         (lhs->neg || rhs->neg), lhs->neg, rhs->neg);
1245     dest->neg = lhs->neg | rhs->neg;
1246 
1247     #endif
1248 }
1249 
1250 /* computes dest = lhs ^ rhs
1251    can have dest, lhs, rhs the same
1252 */
mpz_xor_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs)1253 void mpz_xor_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1254     // make sure lhs has the most digits
1255     if (lhs->len < rhs->len) {
1256         const mpz_t *temp = lhs;
1257         lhs = rhs;
1258         rhs = temp;
1259     }
1260 
1261     #if MICROPY_OPT_MPZ_BITWISE
1262 
1263     if (lhs->neg == rhs->neg) {
1264         mpz_need_dig(dest, lhs->len);
1265         if (lhs->neg == 0) {
1266             dest->len = mpn_xor(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1267         } else {
1268             dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 0, 0, 0);
1269         }
1270         dest->neg = 0;
1271     } else {
1272         mpz_need_dig(dest, lhs->len + 1);
1273         dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len, 1,
1274             0 == lhs->neg, 0 == rhs->neg);
1275         dest->neg = 1;
1276     }
1277 
1278     #else
1279 
1280     mpz_need_dig(dest, lhs->len + (lhs->neg || rhs->neg));
1281     dest->len = mpn_xor_neg(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len,
1282         (lhs->neg != rhs->neg), 0 == lhs->neg, 0 == rhs->neg);
1283     dest->neg = lhs->neg ^ rhs->neg;
1284 
1285     #endif
1286 }
1287 
1288 /* computes dest = lhs * rhs
1289    can have dest, lhs, rhs the same
1290 */
mpz_mul_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs)1291 void mpz_mul_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1292     if (lhs->len == 0 || rhs->len == 0) {
1293         mpz_set_from_int(dest, 0);
1294         return;
1295     }
1296 
1297     mpz_t *temp = NULL;
1298     if (lhs == dest) {
1299         lhs = temp = mpz_clone(lhs);
1300         if (rhs == dest) {
1301             rhs = lhs;
1302         }
1303     } else if (rhs == dest) {
1304         rhs = temp = mpz_clone(rhs);
1305     }
1306 
1307     mpz_need_dig(dest, lhs->len + rhs->len); // min mem l+r-1, max mem l+r
1308     memset(dest->dig, 0, dest->alloc * sizeof(mpz_dig_t));
1309     dest->len = mpn_mul(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len);
1310 
1311     if (lhs->neg == rhs->neg) {
1312         dest->neg = 0;
1313     } else {
1314         dest->neg = 1;
1315     }
1316 
1317     mpz_free(temp);
1318 }
1319 
1320 /* computes dest = lhs ** rhs
1321    can have dest, lhs, rhs the same
1322 */
mpz_pow_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs)1323 void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) {
1324     if (lhs->len == 0 || rhs->neg != 0) {
1325         mpz_set_from_int(dest, 0);
1326         return;
1327     }
1328 
1329     if (rhs->len == 0) {
1330         mpz_set_from_int(dest, 1);
1331         return;
1332     }
1333 
1334     mpz_t *x = mpz_clone(lhs);
1335     mpz_t *n = mpz_clone(rhs);
1336 
1337     mpz_set_from_int(dest, 1);
1338 
1339     while (n->len > 0) {
1340         if ((n->dig[0] & 1) != 0) {
1341             mpz_mul_inpl(dest, dest, x);
1342         }
1343         n->len = mpn_shr(n->dig, n->dig, n->len, 1);
1344         if (n->len == 0) {
1345             break;
1346         }
1347         mpz_mul_inpl(x, x, x);
1348     }
1349 
1350     mpz_free(x);
1351     mpz_free(n);
1352 }
1353 
1354 /* computes dest = (lhs ** rhs) % mod
1355    can have dest, lhs, rhs the same; mod can't be the same as dest
1356 */
mpz_pow3_inpl(mpz_t * dest,const mpz_t * lhs,const mpz_t * rhs,const mpz_t * mod)1357 void mpz_pow3_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs, const mpz_t *mod) {
1358     if (lhs->len == 0 || rhs->neg != 0 || (mod->len == 1 && mod->dig[0] == 1)) {
1359         mpz_set_from_int(dest, 0);
1360         return;
1361     }
1362 
1363     mpz_set_from_int(dest, 1);
1364 
1365     if (rhs->len == 0) {
1366         return;
1367     }
1368 
1369     mpz_t *x = mpz_clone(lhs);
1370     mpz_t *n = mpz_clone(rhs);
1371     mpz_t quo;
1372     mpz_init_zero(&quo);
1373 
1374     while (n->len > 0) {
1375         if ((n->dig[0] & 1) != 0) {
1376             mpz_mul_inpl(dest, dest, x);
1377             mpz_divmod_inpl(&quo, dest, dest, mod);
1378         }
1379         n->len = mpn_shr(n->dig, n->dig, n->len, 1);
1380         if (n->len == 0) {
1381             break;
1382         }
1383         mpz_mul_inpl(x, x, x);
1384         mpz_divmod_inpl(&quo, x, x, mod);
1385     }
1386 
1387     mpz_deinit(&quo);
1388     mpz_free(x);
1389     mpz_free(n);
1390 }
1391 
1392 #if 0
1393 these functions are unused
1394 
1395 /* computes gcd(z1, z2)
1396    based on Knuth's modified gcd algorithm (I think?)
1397    gcd(z1, z2) >= 0
1398    gcd(0, 0) = 0
1399    gcd(z, 0) = abs(z)
1400 */
1401 mpz_t *mpz_gcd(const mpz_t *z1, const mpz_t *z2) {
1402     if (z1->len == 0) {
1403         // TODO: handle case of z2->alloc=0
1404         mpz_t *a = mpz_clone(z2);
1405         a->neg = 0;
1406         return a;
1407     } else if (z2->len == 0) {
1408         mpz_t *a = mpz_clone(z1);
1409         a->neg = 0;
1410         return a;
1411     }
1412 
1413     mpz_t *a = mpz_clone(z1);
1414     mpz_t *b = mpz_clone(z2);
1415     mpz_t c;
1416     mpz_init_zero(&c);
1417     a->neg = 0;
1418     b->neg = 0;
1419 
1420     for (;;) {
1421         if (mpz_cmp(a, b) < 0) {
1422             if (a->len == 0) {
1423                 mpz_free(a);
1424                 mpz_deinit(&c);
1425                 return b;
1426             }
1427             mpz_t *t = a;
1428             a = b;
1429             b = t;
1430         }
1431         if (!(b->len >= 2 || (b->len == 1 && b->dig[0] > 1))) { // compute b > 0; could be mpz_cmp_small_int(b, 1) > 0
1432             break;
1433         }
1434         mpz_set(&c, b);
1435         do {
1436             mpz_add_inpl(&c, &c, &c);
1437         } while (mpz_cmp(&c, a) <= 0);
1438         c.len = mpn_shr(c.dig, c.dig, c.len, 1);
1439         mpz_sub_inpl(a, a, &c);
1440     }
1441 
1442     mpz_deinit(&c);
1443 
1444     if (b->len == 1 && b->dig[0] == 1) { // compute b == 1; could be mpz_cmp_small_int(b, 1) == 0
1445         mpz_free(a);
1446         return b;
1447     } else {
1448         mpz_free(b);
1449         return a;
1450     }
1451 }
1452 
1453 /* computes lcm(z1, z2)
1454      = abs(z1) / gcd(z1, z2) * abs(z2)
1455   lcm(z1, z1) >= 0
1456   lcm(0, 0) = 0
1457   lcm(z, 0) = 0
1458 */
1459 mpz_t *mpz_lcm(const mpz_t *z1, const mpz_t *z2) {
1460     if (z1->len == 0 || z2->len == 0) {
1461         return mpz_zero();
1462     }
1463 
1464     mpz_t *gcd = mpz_gcd(z1, z2);
1465     mpz_t *quo = mpz_zero();
1466     mpz_t *rem = mpz_zero();
1467     mpz_divmod_inpl(quo, rem, z1, gcd);
1468     mpz_mul_inpl(rem, quo, z2);
1469     mpz_free(gcd);
1470     mpz_free(quo);
1471     rem->neg = 0;
1472     return rem;
1473 }
1474 #endif
1475 
1476 /* computes new integers in quo and rem such that:
1477        quo * rhs + rem = lhs
1478        0 <= rem < rhs
1479    can have lhs, rhs the same
1480    assumes rhs != 0 (undefined behaviour if it is)
1481 */
mpz_divmod_inpl(mpz_t * dest_quo,mpz_t * dest_rem,const mpz_t * lhs,const mpz_t * rhs)1482 void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const mpz_t *rhs) {
1483     assert(!mpz_is_zero(rhs));
1484 
1485     mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary?
1486     memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t));
1487     dest_quo->len = 0;
1488     mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary?
1489     mpz_set(dest_rem, lhs);
1490     mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len);
1491 
1492     // check signs and do Python style modulo
1493     if (lhs->neg != rhs->neg) {
1494         dest_quo->neg = 1;
1495         if (!mpz_is_zero(dest_rem)) {
1496             mpz_t mpzone;
1497             mpz_init_from_int(&mpzone, -1);
1498             mpz_add_inpl(dest_quo, dest_quo, &mpzone);
1499             mpz_add_inpl(dest_rem, dest_rem, rhs);
1500         }
1501     }
1502 }
1503 
1504 #if 0
1505 these functions are unused
1506 
1507 /* computes floor(lhs / rhs)
1508    can have lhs, rhs the same
1509 */
1510 mpz_t *mpz_div(const mpz_t *lhs, const mpz_t *rhs) {
1511     mpz_t *quo = mpz_zero();
1512     mpz_t rem;
1513     mpz_init_zero(&rem);
1514     mpz_divmod_inpl(quo, &rem, lhs, rhs);
1515     mpz_deinit(&rem);
1516     return quo;
1517 }
1518 
1519 /* computes lhs % rhs ( >= 0)
1520    can have lhs, rhs the same
1521 */
1522 mpz_t *mpz_mod(const mpz_t *lhs, const mpz_t *rhs) {
1523     mpz_t quo;
1524     mpz_init_zero(&quo);
1525     mpz_t *rem = mpz_zero();
1526     mpz_divmod_inpl(&quo, rem, lhs, rhs);
1527     mpz_deinit(&quo);
1528     return rem;
1529 }
1530 #endif
1531 
1532 // must return actual int value if it fits in mp_int_t
mpz_hash(const mpz_t * z)1533 mp_int_t mpz_hash(const mpz_t *z) {
1534     mp_uint_t val = 0;
1535     mpz_dig_t *d = z->dig + z->len;
1536 
1537     while (d-- > z->dig) {
1538         val = (val << DIG_SIZE) | *d;
1539     }
1540 
1541     if (z->neg != 0) {
1542         val = -val;
1543     }
1544 
1545     return val;
1546 }
1547 
mpz_as_int_checked(const mpz_t * i,mp_int_t * value)1548 bool mpz_as_int_checked(const mpz_t *i, mp_int_t *value) {
1549     mp_uint_t val = 0;
1550     mpz_dig_t *d = i->dig + i->len;
1551 
1552     while (d-- > i->dig) {
1553         if (val > (~(MP_OBJ_WORD_MSBIT_HIGH) >> DIG_SIZE)) {
1554             // will overflow
1555             return false;
1556         }
1557         val = (val << DIG_SIZE) | *d;
1558     }
1559 
1560     if (i->neg != 0) {
1561         val = -val;
1562     }
1563 
1564     *value = val;
1565     return true;
1566 }
1567 
mpz_as_uint_checked(const mpz_t * i,mp_uint_t * value)1568 bool mpz_as_uint_checked(const mpz_t *i, mp_uint_t *value) {
1569     if (i->neg != 0) {
1570         // can't represent signed values
1571         return false;
1572     }
1573 
1574     mp_uint_t val = 0;
1575     mpz_dig_t *d = i->dig + i->len;
1576 
1577     while (d-- > i->dig) {
1578         if (val > (~(MP_OBJ_WORD_MSBIT_HIGH) >> (DIG_SIZE - 1))) {
1579             // will overflow
1580             return false;
1581         }
1582         val = (val << DIG_SIZE) | *d;
1583     }
1584 
1585     *value = val;
1586     return true;
1587 }
1588 
mpz_as_bytes(const mpz_t * z,bool big_endian,size_t len,byte * buf)1589 void mpz_as_bytes(const mpz_t *z, bool big_endian, size_t len, byte *buf) {
1590     byte *b = buf;
1591     if (big_endian) {
1592         b += len;
1593     }
1594     mpz_dig_t *zdig = z->dig;
1595     int bits = 0;
1596     mpz_dbl_dig_t d = 0;
1597     mpz_dbl_dig_t carry = 1;
1598     for (size_t zlen = z->len; zlen > 0; --zlen) {
1599         bits += DIG_SIZE;
1600         d = (d << DIG_SIZE) | *zdig++;
1601         for (; bits >= 8; bits -= 8, d >>= 8) {
1602             mpz_dig_t val = d;
1603             if (z->neg) {
1604                 val = (~val & 0xff) + carry;
1605                 carry = val >> 8;
1606             }
1607             if (big_endian) {
1608                 *--b = val;
1609                 if (b == buf) {
1610                     return;
1611                 }
1612             } else {
1613                 *b++ = val;
1614                 if (b == buf + len) {
1615                     return;
1616                 }
1617             }
1618         }
1619     }
1620 
1621     // fill remainder of buf with zero/sign extension of the integer
1622     if (big_endian) {
1623         len = b - buf;
1624     } else {
1625         len = buf + len - b;
1626         buf = b;
1627     }
1628     memset(buf, z->neg ? 0xff : 0x00, len);
1629 }
1630 
1631 #if MICROPY_PY_BUILTINS_FLOAT
mpz_as_float(const mpz_t * i)1632 mp_float_t mpz_as_float(const mpz_t *i) {
1633     mp_float_t val = 0;
1634     mpz_dig_t *d = i->dig + i->len;
1635 
1636     while (d-- > i->dig) {
1637         val = val * DIG_BASE + *d;
1638     }
1639 
1640     if (i->neg != 0) {
1641         val = -val;
1642     }
1643 
1644     return val;
1645 }
1646 #endif
1647 
1648 #if 0
1649 this function is unused
1650 char *mpz_as_str(const mpz_t *i, unsigned int base) {
1651     char *s = m_new(char, mp_int_format_size(mpz_max_num_bits(i), base, NULL, '\0'));
1652     mpz_as_str_inpl(i, base, NULL, 'a', '\0', s);
1653     return s;
1654 }
1655 #endif
1656 
1657 // assumes enough space in str as calculated by mp_int_format_size
1658 // base must be between 2 and 32 inclusive
1659 // returns length of string, not including null byte
mpz_as_str_inpl(const mpz_t * i,unsigned int base,const char * prefix,char base_char,char comma,char * str)1660 size_t mpz_as_str_inpl(const mpz_t *i, unsigned int base, const char *prefix, char base_char, char comma, char *str) {
1661     assert(str != NULL);
1662     assert(2 <= base && base <= 32);
1663 
1664     size_t ilen = i->len;
1665 
1666     char *s = str;
1667     if (ilen == 0) {
1668         if (prefix) {
1669             while (*prefix) {
1670                 *s++ = *prefix++;
1671             }
1672         }
1673         *s++ = '0';
1674         *s = '\0';
1675         return s - str;
1676     }
1677 
1678     // make a copy of mpz digits, so we can do the div/mod calculation
1679     mpz_dig_t *dig = m_new(mpz_dig_t, ilen);
1680     memcpy(dig, i->dig, ilen * sizeof(mpz_dig_t));
1681 
1682     // convert
1683     char *last_comma = str;
1684     bool done;
1685     do {
1686         mpz_dig_t *d = dig + ilen;
1687         mpz_dbl_dig_t a = 0;
1688 
1689         // compute next remainder
1690         while (--d >= dig) {
1691             a = (a << DIG_SIZE) | *d;
1692             *d = a / base;
1693             a %= base;
1694         }
1695 
1696         // convert to character
1697         a += '0';
1698         if (a > '9') {
1699             a += base_char - '9' - 1;
1700         }
1701         *s++ = a;
1702 
1703         // check if number is zero
1704         done = true;
1705         for (d = dig; d < dig + ilen; ++d) {
1706             if (*d != 0) {
1707                 done = false;
1708                 break;
1709             }
1710         }
1711         if (comma && (s - last_comma) == 3) {
1712             *s++ = comma;
1713             last_comma = s;
1714         }
1715     }
1716     while (!done);
1717 
1718     // free the copy of the digits array
1719     m_del(mpz_dig_t, dig, ilen);
1720 
1721     if (prefix) {
1722         const char *p = &prefix[strlen(prefix)];
1723         while (p > prefix) {
1724             *s++ = *--p;
1725         }
1726     }
1727     if (i->neg != 0) {
1728         *s++ = '-';
1729     }
1730 
1731     // reverse string
1732     for (char *u = str, *v = s - 1; u < v; ++u, --v) {
1733         char temp = *u;
1734         *u = *v;
1735         *v = temp;
1736     }
1737 
1738     *s = '\0'; // null termination
1739 
1740     return s - str;
1741 }
1742 
1743 #endif // MICROPY_LONGINT_IMPL == MICROPY_LONGINT_IMPL_MPZ
1744