1 /*
2   Name:     imath.c
3   Purpose:  Arbitrary precision integer arithmetic routines.
4   Author:   M. J. Fromberger
5 
6   Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved.
7 
8   Permission is hereby granted, free of charge, to any person obtaining a copy
9   of this software and associated documentation files (the "Software"), to deal
10   in the Software without restriction, including without limitation the rights
11   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12   copies of the Software, and to permit persons to whom the Software is
13   furnished to do so, subject to the following conditions:
14 
15   The above copyright notice and this permission notice shall be included in
16   all copies or substantial portions of the Software.
17 
18   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
21   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24   SOFTWARE.
25  */
26 
27 #include "imath.h"
28 
29 #include <assert.h>
30 #include <ctype.h>
31 #include <stdlib.h>
32 #include <string.h>
33 
34 const mp_result MP_OK = 0;      /* no error, all is well  */
35 const mp_result MP_FALSE = 0;   /* boolean false          */
36 const mp_result MP_TRUE = -1;   /* boolean true           */
37 const mp_result MP_MEMORY = -2; /* out of memory          */
38 const mp_result MP_RANGE = -3;  /* argument out of range  */
39 const mp_result MP_UNDEF = -4;  /* result undefined       */
40 const mp_result MP_TRUNC = -5;  /* output truncated       */
41 const mp_result MP_BADARG = -6; /* invalid null argument  */
42 const mp_result MP_MINERR = -6;
43 
44 const mp_sign MP_NEG = 1;  /* value is strictly negative */
45 const mp_sign MP_ZPOS = 0; /* value is non-negative      */
46 
47 static const char *s_unknown_err = "unknown result code";
48 static const char *s_error_msg[] = {"error code 0",     "boolean true",
49                                     "out of memory",    "argument out of range",
50                                     "result undefined", "output truncated",
51                                     "invalid argument", NULL};
52 
53 /* The ith entry of this table gives the value of log_i(2).
54 
55    An integer value n requires ceil(log_i(n)) digits to be represented
56    in base i.  Since it is easy to compute lg(n), by counting bits, we
57    can compute log_i(n) = lg(n) * log_i(2).
58 
59    The use of this table eliminates a dependency upon linkage against
60    the standard math libraries.
61 
62    If MP_MAX_RADIX is increased, this table should be expanded too.
63  */
64 static const double s_log2[] = {
65     0.000000000, 0.000000000, 1.000000000, 0.630929754, /* (D)(D) 2  3 */
66     0.500000000, 0.430676558, 0.386852807, 0.356207187, /*  4  5  6  7 */
67     0.333333333, 0.315464877, 0.301029996, 0.289064826, /*  8  9 10 11 */
68     0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */
69     0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */
70     0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */
71     0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */
72     0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */
73     0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */
74     0.193426404,                                        /* 36          */
75 };
76 
77 /* Return the number of digits needed to represent a static value */
78 #define MP_VALUE_DIGITS(V) \
79   ((sizeof(V) + (sizeof(mp_digit) - 1)) / sizeof(mp_digit))
80 
81 /* Round precision P to nearest word boundary */
s_round_prec(mp_size P)82 static inline mp_size s_round_prec(mp_size P) { return 2 * ((P + 1) / 2); }
83 
84 /* Set array P of S digits to zero */
ZERO(mp_digit * P,mp_size S)85 static inline void ZERO(mp_digit *P, mp_size S) {
86   mp_size i__ = S * sizeof(mp_digit);
87   mp_digit *p__ = P;
88   memset(p__, 0, i__);
89 }
90 
91 /* Copy S digits from array P to array Q */
COPY(mp_digit * P,mp_digit * Q,mp_size S)92 static inline void COPY(mp_digit *P, mp_digit *Q, mp_size S) {
93   mp_size i__ = S * sizeof(mp_digit);
94   mp_digit *p__ = P;
95   mp_digit *q__ = Q;
96   memcpy(q__, p__, i__);
97 }
98 
99 /* Reverse N elements of unsigned char in A. */
REV(unsigned char * A,int N)100 static inline void REV(unsigned char *A, int N) {
101   unsigned char *u_ = A;
102   unsigned char *v_ = u_ + N - 1;
103   while (u_ < v_) {
104     unsigned char xch = *u_;
105     *u_++ = *v_;
106     *v_-- = xch;
107   }
108 }
109 
110 /* Strip leading zeroes from z_ in-place. */
CLAMP(mp_int z_)111 static inline void CLAMP(mp_int z_) {
112   mp_size uz_ = MP_USED(z_);
113   mp_digit *dz_ = MP_DIGITS(z_) + uz_ - 1;
114   while (uz_ > 1 && (*dz_-- == 0)) --uz_;
115   z_->used = uz_;
116 }
117 
118 /* Select min/max. */
MIN(int A,int B)119 static inline int MIN(int A, int B) { return (B < A ? B : A); }
MAX(mp_size A,mp_size B)120 static inline mp_size MAX(mp_size A, mp_size B) { return (B > A ? B : A); }
121 
122 /* Exchange lvalues A and B of type T, e.g.
123    SWAP(int, x, y) where x and y are variables of type int. */
124 #define SWAP(T, A, B) \
125   do {                \
126     T t_ = (A);       \
127     A = (B);          \
128     B = t_;           \
129   } while (0)
130 
131 /* Declare a block of N temporary mpz_t values.
132    These values are initialized to zero.
133    You must add CLEANUP_TEMP() at the end of the function.
134    Use TEMP(i) to access a pointer to the ith value.
135  */
136 #define DECLARE_TEMP(N)                   \
137   struct {                                \
138     mpz_t value[(N)];                     \
139     int len;                              \
140     mp_result err;                        \
141   } temp_ = {                             \
142       .len = (N),                         \
143       .err = MP_OK,                       \
144   };                                      \
145   do {                                    \
146     for (int i = 0; i < temp_.len; i++) { \
147       mp_int_init(TEMP(i));               \
148     }                                     \
149   } while (0)
150 
151 /* Clear all allocated temp values. */
152 #define CLEANUP_TEMP()                    \
153   CLEANUP:                                \
154   do {                                    \
155     for (int i = 0; i < temp_.len; i++) { \
156       mp_int_clear(TEMP(i));              \
157     }                                     \
158     if (temp_.err != MP_OK) {             \
159       return temp_.err;                   \
160     }                                     \
161   } while (0)
162 
163 /* A pointer to the kth temp value. */
164 #define TEMP(K) (temp_.value + (K))
165 
166 /* Evaluate E, an expression of type mp_result expected to return MP_OK.  If
167    the value is not MP_OK, the error is cached and control resumes at the
168    cleanup handler, which returns it.
169 */
170 #define REQUIRE(E)                        \
171   do {                                    \
172     temp_.err = (E);                      \
173     if (temp_.err != MP_OK) goto CLEANUP; \
174   } while (0)
175 
176 /* Compare value to zero. */
CMPZ(mp_int Z)177 static inline int CMPZ(mp_int Z) {
178   if (Z->used == 1 && Z->digits[0] == 0) return 0;
179   return (Z->sign == MP_NEG) ? -1 : 1;
180 }
181 
UPPER_HALF(mp_word W)182 static inline mp_word UPPER_HALF(mp_word W) { return (W >> MP_DIGIT_BIT); }
LOWER_HALF(mp_word W)183 static inline mp_digit LOWER_HALF(mp_word W) { return (mp_digit)(W); }
184 
185 /* Report whether the highest-order bit of W is 1. */
HIGH_BIT_SET(mp_word W)186 static inline bool HIGH_BIT_SET(mp_word W) {
187   return (W >> (MP_WORD_BIT - 1)) != 0;
188 }
189 
190 /* Report whether adding W + V will carry out. */
ADD_WILL_OVERFLOW(mp_word W,mp_word V)191 static inline bool ADD_WILL_OVERFLOW(mp_word W, mp_word V) {
192   return ((MP_WORD_MAX - V) < W);
193 }
194 
195 /* Default number of digits allocated to a new mp_int */
196 static mp_size default_precision = 8;
197 
mp_int_default_precision(mp_size size)198 void mp_int_default_precision(mp_size size) {
199   assert(size > 0);
200   default_precision = size;
201 }
202 
203 /* Minimum number of digits to invoke recursive multiply */
204 static mp_size multiply_threshold = 32;
205 
mp_int_multiply_threshold(mp_size thresh)206 void mp_int_multiply_threshold(mp_size thresh) {
207   assert(thresh >= sizeof(mp_word));
208   multiply_threshold = thresh;
209 }
210 
211 /* Allocate a buffer of (at least) num digits, or return
212    NULL if that couldn't be done.  */
213 static mp_digit *s_alloc(mp_size num);
214 
215 /* Release a buffer of digits allocated by s_alloc(). */
216 static void s_free(void *ptr);
217 
218 /* Insure that z has at least min digits allocated, resizing if
219    necessary.  Returns true if successful, false if out of memory. */
220 static bool s_pad(mp_int z, mp_size min);
221 
222 /* Ensure Z has at least N digits allocated. */
GROW(mp_int Z,mp_size N)223 static inline mp_result GROW(mp_int Z, mp_size N) {
224   return s_pad(Z, N) ? MP_OK : MP_MEMORY;
225 }
226 
227 /* Fill in a "fake" mp_int on the stack with a given value */
228 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]);
229 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]);
230 
231 /* Compare two runs of digits of given length, returns <0, 0, >0 */
232 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len);
233 
234 /* Pack the unsigned digits of v into array t */
235 static int s_uvpack(mp_usmall v, mp_digit t[]);
236 
237 /* Compare magnitudes of a and b, returns <0, 0, >0 */
238 static int s_ucmp(mp_int a, mp_int b);
239 
240 /* Compare magnitudes of a and v, returns <0, 0, >0 */
241 static int s_vcmp(mp_int a, mp_small v);
242 static int s_uvcmp(mp_int a, mp_usmall uv);
243 
244 /* Unsigned magnitude addition; assumes dc is big enough.
245    Carry out is returned (no memory allocated). */
246 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
247                        mp_size size_b);
248 
249 /* Unsigned magnitude subtraction.  Assumes dc is big enough. */
250 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
251                    mp_size size_b);
252 
253 /* Unsigned recursive multiplication.  Assumes dc is big enough. */
254 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
255                   mp_size size_b);
256 
257 /* Unsigned magnitude multiplication.  Assumes dc is big enough. */
258 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
259                    mp_size size_b);
260 
261 /* Unsigned recursive squaring.  Assumes dc is big enough. */
262 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a);
263 
264 /* Unsigned magnitude squaring.  Assumes dc is big enough. */
265 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a);
266 
267 /* Single digit addition.  Assumes a is big enough. */
268 static void s_dadd(mp_int a, mp_digit b);
269 
270 /* Single digit multiplication.  Assumes a is big enough. */
271 static void s_dmul(mp_int a, mp_digit b);
272 
273 /* Single digit multiplication on buffers; assumes dc is big enough. */
274 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a);
275 
276 /* Single digit division.  Replaces a with the quotient,
277    returns the remainder.  */
278 static mp_digit s_ddiv(mp_int a, mp_digit b);
279 
280 /* Quick division by a power of 2, replaces z (no allocation) */
281 static void s_qdiv(mp_int z, mp_size p2);
282 
283 /* Quick remainder by a power of 2, replaces z (no allocation) */
284 static void s_qmod(mp_int z, mp_size p2);
285 
286 /* Quick multiplication by a power of 2, replaces z.
287    Allocates if necessary; returns false in case this fails. */
288 static int s_qmul(mp_int z, mp_size p2);
289 
290 /* Quick subtraction from a power of 2, replaces z.
291    Allocates if necessary; returns false in case this fails. */
292 static int s_qsub(mp_int z, mp_size p2);
293 
294 /* Return maximum k such that 2^k divides z. */
295 static int s_dp2k(mp_int z);
296 
297 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
298 static int s_isp2(mp_int z);
299 
300 /* Set z to 2^k.  May allocate; returns false in case this fails. */
301 static int s_2expt(mp_int z, mp_small k);
302 
303 /* Normalize a and b for division, returns normalization constant */
304 static int s_norm(mp_int a, mp_int b);
305 
306 /* Compute constant mu for Barrett reduction, given modulus m, result
307    replaces z, m is untouched. */
308 static mp_result s_brmu(mp_int z, mp_int m);
309 
310 /* Reduce a modulo m, using Barrett's algorithm. */
311 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
312 
313 /* Modular exponentiation, using Barrett reduction */
314 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
315 
316 /* Unsigned magnitude division.  Assumes |a| > |b|.  Allocates temporaries;
317    overwrites a with quotient, b with remainder. */
318 static mp_result s_udiv_knuth(mp_int a, mp_int b);
319 
320 /* Compute the number of digits in radix r required to represent the given
321    value.  Does not account for sign flags, terminators, etc. */
322 static int s_outlen(mp_int z, mp_size r);
323 
324 /* Guess how many digits of precision will be needed to represent a radix r
325    value of the specified number of digits.  Returns a value guaranteed to be
326    no smaller than the actual number required. */
327 static mp_size s_inlen(int len, mp_size r);
328 
329 /* Convert a character to a digit value in radix r, or
330    -1 if out of range */
331 static int s_ch2val(char c, int r);
332 
333 /* Convert a digit value to a character */
334 static char s_val2ch(int v, int caps);
335 
336 /* Take 2's complement of a buffer in place */
337 static void s_2comp(unsigned char *buf, int len);
338 
339 /* Convert a value to binary, ignoring sign.  On input, *limpos is the bound on
340    how many bytes should be written to buf; on output, *limpos is set to the
341    number of bytes actually written. */
342 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
343 
344 /* Multiply X by Y into Z, ignoring signs.  Requires that Z have enough storage
345    preallocated to hold the result. */
UMUL(mp_int X,mp_int Y,mp_int Z)346 static inline void UMUL(mp_int X, mp_int Y, mp_int Z) {
347   mp_size ua_ = MP_USED(X);
348   mp_size ub_ = MP_USED(Y);
349   mp_size o_ = ua_ + ub_;
350   ZERO(MP_DIGITS(Z), o_);
351   (void)s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_);
352   Z->used = o_;
353   CLAMP(Z);
354 }
355 
356 /* Square X into Z.  Requires that Z have enough storage to hold the result. */
USQR(mp_int X,mp_int Z)357 static inline void USQR(mp_int X, mp_int Z) {
358   mp_size ua_ = MP_USED(X);
359   mp_size o_ = ua_ + ua_;
360   ZERO(MP_DIGITS(Z), o_);
361   (void)s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_);
362   Z->used = o_;
363   CLAMP(Z);
364 }
365 
mp_int_init(mp_int z)366 mp_result mp_int_init(mp_int z) {
367   if (z == NULL) return MP_BADARG;
368 
369   z->single = 0;
370   z->digits = &(z->single);
371   z->alloc = 1;
372   z->used = 1;
373   z->sign = MP_ZPOS;
374 
375   return MP_OK;
376 }
377 
mp_int_alloc(void)378 mp_int mp_int_alloc(void) {
379   mp_int out = malloc(sizeof(mpz_t));
380 
381   if (out != NULL) mp_int_init(out);
382 
383   return out;
384 }
385 
mp_int_init_size(mp_int z,mp_size prec)386 mp_result mp_int_init_size(mp_int z, mp_size prec) {
387   assert(z != NULL);
388 
389   if (prec == 0) {
390     prec = default_precision;
391   } else if (prec == 1) {
392     return mp_int_init(z);
393   } else {
394     prec = s_round_prec(prec);
395   }
396 
397   z->digits = s_alloc(prec);
398   if (MP_DIGITS(z) == NULL) return MP_MEMORY;
399 
400   z->digits[0] = 0;
401   z->used = 1;
402   z->alloc = prec;
403   z->sign = MP_ZPOS;
404 
405   return MP_OK;
406 }
407 
mp_int_init_copy(mp_int z,mp_int old)408 mp_result mp_int_init_copy(mp_int z, mp_int old) {
409   assert(z != NULL && old != NULL);
410 
411   mp_size uold = MP_USED(old);
412   if (uold == 1) {
413     mp_int_init(z);
414   } else {
415     mp_size target = MAX(uold, default_precision);
416     mp_result res = mp_int_init_size(z, target);
417     if (res != MP_OK) return res;
418   }
419 
420   z->used = uold;
421   z->sign = old->sign;
422   COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
423 
424   return MP_OK;
425 }
426 
mp_int_init_value(mp_int z,mp_small value)427 mp_result mp_int_init_value(mp_int z, mp_small value) {
428   mpz_t vtmp;
429   mp_digit vbuf[MP_VALUE_DIGITS(value)];
430 
431   s_fake(&vtmp, value, vbuf);
432   return mp_int_init_copy(z, &vtmp);
433 }
434 
mp_int_init_uvalue(mp_int z,mp_usmall uvalue)435 mp_result mp_int_init_uvalue(mp_int z, mp_usmall uvalue) {
436   mpz_t vtmp;
437   mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
438 
439   s_ufake(&vtmp, uvalue, vbuf);
440   return mp_int_init_copy(z, &vtmp);
441 }
442 
mp_int_set_value(mp_int z,mp_small value)443 mp_result mp_int_set_value(mp_int z, mp_small value) {
444   mpz_t vtmp;
445   mp_digit vbuf[MP_VALUE_DIGITS(value)];
446 
447   s_fake(&vtmp, value, vbuf);
448   return mp_int_copy(&vtmp, z);
449 }
450 
mp_int_set_uvalue(mp_int z,mp_usmall uvalue)451 mp_result mp_int_set_uvalue(mp_int z, mp_usmall uvalue) {
452   mpz_t vtmp;
453   mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
454 
455   s_ufake(&vtmp, uvalue, vbuf);
456   return mp_int_copy(&vtmp, z);
457 }
458 
mp_int_clear(mp_int z)459 void mp_int_clear(mp_int z) {
460   if (z == NULL) return;
461 
462   if (MP_DIGITS(z) != NULL) {
463     if (MP_DIGITS(z) != &(z->single)) s_free(MP_DIGITS(z));
464 
465     z->digits = NULL;
466   }
467 }
468 
mp_int_free(mp_int z)469 void mp_int_free(mp_int z) {
470   assert(z != NULL);
471 
472   mp_int_clear(z);
473   free(z); /* note: NOT s_free() */
474 }
475 
mp_int_copy(mp_int a,mp_int c)476 mp_result mp_int_copy(mp_int a, mp_int c) {
477   assert(a != NULL && c != NULL);
478 
479   if (a != c) {
480     mp_size ua = MP_USED(a);
481     mp_digit *da, *dc;
482 
483     if (!s_pad(c, ua)) return MP_MEMORY;
484 
485     da = MP_DIGITS(a);
486     dc = MP_DIGITS(c);
487     COPY(da, dc, ua);
488 
489     c->used = ua;
490     c->sign = a->sign;
491   }
492 
493   return MP_OK;
494 }
495 
mp_int_swap(mp_int a,mp_int c)496 void mp_int_swap(mp_int a, mp_int c) {
497   if (a != c) {
498     mpz_t tmp = *a;
499 
500     *a = *c;
501     *c = tmp;
502 
503     if (MP_DIGITS(a) == &(c->single)) a->digits = &(a->single);
504     if (MP_DIGITS(c) == &(a->single)) c->digits = &(c->single);
505   }
506 }
507 
mp_int_zero(mp_int z)508 void mp_int_zero(mp_int z) {
509   assert(z != NULL);
510 
511   z->digits[0] = 0;
512   z->used = 1;
513   z->sign = MP_ZPOS;
514 }
515 
mp_int_abs(mp_int a,mp_int c)516 mp_result mp_int_abs(mp_int a, mp_int c) {
517   assert(a != NULL && c != NULL);
518 
519   mp_result res;
520   if ((res = mp_int_copy(a, c)) != MP_OK) return res;
521 
522   c->sign = MP_ZPOS;
523   return MP_OK;
524 }
525 
mp_int_neg(mp_int a,mp_int c)526 mp_result mp_int_neg(mp_int a, mp_int c) {
527   assert(a != NULL && c != NULL);
528 
529   mp_result res;
530   if ((res = mp_int_copy(a, c)) != MP_OK) return res;
531 
532   if (CMPZ(c) != 0) c->sign = 1 - MP_SIGN(a);
533 
534   return MP_OK;
535 }
536 
mp_int_add(mp_int a,mp_int b,mp_int c)537 mp_result mp_int_add(mp_int a, mp_int b, mp_int c) {
538   assert(a != NULL && b != NULL && c != NULL);
539 
540   mp_size ua = MP_USED(a);
541   mp_size ub = MP_USED(b);
542   mp_size max = MAX(ua, ub);
543 
544   if (MP_SIGN(a) == MP_SIGN(b)) {
545     /* Same sign -- add magnitudes, preserve sign of addends */
546     if (!s_pad(c, max)) return MP_MEMORY;
547 
548     mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
549     mp_size uc = max;
550 
551     if (carry) {
552       if (!s_pad(c, max + 1)) return MP_MEMORY;
553 
554       c->digits[max] = carry;
555       ++uc;
556     }
557 
558     c->used = uc;
559     c->sign = a->sign;
560 
561   } else {
562     /* Different signs -- subtract magnitudes, preserve sign of greater */
563     int cmp = s_ucmp(a, b); /* magnitude comparison, sign ignored */
564 
565     /* Set x to max(a, b), y to min(a, b) to simplify later code.
566        A special case yields zero for equal magnitudes.
567     */
568     mp_int x, y;
569     if (cmp == 0) {
570       mp_int_zero(c);
571       return MP_OK;
572     } else if (cmp < 0) {
573       x = b;
574       y = a;
575     } else {
576       x = a;
577       y = b;
578     }
579 
580     if (!s_pad(c, MP_USED(x))) return MP_MEMORY;
581 
582     /* Subtract smaller from larger */
583     s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
584     c->used = x->used;
585     CLAMP(c);
586 
587     /* Give result the sign of the larger */
588     c->sign = x->sign;
589   }
590 
591   return MP_OK;
592 }
593 
mp_int_add_value(mp_int a,mp_small value,mp_int c)594 mp_result mp_int_add_value(mp_int a, mp_small value, mp_int c) {
595   mpz_t vtmp;
596   mp_digit vbuf[MP_VALUE_DIGITS(value)];
597 
598   s_fake(&vtmp, value, vbuf);
599 
600   return mp_int_add(a, &vtmp, c);
601 }
602 
mp_int_sub(mp_int a,mp_int b,mp_int c)603 mp_result mp_int_sub(mp_int a, mp_int b, mp_int c) {
604   assert(a != NULL && b != NULL && c != NULL);
605 
606   mp_size ua = MP_USED(a);
607   mp_size ub = MP_USED(b);
608   mp_size max = MAX(ua, ub);
609 
610   if (MP_SIGN(a) != MP_SIGN(b)) {
611     /* Different signs -- add magnitudes and keep sign of a */
612     if (!s_pad(c, max)) return MP_MEMORY;
613 
614     mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
615     mp_size uc = max;
616 
617     if (carry) {
618       if (!s_pad(c, max + 1)) return MP_MEMORY;
619 
620       c->digits[max] = carry;
621       ++uc;
622     }
623 
624     c->used = uc;
625     c->sign = a->sign;
626 
627   } else {
628     /* Same signs -- subtract magnitudes */
629     if (!s_pad(c, max)) return MP_MEMORY;
630     mp_int x, y;
631     mp_sign osign;
632 
633     int cmp = s_ucmp(a, b);
634     if (cmp >= 0) {
635       x = a;
636       y = b;
637       osign = MP_ZPOS;
638     } else {
639       x = b;
640       y = a;
641       osign = MP_NEG;
642     }
643 
644     if (MP_SIGN(a) == MP_NEG && cmp != 0) osign = 1 - osign;
645 
646     s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
647     c->used = x->used;
648     CLAMP(c);
649 
650     c->sign = osign;
651   }
652 
653   return MP_OK;
654 }
655 
mp_int_sub_value(mp_int a,mp_small value,mp_int c)656 mp_result mp_int_sub_value(mp_int a, mp_small value, mp_int c) {
657   mpz_t vtmp;
658   mp_digit vbuf[MP_VALUE_DIGITS(value)];
659 
660   s_fake(&vtmp, value, vbuf);
661 
662   return mp_int_sub(a, &vtmp, c);
663 }
664 
mp_int_mul(mp_int a,mp_int b,mp_int c)665 mp_result mp_int_mul(mp_int a, mp_int b, mp_int c) {
666   assert(a != NULL && b != NULL && c != NULL);
667 
668   /* If either input is zero, we can shortcut multiplication */
669   if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0) {
670     mp_int_zero(c);
671     return MP_OK;
672   }
673 
674   /* Output is positive if inputs have same sign, otherwise negative */
675   mp_sign osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
676 
677   /* If the output is not identical to any of the inputs, we'll write the
678      results directly; otherwise, allocate a temporary space. */
679   mp_size ua = MP_USED(a);
680   mp_size ub = MP_USED(b);
681   mp_size osize = MAX(ua, ub);
682   osize = 4 * ((osize + 1) / 2);
683 
684   mp_digit *out;
685   mp_size p = 0;
686   if (c == a || c == b) {
687     p = MAX(s_round_prec(osize), default_precision);
688 
689     if ((out = s_alloc(p)) == NULL) return MP_MEMORY;
690   } else {
691     if (!s_pad(c, osize)) return MP_MEMORY;
692 
693     out = MP_DIGITS(c);
694   }
695   ZERO(out, osize);
696 
697   if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub)) return MP_MEMORY;
698 
699   /* If we allocated a new buffer, get rid of whatever memory c was already
700      using, and fix up its fields to reflect that.
701    */
702   if (out != MP_DIGITS(c)) {
703     if ((void *)MP_DIGITS(c) != (void *)c) s_free(MP_DIGITS(c));
704     c->digits = out;
705     c->alloc = p;
706   }
707 
708   c->used = osize; /* might not be true, but we'll fix it ... */
709   CLAMP(c);        /* ... right here */
710   c->sign = osign;
711 
712   return MP_OK;
713 }
714 
mp_int_mul_value(mp_int a,mp_small value,mp_int c)715 mp_result mp_int_mul_value(mp_int a, mp_small value, mp_int c) {
716   mpz_t vtmp;
717   mp_digit vbuf[MP_VALUE_DIGITS(value)];
718 
719   s_fake(&vtmp, value, vbuf);
720 
721   return mp_int_mul(a, &vtmp, c);
722 }
723 
mp_int_mul_pow2(mp_int a,mp_small p2,mp_int c)724 mp_result mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c) {
725   assert(a != NULL && c != NULL && p2 >= 0);
726 
727   mp_result res = mp_int_copy(a, c);
728   if (res != MP_OK) return res;
729 
730   if (s_qmul(c, (mp_size)p2)) {
731     return MP_OK;
732   } else {
733     return MP_MEMORY;
734   }
735 }
736 
mp_int_sqr(mp_int a,mp_int c)737 mp_result mp_int_sqr(mp_int a, mp_int c) {
738   assert(a != NULL && c != NULL);
739 
740   /* Get a temporary buffer big enough to hold the result */
741   mp_size osize = (mp_size)4 * ((MP_USED(a) + 1) / 2);
742   mp_size p = 0;
743   mp_digit *out;
744   if (a == c) {
745     p = s_round_prec(osize);
746     p = MAX(p, default_precision);
747 
748     if ((out = s_alloc(p)) == NULL) return MP_MEMORY;
749   } else {
750     if (!s_pad(c, osize)) return MP_MEMORY;
751 
752     out = MP_DIGITS(c);
753   }
754   ZERO(out, osize);
755 
756   s_ksqr(MP_DIGITS(a), out, MP_USED(a));
757 
758   /* Get rid of whatever memory c was already using, and fix up its fields to
759      reflect the new digit array it's using
760    */
761   if (out != MP_DIGITS(c)) {
762     if ((void *)MP_DIGITS(c) != (void *)c) s_free(MP_DIGITS(c));
763     c->digits = out;
764     c->alloc = p;
765   }
766 
767   c->used = osize; /* might not be true, but we'll fix it ... */
768   CLAMP(c);        /* ... right here */
769   c->sign = MP_ZPOS;
770 
771   return MP_OK;
772 }
773 
mp_int_div(mp_int a,mp_int b,mp_int q,mp_int r)774 mp_result mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r) {
775   assert(a != NULL && b != NULL && q != r);
776 
777   int cmp;
778   mp_result res = MP_OK;
779   mp_int qout, rout;
780   mp_sign sa = MP_SIGN(a);
781   mp_sign sb = MP_SIGN(b);
782   if (CMPZ(b) == 0) {
783     return MP_UNDEF;
784   } else if ((cmp = s_ucmp(a, b)) < 0) {
785     /* If |a| < |b|, no division is required:
786        q = 0, r = a
787      */
788     if (r && (res = mp_int_copy(a, r)) != MP_OK) return res;
789 
790     if (q) mp_int_zero(q);
791 
792     return MP_OK;
793   } else if (cmp == 0) {
794     /* If |a| = |b|, no division is required:
795        q = 1 or -1, r = 0
796      */
797     if (r) mp_int_zero(r);
798 
799     if (q) {
800       mp_int_zero(q);
801       q->digits[0] = 1;
802 
803       if (sa != sb) q->sign = MP_NEG;
804     }
805 
806     return MP_OK;
807   }
808 
809   /* When |a| > |b|, real division is required.  We need someplace to store
810      quotient and remainder, but q and r are allowed to be NULL or to overlap
811      with the inputs.
812    */
813   DECLARE_TEMP(2);
814   int lg;
815   if ((lg = s_isp2(b)) < 0) {
816     if (q && b != q) {
817       REQUIRE(mp_int_copy(a, q));
818       qout = q;
819     } else {
820       REQUIRE(mp_int_copy(a, TEMP(0)));
821       qout = TEMP(0);
822     }
823 
824     if (r && a != r) {
825       REQUIRE(mp_int_copy(b, r));
826       rout = r;
827     } else {
828       REQUIRE(mp_int_copy(b, TEMP(1)));
829       rout = TEMP(1);
830     }
831 
832     REQUIRE(s_udiv_knuth(qout, rout));
833   } else {
834     if (q) REQUIRE(mp_int_copy(a, q));
835     if (r) REQUIRE(mp_int_copy(a, r));
836 
837     if (q) s_qdiv(q, (mp_size)lg);
838     qout = q;
839     if (r) s_qmod(r, (mp_size)lg);
840     rout = r;
841   }
842 
843   /* Recompute signs for output */
844   if (rout) {
845     rout->sign = sa;
846     if (CMPZ(rout) == 0) rout->sign = MP_ZPOS;
847   }
848   if (qout) {
849     qout->sign = (sa == sb) ? MP_ZPOS : MP_NEG;
850     if (CMPZ(qout) == 0) qout->sign = MP_ZPOS;
851   }
852 
853   if (q) REQUIRE(mp_int_copy(qout, q));
854   if (r) REQUIRE(mp_int_copy(rout, r));
855   CLEANUP_TEMP();
856   return res;
857 }
858 
mp_int_mod(mp_int a,mp_int m,mp_int c)859 mp_result mp_int_mod(mp_int a, mp_int m, mp_int c) {
860   DECLARE_TEMP(1);
861   mp_int out = (m == c) ? TEMP(0) : c;
862   REQUIRE(mp_int_div(a, m, NULL, out));
863   if (CMPZ(out) < 0) {
864     REQUIRE(mp_int_add(out, m, c));
865   } else {
866     REQUIRE(mp_int_copy(out, c));
867   }
868   CLEANUP_TEMP();
869   return MP_OK;
870 }
871 
mp_int_div_value(mp_int a,mp_small value,mp_int q,mp_small * r)872 mp_result mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r) {
873   mpz_t vtmp;
874   mp_digit vbuf[MP_VALUE_DIGITS(value)];
875   s_fake(&vtmp, value, vbuf);
876 
877   DECLARE_TEMP(1);
878   REQUIRE(mp_int_div(a, &vtmp, q, TEMP(0)));
879 
880   if (r) (void)mp_int_to_int(TEMP(0), r); /* can't fail */
881 
882   CLEANUP_TEMP();
883   return MP_OK;
884 }
885 
mp_int_div_pow2(mp_int a,mp_small p2,mp_int q,mp_int r)886 mp_result mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r) {
887   assert(a != NULL && p2 >= 0 && q != r);
888 
889   mp_result res = MP_OK;
890   if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK) {
891     s_qdiv(q, (mp_size)p2);
892   }
893 
894   if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK) {
895     s_qmod(r, (mp_size)p2);
896   }
897 
898   return res;
899 }
900 
mp_int_expt(mp_int a,mp_small b,mp_int c)901 mp_result mp_int_expt(mp_int a, mp_small b, mp_int c) {
902   assert(c != NULL);
903   if (b < 0) return MP_RANGE;
904 
905   DECLARE_TEMP(1);
906   REQUIRE(mp_int_copy(a, TEMP(0)));
907 
908   (void)mp_int_set_value(c, 1);
909   unsigned int v = labs(b);
910   while (v != 0) {
911     if (v & 1) {
912       REQUIRE(mp_int_mul(c, TEMP(0), c));
913     }
914 
915     v >>= 1;
916     if (v == 0) break;
917 
918     REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
919   }
920 
921   CLEANUP_TEMP();
922   return MP_OK;
923 }
924 
mp_int_expt_value(mp_small a,mp_small b,mp_int c)925 mp_result mp_int_expt_value(mp_small a, mp_small b, mp_int c) {
926   assert(c != NULL);
927   if (b < 0) return MP_RANGE;
928 
929   DECLARE_TEMP(1);
930   REQUIRE(mp_int_set_value(TEMP(0), a));
931 
932   (void)mp_int_set_value(c, 1);
933   unsigned int v = labs(b);
934   while (v != 0) {
935     if (v & 1) {
936       REQUIRE(mp_int_mul(c, TEMP(0), c));
937     }
938 
939     v >>= 1;
940     if (v == 0) break;
941 
942     REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
943   }
944 
945   CLEANUP_TEMP();
946   return MP_OK;
947 }
948 
mp_int_expt_full(mp_int a,mp_int b,mp_int c)949 mp_result mp_int_expt_full(mp_int a, mp_int b, mp_int c) {
950   assert(a != NULL && b != NULL && c != NULL);
951   if (MP_SIGN(b) == MP_NEG) return MP_RANGE;
952 
953   DECLARE_TEMP(1);
954   REQUIRE(mp_int_copy(a, TEMP(0)));
955 
956   (void)mp_int_set_value(c, 1);
957   for (unsigned ix = 0; ix < MP_USED(b); ++ix) {
958     mp_digit d = b->digits[ix];
959 
960     for (unsigned jx = 0; jx < MP_DIGIT_BIT; ++jx) {
961       if (d & 1) {
962         REQUIRE(mp_int_mul(c, TEMP(0), c));
963       }
964 
965       d >>= 1;
966       if (d == 0 && ix + 1 == MP_USED(b)) break;
967       REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
968     }
969   }
970 
971   CLEANUP_TEMP();
972   return MP_OK;
973 }
974 
mp_int_compare(mp_int a,mp_int b)975 int mp_int_compare(mp_int a, mp_int b) {
976   assert(a != NULL && b != NULL);
977 
978   mp_sign sa = MP_SIGN(a);
979   if (sa == MP_SIGN(b)) {
980     int cmp = s_ucmp(a, b);
981 
982     /* If they're both zero or positive, the normal comparison applies; if both
983        negative, the sense is reversed. */
984     if (sa == MP_ZPOS) {
985       return cmp;
986     } else {
987       return -cmp;
988     }
989   } else if (sa == MP_ZPOS) {
990     return 1;
991   } else {
992     return -1;
993   }
994 }
995 
mp_int_compare_unsigned(mp_int a,mp_int b)996 int mp_int_compare_unsigned(mp_int a, mp_int b) {
997   assert(a != NULL && b != NULL);
998 
999   return s_ucmp(a, b);
1000 }
1001 
mp_int_compare_zero(mp_int z)1002 int mp_int_compare_zero(mp_int z) {
1003   assert(z != NULL);
1004 
1005   if (MP_USED(z) == 1 && z->digits[0] == 0) {
1006     return 0;
1007   } else if (MP_SIGN(z) == MP_ZPOS) {
1008     return 1;
1009   } else {
1010     return -1;
1011   }
1012 }
1013 
mp_int_compare_value(mp_int z,mp_small value)1014 int mp_int_compare_value(mp_int z, mp_small value) {
1015   assert(z != NULL);
1016 
1017   mp_sign vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1018   if (vsign == MP_SIGN(z)) {
1019     int cmp = s_vcmp(z, value);
1020 
1021     return (vsign == MP_ZPOS) ? cmp : -cmp;
1022   } else {
1023     return (value < 0) ? 1 : -1;
1024   }
1025 }
1026 
mp_int_compare_uvalue(mp_int z,mp_usmall uv)1027 int mp_int_compare_uvalue(mp_int z, mp_usmall uv) {
1028   assert(z != NULL);
1029 
1030   if (MP_SIGN(z) == MP_NEG) {
1031     return -1;
1032   } else {
1033     return s_uvcmp(z, uv);
1034   }
1035 }
1036 
mp_int_exptmod(mp_int a,mp_int b,mp_int m,mp_int c)1037 mp_result mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c) {
1038   assert(a != NULL && b != NULL && c != NULL && m != NULL);
1039 
1040   /* Zero moduli and negative exponents are not considered. */
1041   if (CMPZ(m) == 0) return MP_UNDEF;
1042   if (CMPZ(b) < 0) return MP_RANGE;
1043 
1044   mp_size um = MP_USED(m);
1045   DECLARE_TEMP(3);
1046   REQUIRE(GROW(TEMP(0), 2 * um));
1047   REQUIRE(GROW(TEMP(1), 2 * um));
1048 
1049   mp_int s;
1050   if (c == b || c == m) {
1051     REQUIRE(GROW(TEMP(2), 2 * um));
1052     s = TEMP(2);
1053   } else {
1054     s = c;
1055   }
1056 
1057   REQUIRE(mp_int_mod(a, m, TEMP(0)));
1058   REQUIRE(s_brmu(TEMP(1), m));
1059   REQUIRE(s_embar(TEMP(0), b, m, TEMP(1), s));
1060   REQUIRE(mp_int_copy(s, c));
1061 
1062   CLEANUP_TEMP();
1063   return MP_OK;
1064 }
1065 
mp_int_exptmod_evalue(mp_int a,mp_small value,mp_int m,mp_int c)1066 mp_result mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c) {
1067   mpz_t vtmp;
1068   mp_digit vbuf[MP_VALUE_DIGITS(value)];
1069 
1070   s_fake(&vtmp, value, vbuf);
1071 
1072   return mp_int_exptmod(a, &vtmp, m, c);
1073 }
1074 
mp_int_exptmod_bvalue(mp_small value,mp_int b,mp_int m,mp_int c)1075 mp_result mp_int_exptmod_bvalue(mp_small value, mp_int b, mp_int m, mp_int c) {
1076   mpz_t vtmp;
1077   mp_digit vbuf[MP_VALUE_DIGITS(value)];
1078 
1079   s_fake(&vtmp, value, vbuf);
1080 
1081   return mp_int_exptmod(&vtmp, b, m, c);
1082 }
1083 
mp_int_exptmod_known(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)1084 mp_result mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu,
1085                                mp_int c) {
1086   assert(a && b && m && c);
1087 
1088   /* Zero moduli and negative exponents are not considered. */
1089   if (CMPZ(m) == 0) return MP_UNDEF;
1090   if (CMPZ(b) < 0) return MP_RANGE;
1091 
1092   DECLARE_TEMP(2);
1093   mp_size um = MP_USED(m);
1094   REQUIRE(GROW(TEMP(0), 2 * um));
1095 
1096   mp_int s;
1097   if (c == b || c == m) {
1098     REQUIRE(GROW(TEMP(1), 2 * um));
1099     s = TEMP(1);
1100   } else {
1101     s = c;
1102   }
1103 
1104   REQUIRE(mp_int_mod(a, m, TEMP(0)));
1105   REQUIRE(s_embar(TEMP(0), b, m, mu, s));
1106   REQUIRE(mp_int_copy(s, c));
1107 
1108   CLEANUP_TEMP();
1109   return MP_OK;
1110 }
1111 
mp_int_redux_const(mp_int m,mp_int c)1112 mp_result mp_int_redux_const(mp_int m, mp_int c) {
1113   assert(m != NULL && c != NULL && m != c);
1114 
1115   return s_brmu(c, m);
1116 }
1117 
mp_int_invmod(mp_int a,mp_int m,mp_int c)1118 mp_result mp_int_invmod(mp_int a, mp_int m, mp_int c) {
1119   assert(a != NULL && m != NULL && c != NULL);
1120 
1121   if (CMPZ(a) == 0 || CMPZ(m) <= 0) return MP_RANGE;
1122 
1123   DECLARE_TEMP(2);
1124 
1125   REQUIRE(mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL));
1126 
1127   if (mp_int_compare_value(TEMP(0), 1) != 0) {
1128     REQUIRE(MP_UNDEF);
1129   }
1130 
1131   /* It is first necessary to constrain the value to the proper range */
1132   REQUIRE(mp_int_mod(TEMP(1), m, TEMP(1)));
1133 
1134   /* Now, if 'a' was originally negative, the value we have is actually the
1135      magnitude of the negative representative; to get the positive value we
1136      have to subtract from the modulus.  Otherwise, the value is okay as it
1137      stands.
1138    */
1139   if (MP_SIGN(a) == MP_NEG) {
1140     REQUIRE(mp_int_sub(m, TEMP(1), c));
1141   } else {
1142     REQUIRE(mp_int_copy(TEMP(1), c));
1143   }
1144 
1145   CLEANUP_TEMP();
1146   return MP_OK;
1147 }
1148 
1149 /* Binary GCD algorithm due to Josef Stein, 1961 */
mp_int_gcd(mp_int a,mp_int b,mp_int c)1150 mp_result mp_int_gcd(mp_int a, mp_int b, mp_int c) {
1151   assert(a != NULL && b != NULL && c != NULL);
1152 
1153   int ca = CMPZ(a);
1154   int cb = CMPZ(b);
1155   if (ca == 0 && cb == 0) {
1156     return MP_UNDEF;
1157   } else if (ca == 0) {
1158     return mp_int_abs(b, c);
1159   } else if (cb == 0) {
1160     return mp_int_abs(a, c);
1161   }
1162 
1163   DECLARE_TEMP(3);
1164   REQUIRE(mp_int_copy(a, TEMP(0)));
1165   REQUIRE(mp_int_copy(b, TEMP(1)));
1166 
1167   TEMP(0)->sign = MP_ZPOS;
1168   TEMP(1)->sign = MP_ZPOS;
1169 
1170   int k = 0;
1171   { /* Divide out common factors of 2 from u and v */
1172     int div2_u = s_dp2k(TEMP(0));
1173     int div2_v = s_dp2k(TEMP(1));
1174 
1175     k = MIN(div2_u, div2_v);
1176     s_qdiv(TEMP(0), (mp_size)k);
1177     s_qdiv(TEMP(1), (mp_size)k);
1178   }
1179 
1180   if (mp_int_is_odd(TEMP(0))) {
1181     REQUIRE(mp_int_neg(TEMP(1), TEMP(2)));
1182   } else {
1183     REQUIRE(mp_int_copy(TEMP(0), TEMP(2)));
1184   }
1185 
1186   for (;;) {
1187     s_qdiv(TEMP(2), s_dp2k(TEMP(2)));
1188 
1189     if (CMPZ(TEMP(2)) > 0) {
1190       REQUIRE(mp_int_copy(TEMP(2), TEMP(0)));
1191     } else {
1192       REQUIRE(mp_int_neg(TEMP(2), TEMP(1)));
1193     }
1194 
1195     REQUIRE(mp_int_sub(TEMP(0), TEMP(1), TEMP(2)));
1196 
1197     if (CMPZ(TEMP(2)) == 0) break;
1198   }
1199 
1200   REQUIRE(mp_int_abs(TEMP(0), c));
1201   if (!s_qmul(c, (mp_size)k)) REQUIRE(MP_MEMORY);
1202 
1203   CLEANUP_TEMP();
1204   return MP_OK;
1205 }
1206 
1207 /* This is the binary GCD algorithm again, but this time we keep track of the
1208    elementary matrix operations as we go, so we can get values x and y
1209    satisfying c = ax + by.
1210  */
mp_int_egcd(mp_int a,mp_int b,mp_int c,mp_int x,mp_int y)1211 mp_result mp_int_egcd(mp_int a, mp_int b, mp_int c, mp_int x, mp_int y) {
1212   assert(a != NULL && b != NULL && c != NULL && (x != NULL || y != NULL));
1213 
1214   mp_result res = MP_OK;
1215   int ca = CMPZ(a);
1216   int cb = CMPZ(b);
1217   if (ca == 0 && cb == 0) {
1218     return MP_UNDEF;
1219   } else if (ca == 0) {
1220     if ((res = mp_int_abs(b, c)) != MP_OK) return res;
1221     mp_int_zero(x);
1222     (void)mp_int_set_value(y, 1);
1223     return MP_OK;
1224   } else if (cb == 0) {
1225     if ((res = mp_int_abs(a, c)) != MP_OK) return res;
1226     (void)mp_int_set_value(x, 1);
1227     mp_int_zero(y);
1228     return MP_OK;
1229   }
1230 
1231   /* Initialize temporaries:
1232      A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7 */
1233   DECLARE_TEMP(8);
1234   REQUIRE(mp_int_set_value(TEMP(0), 1));
1235   REQUIRE(mp_int_set_value(TEMP(3), 1));
1236   REQUIRE(mp_int_copy(a, TEMP(4)));
1237   REQUIRE(mp_int_copy(b, TEMP(5)));
1238 
1239   /* We will work with absolute values here */
1240   TEMP(4)->sign = MP_ZPOS;
1241   TEMP(5)->sign = MP_ZPOS;
1242 
1243   int k = 0;
1244   { /* Divide out common factors of 2 from u and v */
1245     int div2_u = s_dp2k(TEMP(4)), div2_v = s_dp2k(TEMP(5));
1246 
1247     k = MIN(div2_u, div2_v);
1248     s_qdiv(TEMP(4), k);
1249     s_qdiv(TEMP(5), k);
1250   }
1251 
1252   REQUIRE(mp_int_copy(TEMP(4), TEMP(6)));
1253   REQUIRE(mp_int_copy(TEMP(5), TEMP(7)));
1254 
1255   for (;;) {
1256     while (mp_int_is_even(TEMP(4))) {
1257       s_qdiv(TEMP(4), 1);
1258 
1259       if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1))) {
1260         REQUIRE(mp_int_add(TEMP(0), TEMP(7), TEMP(0)));
1261         REQUIRE(mp_int_sub(TEMP(1), TEMP(6), TEMP(1)));
1262       }
1263 
1264       s_qdiv(TEMP(0), 1);
1265       s_qdiv(TEMP(1), 1);
1266     }
1267 
1268     while (mp_int_is_even(TEMP(5))) {
1269       s_qdiv(TEMP(5), 1);
1270 
1271       if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3))) {
1272         REQUIRE(mp_int_add(TEMP(2), TEMP(7), TEMP(2)));
1273         REQUIRE(mp_int_sub(TEMP(3), TEMP(6), TEMP(3)));
1274       }
1275 
1276       s_qdiv(TEMP(2), 1);
1277       s_qdiv(TEMP(3), 1);
1278     }
1279 
1280     if (mp_int_compare(TEMP(4), TEMP(5)) >= 0) {
1281       REQUIRE(mp_int_sub(TEMP(4), TEMP(5), TEMP(4)));
1282       REQUIRE(mp_int_sub(TEMP(0), TEMP(2), TEMP(0)));
1283       REQUIRE(mp_int_sub(TEMP(1), TEMP(3), TEMP(1)));
1284     } else {
1285       REQUIRE(mp_int_sub(TEMP(5), TEMP(4), TEMP(5)));
1286       REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1287       REQUIRE(mp_int_sub(TEMP(3), TEMP(1), TEMP(3)));
1288     }
1289 
1290     if (CMPZ(TEMP(4)) == 0) {
1291       if (x) REQUIRE(mp_int_copy(TEMP(2), x));
1292       if (y) REQUIRE(mp_int_copy(TEMP(3), y));
1293       if (c) {
1294         if (!s_qmul(TEMP(5), k)) {
1295           REQUIRE(MP_MEMORY);
1296         }
1297         REQUIRE(mp_int_copy(TEMP(5), c));
1298       }
1299 
1300       break;
1301     }
1302   }
1303 
1304   CLEANUP_TEMP();
1305   return MP_OK;
1306 }
1307 
mp_int_lcm(mp_int a,mp_int b,mp_int c)1308 mp_result mp_int_lcm(mp_int a, mp_int b, mp_int c) {
1309   assert(a != NULL && b != NULL && c != NULL);
1310 
1311   /* Since a * b = gcd(a, b) * lcm(a, b), we can compute
1312      lcm(a, b) = (a / gcd(a, b)) * b.
1313 
1314      This formulation insures everything works even if the input
1315      variables share space.
1316    */
1317   DECLARE_TEMP(1);
1318   REQUIRE(mp_int_gcd(a, b, TEMP(0)));
1319   REQUIRE(mp_int_div(a, TEMP(0), TEMP(0), NULL));
1320   REQUIRE(mp_int_mul(TEMP(0), b, TEMP(0)));
1321   REQUIRE(mp_int_copy(TEMP(0), c));
1322 
1323   CLEANUP_TEMP();
1324   return MP_OK;
1325 }
1326 
mp_int_divisible_value(mp_int a,mp_small v)1327 bool mp_int_divisible_value(mp_int a, mp_small v) {
1328   mp_small rem = 0;
1329 
1330   if (mp_int_div_value(a, v, NULL, &rem) != MP_OK) {
1331     return false;
1332   }
1333   return rem == 0;
1334 }
1335 
mp_int_is_pow2(mp_int z)1336 int mp_int_is_pow2(mp_int z) {
1337   assert(z != NULL);
1338 
1339   return s_isp2(z);
1340 }
1341 
1342 /* Implementation of Newton's root finding method, based loosely on a patch
1343    contributed by Hal Finkel <half@halssoftware.com>
1344    modified by M. J. Fromberger.
1345  */
mp_int_root(mp_int a,mp_small b,mp_int c)1346 mp_result mp_int_root(mp_int a, mp_small b, mp_int c) {
1347   assert(a != NULL && c != NULL && b > 0);
1348 
1349   if (b == 1) {
1350     return mp_int_copy(a, c);
1351   }
1352   bool flips = false;
1353   if (MP_SIGN(a) == MP_NEG) {
1354     if (b % 2 == 0) {
1355       return MP_UNDEF; /* root does not exist for negative a with even b */
1356     } else {
1357       flips = true;
1358     }
1359   }
1360 
1361   DECLARE_TEMP(5);
1362   REQUIRE(mp_int_copy(a, TEMP(0)));
1363   REQUIRE(mp_int_copy(a, TEMP(1)));
1364   TEMP(0)->sign = MP_ZPOS;
1365   TEMP(1)->sign = MP_ZPOS;
1366 
1367   for (;;) {
1368     REQUIRE(mp_int_expt(TEMP(1), b, TEMP(2)));
1369 
1370     if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0) break;
1371 
1372     REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1373     REQUIRE(mp_int_expt(TEMP(1), b - 1, TEMP(3)));
1374     REQUIRE(mp_int_mul_value(TEMP(3), b, TEMP(3)));
1375     REQUIRE(mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL));
1376     REQUIRE(mp_int_sub(TEMP(1), TEMP(4), TEMP(4)));
1377 
1378     if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0) {
1379       REQUIRE(mp_int_sub_value(TEMP(4), 1, TEMP(4)));
1380     }
1381     REQUIRE(mp_int_copy(TEMP(4), TEMP(1)));
1382   }
1383 
1384   REQUIRE(mp_int_copy(TEMP(1), c));
1385 
1386   /* If the original value of a was negative, flip the output sign. */
1387   if (flips) (void)mp_int_neg(c, c); /* cannot fail */
1388 
1389   CLEANUP_TEMP();
1390   return MP_OK;
1391 }
1392 
mp_int_to_int(mp_int z,mp_small * out)1393 mp_result mp_int_to_int(mp_int z, mp_small *out) {
1394   assert(z != NULL);
1395 
1396   /* Make sure the value is representable as a small integer */
1397   mp_sign sz = MP_SIGN(z);
1398   if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) ||
1399       mp_int_compare_value(z, MP_SMALL_MIN) < 0) {
1400     return MP_RANGE;
1401   }
1402 
1403   mp_usmall uz = MP_USED(z);
1404   mp_digit *dz = MP_DIGITS(z) + uz - 1;
1405   mp_small uv = 0;
1406   while (uz > 0) {
1407     uv <<= MP_DIGIT_BIT / 2;
1408     uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1409     --uz;
1410   }
1411 
1412   if (out) *out = (mp_small)((sz == MP_NEG) ? -uv : uv);
1413 
1414   return MP_OK;
1415 }
1416 
mp_int_to_uint(mp_int z,mp_usmall * out)1417 mp_result mp_int_to_uint(mp_int z, mp_usmall *out) {
1418   assert(z != NULL);
1419 
1420   /* Make sure the value is representable as an unsigned small integer */
1421   mp_size sz = MP_SIGN(z);
1422   if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0) {
1423     return MP_RANGE;
1424   }
1425 
1426   mp_size uz = MP_USED(z);
1427   mp_digit *dz = MP_DIGITS(z) + uz - 1;
1428   mp_usmall uv = 0;
1429 
1430   while (uz > 0) {
1431     uv <<= MP_DIGIT_BIT / 2;
1432     uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1433     --uz;
1434   }
1435 
1436   if (out) *out = uv;
1437 
1438   return MP_OK;
1439 }
1440 
mp_int_to_string(mp_int z,mp_size radix,char * str,int limit)1441 mp_result mp_int_to_string(mp_int z, mp_size radix, char *str, int limit) {
1442   assert(z != NULL && str != NULL && limit >= 2);
1443   assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1444 
1445   int cmp = 0;
1446   if (CMPZ(z) == 0) {
1447     *str++ = s_val2ch(0, 1);
1448   } else {
1449     mp_result res;
1450     mpz_t tmp;
1451     char *h, *t;
1452 
1453     if ((res = mp_int_init_copy(&tmp, z)) != MP_OK) return res;
1454 
1455     if (MP_SIGN(z) == MP_NEG) {
1456       *str++ = '-';
1457       --limit;
1458     }
1459     h = str;
1460 
1461     /* Generate digits in reverse order until finished or limit reached */
1462     for (/* */; limit > 0; --limit) {
1463       mp_digit d;
1464 
1465       if ((cmp = CMPZ(&tmp)) == 0) break;
1466 
1467       d = s_ddiv(&tmp, (mp_digit)radix);
1468       *str++ = s_val2ch(d, 1);
1469     }
1470     t = str - 1;
1471 
1472     /* Put digits back in correct output order */
1473     while (h < t) {
1474       char tc = *h;
1475       *h++ = *t;
1476       *t-- = tc;
1477     }
1478 
1479     mp_int_clear(&tmp);
1480   }
1481 
1482   *str = '\0';
1483   if (cmp == 0) {
1484     return MP_OK;
1485   } else {
1486     return MP_TRUNC;
1487   }
1488 }
1489 
mp_int_string_len(mp_int z,mp_size radix)1490 mp_result mp_int_string_len(mp_int z, mp_size radix) {
1491   assert(z != NULL);
1492   assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1493 
1494   int len = s_outlen(z, radix) + 1; /* for terminator */
1495 
1496   /* Allow for sign marker on negatives */
1497   if (MP_SIGN(z) == MP_NEG) len += 1;
1498 
1499   return len;
1500 }
1501 
1502 /* Read zero-terminated string into z */
mp_int_read_string(mp_int z,mp_size radix,const char * str)1503 mp_result mp_int_read_string(mp_int z, mp_size radix, const char *str) {
1504   return mp_int_read_cstring(z, radix, str, NULL);
1505 }
1506 
mp_int_read_cstring(mp_int z,mp_size radix,const char * str,char ** end)1507 mp_result mp_int_read_cstring(mp_int z, mp_size radix, const char *str,
1508                               char **end) {
1509   assert(z != NULL && str != NULL);
1510   assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1511 
1512   /* Skip leading whitespace */
1513   while (isspace((unsigned char)*str)) ++str;
1514 
1515   /* Handle leading sign tag (+/-, positive default) */
1516   switch (*str) {
1517     case '-':
1518       z->sign = MP_NEG;
1519       ++str;
1520       break;
1521     case '+':
1522       ++str; /* fallthrough */
1523     default:
1524       z->sign = MP_ZPOS;
1525       break;
1526   }
1527 
1528   /* Skip leading zeroes */
1529   int ch;
1530   while ((ch = s_ch2val(*str, radix)) == 0) ++str;
1531 
1532   /* Make sure there is enough space for the value */
1533   if (!s_pad(z, s_inlen(strlen(str), radix))) return MP_MEMORY;
1534 
1535   z->used = 1;
1536   z->digits[0] = 0;
1537 
1538   while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0)) {
1539     s_dmul(z, (mp_digit)radix);
1540     s_dadd(z, (mp_digit)ch);
1541     ++str;
1542   }
1543 
1544   CLAMP(z);
1545 
1546   /* Override sign for zero, even if negative specified. */
1547   if (CMPZ(z) == 0) z->sign = MP_ZPOS;
1548 
1549   if (end != NULL) *end = (char *)str;
1550 
1551   /* Return a truncation error if the string has unprocessed characters
1552      remaining, so the caller can tell if the whole string was done */
1553   if (*str != '\0') {
1554     return MP_TRUNC;
1555   } else {
1556     return MP_OK;
1557   }
1558 }
1559 
mp_int_count_bits(mp_int z)1560 mp_result mp_int_count_bits(mp_int z) {
1561   assert(z != NULL);
1562 
1563   mp_size uz = MP_USED(z);
1564   if (uz == 1 && z->digits[0] == 0) return 1;
1565 
1566   --uz;
1567   mp_size nbits = uz * MP_DIGIT_BIT;
1568   mp_digit d = z->digits[uz];
1569 
1570   while (d != 0) {
1571     d >>= 1;
1572     ++nbits;
1573   }
1574 
1575   return nbits;
1576 }
1577 
mp_int_to_binary(mp_int z,unsigned char * buf,int limit)1578 mp_result mp_int_to_binary(mp_int z, unsigned char *buf, int limit) {
1579   static const int PAD_FOR_2C = 1;
1580 
1581   assert(z != NULL && buf != NULL);
1582 
1583   int limpos = limit;
1584   mp_result res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
1585 
1586   if (MP_SIGN(z) == MP_NEG) s_2comp(buf, limpos);
1587 
1588   return res;
1589 }
1590 
mp_int_read_binary(mp_int z,unsigned char * buf,int len)1591 mp_result mp_int_read_binary(mp_int z, unsigned char *buf, int len) {
1592   assert(z != NULL && buf != NULL && len > 0);
1593 
1594   /* Figure out how many digits are needed to represent this value */
1595   mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1596   if (!s_pad(z, need)) return MP_MEMORY;
1597 
1598   mp_int_zero(z);
1599 
1600   /* If the high-order bit is set, take the 2's complement before reading the
1601      value (it will be restored afterward) */
1602   if (buf[0] >> (CHAR_BIT - 1)) {
1603     z->sign = MP_NEG;
1604     s_2comp(buf, len);
1605   }
1606 
1607   mp_digit *dz = MP_DIGITS(z);
1608   unsigned char *tmp = buf;
1609   for (int i = len; i > 0; --i, ++tmp) {
1610     s_qmul(z, (mp_size)CHAR_BIT);
1611     *dz |= *tmp;
1612   }
1613 
1614   /* Restore 2's complement if we took it before */
1615   if (MP_SIGN(z) == MP_NEG) s_2comp(buf, len);
1616 
1617   return MP_OK;
1618 }
1619 
mp_int_binary_len(mp_int z)1620 mp_result mp_int_binary_len(mp_int z) {
1621   mp_result res = mp_int_count_bits(z);
1622   if (res <= 0) return res;
1623 
1624   int bytes = mp_int_unsigned_len(z);
1625 
1626   /* If the highest-order bit falls exactly on a byte boundary, we need to pad
1627      with an extra byte so that the sign will be read correctly when reading it
1628      back in. */
1629   if (bytes * CHAR_BIT == res) ++bytes;
1630 
1631   return bytes;
1632 }
1633 
mp_int_to_unsigned(mp_int z,unsigned char * buf,int limit)1634 mp_result mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit) {
1635   static const int NO_PADDING = 0;
1636 
1637   assert(z != NULL && buf != NULL);
1638 
1639   return s_tobin(z, buf, &limit, NO_PADDING);
1640 }
1641 
mp_int_read_unsigned(mp_int z,unsigned char * buf,int len)1642 mp_result mp_int_read_unsigned(mp_int z, unsigned char *buf, int len) {
1643   assert(z != NULL && buf != NULL && len > 0);
1644 
1645   /* Figure out how many digits are needed to represent this value */
1646   mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1647   if (!s_pad(z, need)) return MP_MEMORY;
1648 
1649   mp_int_zero(z);
1650 
1651   unsigned char *tmp = buf;
1652   for (int i = len; i > 0; --i, ++tmp) {
1653     (void)s_qmul(z, CHAR_BIT);
1654     *MP_DIGITS(z) |= *tmp;
1655   }
1656 
1657   return MP_OK;
1658 }
1659 
mp_int_unsigned_len(mp_int z)1660 mp_result mp_int_unsigned_len(mp_int z) {
1661   mp_result res = mp_int_count_bits(z);
1662   if (res <= 0) return res;
1663 
1664   int bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
1665   return bytes;
1666 }
1667 
mp_error_string(mp_result res)1668 const char *mp_error_string(mp_result res) {
1669   if (res > 0) return s_unknown_err;
1670 
1671   res = -res;
1672   int ix;
1673   for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
1674     ;
1675 
1676   if (s_error_msg[ix] != NULL) {
1677     return s_error_msg[ix];
1678   } else {
1679     return s_unknown_err;
1680   }
1681 }
1682 
1683 /*------------------------------------------------------------------------*/
1684 /* Private functions for internal use.  These make assumptions.           */
1685 
1686 #if DEBUG
1687 static const mp_digit fill = (mp_digit)0xdeadbeefabad1dea;
1688 #endif
1689 
s_alloc(mp_size num)1690 static mp_digit *s_alloc(mp_size num) {
1691   mp_digit *out = malloc(num * sizeof(mp_digit));
1692   assert(out != NULL);
1693 
1694 #if DEBUG
1695   for (mp_size ix = 0; ix < num; ++ix) out[ix] = fill;
1696 #endif
1697   return out;
1698 }
1699 
s_realloc(mp_digit * old,mp_size osize,mp_size nsize)1700 static mp_digit *s_realloc(mp_digit *old, mp_size osize, mp_size nsize) {
1701 #if DEBUG
1702   mp_digit *new = s_alloc(nsize);
1703   assert(new != NULL);
1704 
1705   for (mp_size ix = 0; ix < nsize; ++ix) new[ix] = fill;
1706   memcpy(new, old, osize * sizeof(mp_digit));
1707 #else
1708   mp_digit *new = realloc(old, nsize * sizeof(mp_digit));
1709   assert(new != NULL);
1710 #endif
1711 
1712   return new;
1713 }
1714 
s_free(void * ptr)1715 static void s_free(void *ptr) { free(ptr); }
1716 
s_pad(mp_int z,mp_size min)1717 static bool s_pad(mp_int z, mp_size min) {
1718   if (MP_ALLOC(z) < min) {
1719     mp_size nsize = s_round_prec(min);
1720     mp_digit *tmp;
1721 
1722     if (z->digits == &(z->single)) {
1723       if ((tmp = s_alloc(nsize)) == NULL) return false;
1724       tmp[0] = z->single;
1725     } else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL) {
1726       return false;
1727     }
1728 
1729     z->digits = tmp;
1730     z->alloc = nsize;
1731   }
1732 
1733   return true;
1734 }
1735 
1736 /* Note: This will not work correctly when value == MP_SMALL_MIN */
s_fake(mp_int z,mp_small value,mp_digit vbuf[])1737 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]) {
1738   mp_usmall uv = (mp_usmall)(value < 0) ? -value : value;
1739   s_ufake(z, uv, vbuf);
1740   if (value < 0) z->sign = MP_NEG;
1741 }
1742 
s_ufake(mp_int z,mp_usmall value,mp_digit vbuf[])1743 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]) {
1744   mp_size ndig = (mp_size)s_uvpack(value, vbuf);
1745 
1746   z->used = ndig;
1747   z->alloc = MP_VALUE_DIGITS(value);
1748   z->sign = MP_ZPOS;
1749   z->digits = vbuf;
1750 }
1751 
s_cdig(mp_digit * da,mp_digit * db,mp_size len)1752 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len) {
1753   mp_digit *dat = da + len - 1, *dbt = db + len - 1;
1754 
1755   for (/* */; len != 0; --len, --dat, --dbt) {
1756     if (*dat > *dbt) {
1757       return 1;
1758     } else if (*dat < *dbt) {
1759       return -1;
1760     }
1761   }
1762 
1763   return 0;
1764 }
1765 
s_uvpack(mp_usmall uv,mp_digit t[])1766 static int s_uvpack(mp_usmall uv, mp_digit t[]) {
1767   int ndig = 0;
1768 
1769   if (uv == 0)
1770     t[ndig++] = 0;
1771   else {
1772     while (uv != 0) {
1773       t[ndig++] = (mp_digit)uv;
1774       uv >>= MP_DIGIT_BIT / 2;
1775       uv >>= MP_DIGIT_BIT / 2;
1776     }
1777   }
1778 
1779   return ndig;
1780 }
1781 
s_ucmp(mp_int a,mp_int b)1782 static int s_ucmp(mp_int a, mp_int b) {
1783   mp_size ua = MP_USED(a), ub = MP_USED(b);
1784 
1785   if (ua > ub) {
1786     return 1;
1787   } else if (ub > ua) {
1788     return -1;
1789   } else {
1790     return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
1791   }
1792 }
1793 
s_vcmp(mp_int a,mp_small v)1794 static int s_vcmp(mp_int a, mp_small v) {
1795   mp_usmall uv = (v < 0) ? -(mp_usmall)v : (mp_usmall)v;
1796   return s_uvcmp(a, uv);
1797 }
1798 
s_uvcmp(mp_int a,mp_usmall uv)1799 static int s_uvcmp(mp_int a, mp_usmall uv) {
1800   mpz_t vtmp;
1801   mp_digit vdig[MP_VALUE_DIGITS(uv)];
1802 
1803   s_ufake(&vtmp, uv, vdig);
1804   return s_ucmp(a, &vtmp);
1805 }
1806 
s_uadd(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)1807 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1808                        mp_size size_b) {
1809   mp_size pos;
1810   mp_word w = 0;
1811 
1812   /* Insure that da is the longer of the two to simplify later code */
1813   if (size_b > size_a) {
1814     SWAP(mp_digit *, da, db);
1815     SWAP(mp_size, size_a, size_b);
1816   }
1817 
1818   /* Add corresponding digits until the shorter number runs out */
1819   for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
1820     w = w + (mp_word)*da + (mp_word)*db;
1821     *dc = LOWER_HALF(w);
1822     w = UPPER_HALF(w);
1823   }
1824 
1825   /* Propagate carries as far as necessary */
1826   for (/* */; pos < size_a; ++pos, ++da, ++dc) {
1827     w = w + *da;
1828 
1829     *dc = LOWER_HALF(w);
1830     w = UPPER_HALF(w);
1831   }
1832 
1833   /* Return carry out */
1834   return (mp_digit)w;
1835 }
1836 
s_usub(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)1837 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1838                    mp_size size_b) {
1839   mp_size pos;
1840   mp_word w = 0;
1841 
1842   /* We assume that |a| >= |b| so this should definitely hold */
1843   assert(size_a >= size_b);
1844 
1845   /* Subtract corresponding digits and propagate borrow */
1846   for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
1847     w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */
1848          (mp_word)*da) -
1849         w - (mp_word)*db;
1850 
1851     *dc = LOWER_HALF(w);
1852     w = (UPPER_HALF(w) == 0);
1853   }
1854 
1855   /* Finish the subtraction for remaining upper digits of da */
1856   for (/* */; pos < size_a; ++pos, ++da, ++dc) {
1857     w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */
1858          (mp_word)*da) -
1859         w;
1860 
1861     *dc = LOWER_HALF(w);
1862     w = (UPPER_HALF(w) == 0);
1863   }
1864 
1865   /* If there is a borrow out at the end, it violates the precondition */
1866   assert(w == 0);
1867 }
1868 
s_kmul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)1869 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1870                   mp_size size_b) {
1871   mp_size bot_size;
1872 
1873   /* Make sure b is the smaller of the two input values */
1874   if (size_b > size_a) {
1875     SWAP(mp_digit *, da, db);
1876     SWAP(mp_size, size_a, size_b);
1877   }
1878 
1879   /* Insure that the bottom is the larger half in an odd-length split; the code
1880      below relies on this being true.
1881    */
1882   bot_size = (size_a + 1) / 2;
1883 
1884   /* If the values are big enough to bother with recursion, use the Karatsuba
1885      algorithm to compute the product; otherwise use the normal multiplication
1886      algorithm
1887    */
1888   if (multiply_threshold && size_a >= multiply_threshold && size_b > bot_size) {
1889     mp_digit *t1, *t2, *t3, carry;
1890 
1891     mp_digit *a_top = da + bot_size;
1892     mp_digit *b_top = db + bot_size;
1893 
1894     mp_size at_size = size_a - bot_size;
1895     mp_size bt_size = size_b - bot_size;
1896     mp_size buf_size = 2 * bot_size;
1897 
1898     /* Do a single allocation for all three temporary buffers needed; each
1899        buffer must be big enough to hold the product of two bottom halves, and
1900        one buffer needs space for the completed product; twice the space is
1901        plenty.
1902      */
1903     if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
1904     t2 = t1 + buf_size;
1905     t3 = t2 + buf_size;
1906     ZERO(t1, 4 * buf_size);
1907 
1908     /* t1 and t2 are initially used as temporaries to compute the inner product
1909        (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
1910      */
1911     carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */
1912     t1[bot_size] = carry;
1913 
1914     carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */
1915     t2[bot_size] = carry;
1916 
1917     (void)s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */
1918 
1919     /* Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so that
1920        we're left with only the pieces we want:  t3 = a1b0 + a0b1
1921      */
1922     ZERO(t1, buf_size);
1923     ZERO(t2, buf_size);
1924     (void)s_kmul(da, db, t1, bot_size, bot_size);     /* t1 = a0 * b0 */
1925     (void)s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */
1926 
1927     /* Subtract out t1 and t2 to get the inner product */
1928     s_usub(t3, t1, t3, buf_size + 2, buf_size);
1929     s_usub(t3, t2, t3, buf_size + 2, buf_size);
1930 
1931     /* Assemble the output value */
1932     COPY(t1, dc, buf_size);
1933     carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
1934     assert(carry == 0);
1935 
1936     carry =
1937         s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
1938     assert(carry == 0);
1939 
1940     s_free(t1); /* note t2 and t3 are just internal pointers to t1 */
1941   } else {
1942     s_umul(da, db, dc, size_a, size_b);
1943   }
1944 
1945   return 1;
1946 }
1947 
s_umul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)1948 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1949                    mp_size size_b) {
1950   mp_size a, b;
1951   mp_word w;
1952 
1953   for (a = 0; a < size_a; ++a, ++dc, ++da) {
1954     mp_digit *dct = dc;
1955     mp_digit *dbt = db;
1956 
1957     if (*da == 0) continue;
1958 
1959     w = 0;
1960     for (b = 0; b < size_b; ++b, ++dbt, ++dct) {
1961       w = (mp_word)*da * (mp_word)*dbt + w + (mp_word)*dct;
1962 
1963       *dct = LOWER_HALF(w);
1964       w = UPPER_HALF(w);
1965     }
1966 
1967     *dct = (mp_digit)w;
1968   }
1969 }
1970 
s_ksqr(mp_digit * da,mp_digit * dc,mp_size size_a)1971 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a) {
1972   if (multiply_threshold && size_a > multiply_threshold) {
1973     mp_size bot_size = (size_a + 1) / 2;
1974     mp_digit *a_top = da + bot_size;
1975     mp_digit *t1, *t2, *t3, carry;
1976     mp_size at_size = size_a - bot_size;
1977     mp_size buf_size = 2 * bot_size;
1978 
1979     if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
1980     t2 = t1 + buf_size;
1981     t3 = t2 + buf_size;
1982     ZERO(t1, 4 * buf_size);
1983 
1984     (void)s_ksqr(da, t1, bot_size);   /* t1 = a0 ^ 2 */
1985     (void)s_ksqr(a_top, t2, at_size); /* t2 = a1 ^ 2 */
1986 
1987     (void)s_kmul(da, a_top, t3, bot_size, at_size); /* t3 = a0 * a1 */
1988 
1989     /* Quick multiply t3 by 2, shifting left (can't overflow) */
1990     {
1991       int i, top = bot_size + at_size;
1992       mp_word w, save = 0;
1993 
1994       for (i = 0; i < top; ++i) {
1995         w = t3[i];
1996         w = (w << 1) | save;
1997         t3[i] = LOWER_HALF(w);
1998         save = UPPER_HALF(w);
1999       }
2000       t3[i] = LOWER_HALF(save);
2001     }
2002 
2003     /* Assemble the output value */
2004     COPY(t1, dc, 2 * bot_size);
2005     carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
2006     assert(carry == 0);
2007 
2008     carry =
2009         s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
2010     assert(carry == 0);
2011 
2012     s_free(t1); /* note that t2 and t2 are internal pointers only */
2013 
2014   } else {
2015     s_usqr(da, dc, size_a);
2016   }
2017 
2018   return 1;
2019 }
2020 
s_usqr(mp_digit * da,mp_digit * dc,mp_size size_a)2021 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a) {
2022   mp_size i, j;
2023   mp_word w;
2024 
2025   for (i = 0; i < size_a; ++i, dc += 2, ++da) {
2026     mp_digit *dct = dc, *dat = da;
2027 
2028     if (*da == 0) continue;
2029 
2030     /* Take care of the first digit, no rollover */
2031     w = (mp_word)*dat * (mp_word)*dat + (mp_word)*dct;
2032     *dct = LOWER_HALF(w);
2033     w = UPPER_HALF(w);
2034     ++dat;
2035     ++dct;
2036 
2037     for (j = i + 1; j < size_a; ++j, ++dat, ++dct) {
2038       mp_word t = (mp_word)*da * (mp_word)*dat;
2039       mp_word u = w + (mp_word)*dct, ov = 0;
2040 
2041       /* Check if doubling t will overflow a word */
2042       if (HIGH_BIT_SET(t)) ov = 1;
2043 
2044       w = t + t;
2045 
2046       /* Check if adding u to w will overflow a word */
2047       if (ADD_WILL_OVERFLOW(w, u)) ov = 1;
2048 
2049       w += u;
2050 
2051       *dct = LOWER_HALF(w);
2052       w = UPPER_HALF(w);
2053       if (ov) {
2054         w += MP_DIGIT_MAX; /* MP_RADIX */
2055         ++w;
2056       }
2057     }
2058 
2059     w = w + *dct;
2060     *dct = (mp_digit)w;
2061     while ((w = UPPER_HALF(w)) != 0) {
2062       ++dct;
2063       w = w + *dct;
2064       *dct = LOWER_HALF(w);
2065     }
2066 
2067     assert(w == 0);
2068   }
2069 }
2070 
s_dadd(mp_int a,mp_digit b)2071 static void s_dadd(mp_int a, mp_digit b) {
2072   mp_word w = 0;
2073   mp_digit *da = MP_DIGITS(a);
2074   mp_size ua = MP_USED(a);
2075 
2076   w = (mp_word)*da + b;
2077   *da++ = LOWER_HALF(w);
2078   w = UPPER_HALF(w);
2079 
2080   for (ua -= 1; ua > 0; --ua, ++da) {
2081     w = (mp_word)*da + w;
2082 
2083     *da = LOWER_HALF(w);
2084     w = UPPER_HALF(w);
2085   }
2086 
2087   if (w) {
2088     *da = (mp_digit)w;
2089     a->used += 1;
2090   }
2091 }
2092 
s_dmul(mp_int a,mp_digit b)2093 static void s_dmul(mp_int a, mp_digit b) {
2094   mp_word w = 0;
2095   mp_digit *da = MP_DIGITS(a);
2096   mp_size ua = MP_USED(a);
2097 
2098   while (ua > 0) {
2099     w = (mp_word)*da * b + w;
2100     *da++ = LOWER_HALF(w);
2101     w = UPPER_HALF(w);
2102     --ua;
2103   }
2104 
2105   if (w) {
2106     *da = (mp_digit)w;
2107     a->used += 1;
2108   }
2109 }
2110 
s_dbmul(mp_digit * da,mp_digit b,mp_digit * dc,mp_size size_a)2111 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a) {
2112   mp_word w = 0;
2113 
2114   while (size_a > 0) {
2115     w = (mp_word)*da++ * (mp_word)b + w;
2116 
2117     *dc++ = LOWER_HALF(w);
2118     w = UPPER_HALF(w);
2119     --size_a;
2120   }
2121 
2122   if (w) *dc = LOWER_HALF(w);
2123 }
2124 
s_ddiv(mp_int a,mp_digit b)2125 static mp_digit s_ddiv(mp_int a, mp_digit b) {
2126   mp_word w = 0, qdigit;
2127   mp_size ua = MP_USED(a);
2128   mp_digit *da = MP_DIGITS(a) + ua - 1;
2129 
2130   for (/* */; ua > 0; --ua, --da) {
2131     w = (w << MP_DIGIT_BIT) | *da;
2132 
2133     if (w >= b) {
2134       qdigit = w / b;
2135       w = w % b;
2136     } else {
2137       qdigit = 0;
2138     }
2139 
2140     *da = (mp_digit)qdigit;
2141   }
2142 
2143   CLAMP(a);
2144   return (mp_digit)w;
2145 }
2146 
s_qdiv(mp_int z,mp_size p2)2147 static void s_qdiv(mp_int z, mp_size p2) {
2148   mp_size ndig = p2 / MP_DIGIT_BIT, nbits = p2 % MP_DIGIT_BIT;
2149   mp_size uz = MP_USED(z);
2150 
2151   if (ndig) {
2152     mp_size mark;
2153     mp_digit *to, *from;
2154 
2155     if (ndig >= uz) {
2156       mp_int_zero(z);
2157       return;
2158     }
2159 
2160     to = MP_DIGITS(z);
2161     from = to + ndig;
2162 
2163     for (mark = ndig; mark < uz; ++mark) {
2164       *to++ = *from++;
2165     }
2166 
2167     z->used = uz - ndig;
2168   }
2169 
2170   if (nbits) {
2171     mp_digit d = 0, *dz, save;
2172     mp_size up = MP_DIGIT_BIT - nbits;
2173 
2174     uz = MP_USED(z);
2175     dz = MP_DIGITS(z) + uz - 1;
2176 
2177     for (/* */; uz > 0; --uz, --dz) {
2178       save = *dz;
2179 
2180       *dz = (*dz >> nbits) | (d << up);
2181       d = save;
2182     }
2183 
2184     CLAMP(z);
2185   }
2186 
2187   if (MP_USED(z) == 1 && z->digits[0] == 0) z->sign = MP_ZPOS;
2188 }
2189 
s_qmod(mp_int z,mp_size p2)2190 static void s_qmod(mp_int z, mp_size p2) {
2191   mp_size start = p2 / MP_DIGIT_BIT + 1, rest = p2 % MP_DIGIT_BIT;
2192   mp_size uz = MP_USED(z);
2193   mp_digit mask = (1u << rest) - 1;
2194 
2195   if (start <= uz) {
2196     z->used = start;
2197     z->digits[start - 1] &= mask;
2198     CLAMP(z);
2199   }
2200 }
2201 
s_qmul(mp_int z,mp_size p2)2202 static int s_qmul(mp_int z, mp_size p2) {
2203   mp_size uz, need, rest, extra, i;
2204   mp_digit *from, *to, d;
2205 
2206   if (p2 == 0) return 1;
2207 
2208   uz = MP_USED(z);
2209   need = p2 / MP_DIGIT_BIT;
2210   rest = p2 % MP_DIGIT_BIT;
2211 
2212   /* Figure out if we need an extra digit at the top end; this occurs if the
2213      topmost `rest' bits of the high-order digit of z are not zero, meaning
2214      they will be shifted off the end if not preserved */
2215   extra = 0;
2216   if (rest != 0) {
2217     mp_digit *dz = MP_DIGITS(z) + uz - 1;
2218 
2219     if ((*dz >> (MP_DIGIT_BIT - rest)) != 0) extra = 1;
2220   }
2221 
2222   if (!s_pad(z, uz + need + extra)) return 0;
2223 
2224   /* If we need to shift by whole digits, do that in one pass, then
2225      to back and shift by partial digits.
2226    */
2227   if (need > 0) {
2228     from = MP_DIGITS(z) + uz - 1;
2229     to = from + need;
2230 
2231     for (i = 0; i < uz; ++i) *to-- = *from--;
2232 
2233     ZERO(MP_DIGITS(z), need);
2234     uz += need;
2235   }
2236 
2237   if (rest) {
2238     d = 0;
2239     for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from) {
2240       mp_digit save = *from;
2241 
2242       *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
2243       d = save;
2244     }
2245 
2246     d >>= (MP_DIGIT_BIT - rest);
2247     if (d != 0) {
2248       *from = d;
2249       uz += extra;
2250     }
2251   }
2252 
2253   z->used = uz;
2254   CLAMP(z);
2255 
2256   return 1;
2257 }
2258 
2259 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z|
2260    The sign of the result is always zero/positive.
2261  */
s_qsub(mp_int z,mp_size p2)2262 static int s_qsub(mp_int z, mp_size p2) {
2263   mp_digit hi = (1u << (p2 % MP_DIGIT_BIT)), *zp;
2264   mp_size tdig = (p2 / MP_DIGIT_BIT), pos;
2265   mp_word w = 0;
2266 
2267   if (!s_pad(z, tdig + 1)) return 0;
2268 
2269   for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp) {
2270     w = ((mp_word)MP_DIGIT_MAX + 1) - w - (mp_word)*zp;
2271 
2272     *zp = LOWER_HALF(w);
2273     w = UPPER_HALF(w) ? 0 : 1;
2274   }
2275 
2276   w = ((mp_word)MP_DIGIT_MAX + 1 + hi) - w - (mp_word)*zp;
2277   *zp = LOWER_HALF(w);
2278 
2279   assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
2280 
2281   z->sign = MP_ZPOS;
2282   CLAMP(z);
2283 
2284   return 1;
2285 }
2286 
s_dp2k(mp_int z)2287 static int s_dp2k(mp_int z) {
2288   int k = 0;
2289   mp_digit *dp = MP_DIGITS(z), d;
2290 
2291   if (MP_USED(z) == 1 && *dp == 0) return 1;
2292 
2293   while (*dp == 0) {
2294     k += MP_DIGIT_BIT;
2295     ++dp;
2296   }
2297 
2298   d = *dp;
2299   while ((d & 1) == 0) {
2300     d >>= 1;
2301     ++k;
2302   }
2303 
2304   return k;
2305 }
2306 
s_isp2(mp_int z)2307 static int s_isp2(mp_int z) {
2308   mp_size uz = MP_USED(z), k = 0;
2309   mp_digit *dz = MP_DIGITS(z), d;
2310 
2311   while (uz > 1) {
2312     if (*dz++ != 0) return -1;
2313     k += MP_DIGIT_BIT;
2314     --uz;
2315   }
2316 
2317   d = *dz;
2318   while (d > 1) {
2319     if (d & 1) return -1;
2320     ++k;
2321     d >>= 1;
2322   }
2323 
2324   return (int)k;
2325 }
2326 
s_2expt(mp_int z,mp_small k)2327 static int s_2expt(mp_int z, mp_small k) {
2328   mp_size ndig, rest;
2329   mp_digit *dz;
2330 
2331   ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
2332   rest = k % MP_DIGIT_BIT;
2333 
2334   if (!s_pad(z, ndig)) return 0;
2335 
2336   dz = MP_DIGITS(z);
2337   ZERO(dz, ndig);
2338   *(dz + ndig - 1) = (1u << rest);
2339   z->used = ndig;
2340 
2341   return 1;
2342 }
2343 
s_norm(mp_int a,mp_int b)2344 static int s_norm(mp_int a, mp_int b) {
2345   mp_digit d = b->digits[MP_USED(b) - 1];
2346   int k = 0;
2347 
2348   while (d < (1u << (mp_digit)(MP_DIGIT_BIT - 1))) { /* d < (MP_RADIX / 2) */
2349     d <<= 1;
2350     ++k;
2351   }
2352 
2353   /* These multiplications can't fail */
2354   if (k != 0) {
2355     (void)s_qmul(a, (mp_size)k);
2356     (void)s_qmul(b, (mp_size)k);
2357   }
2358 
2359   return k;
2360 }
2361 
s_brmu(mp_int z,mp_int m)2362 static mp_result s_brmu(mp_int z, mp_int m) {
2363   mp_size um = MP_USED(m) * 2;
2364 
2365   if (!s_pad(z, um)) return MP_MEMORY;
2366 
2367   s_2expt(z, MP_DIGIT_BIT * um);
2368   return mp_int_div(z, m, z, NULL);
2369 }
2370 
s_reduce(mp_int x,mp_int m,mp_int mu,mp_int q1,mp_int q2)2371 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2) {
2372   mp_size um = MP_USED(m), umb_p1, umb_m1;
2373 
2374   umb_p1 = (um + 1) * MP_DIGIT_BIT;
2375   umb_m1 = (um - 1) * MP_DIGIT_BIT;
2376 
2377   if (mp_int_copy(x, q1) != MP_OK) return 0;
2378 
2379   /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
2380   s_qdiv(q1, umb_m1);
2381   UMUL(q1, mu, q2);
2382   s_qdiv(q2, umb_p1);
2383 
2384   /* Set x = x mod b^(k+1) */
2385   s_qmod(x, umb_p1);
2386 
2387   /* Now, q is a guess for the quotient a / m.
2388      Compute x - q * m mod b^(k+1), replacing x.  This may be off
2389      by a factor of 2m, but no more than that.
2390    */
2391   UMUL(q2, m, q1);
2392   s_qmod(q1, umb_p1);
2393   (void)mp_int_sub(x, q1, x); /* can't fail */
2394 
2395   /* The result may be < 0; if it is, add b^(k+1) to pin it in the proper
2396      range. */
2397   if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1)) return 0;
2398 
2399   /* If x > m, we need to back it off until it is in range.  This will be
2400      required at most twice.  */
2401   if (mp_int_compare(x, m) >= 0) {
2402     (void)mp_int_sub(x, m, x);
2403     if (mp_int_compare(x, m) >= 0) {
2404       (void)mp_int_sub(x, m, x);
2405     }
2406   }
2407 
2408   /* At this point, x has been properly reduced. */
2409   return 1;
2410 }
2411 
2412 /* Perform modular exponentiation using Barrett's method, where mu is the
2413    reduction constant for m.  Assumes a < m, b > 0. */
s_embar(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)2414 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c) {
2415   mp_digit umu = MP_USED(mu);
2416   mp_digit *db = MP_DIGITS(b);
2417   mp_digit *dbt = db + MP_USED(b) - 1;
2418 
2419   DECLARE_TEMP(3);
2420   REQUIRE(GROW(TEMP(0), 4 * umu));
2421   REQUIRE(GROW(TEMP(1), 4 * umu));
2422   REQUIRE(GROW(TEMP(2), 4 * umu));
2423   ZERO(TEMP(0)->digits, TEMP(0)->alloc);
2424   ZERO(TEMP(1)->digits, TEMP(1)->alloc);
2425   ZERO(TEMP(2)->digits, TEMP(2)->alloc);
2426 
2427   (void)mp_int_set_value(c, 1);
2428 
2429   /* Take care of low-order digits */
2430   while (db < dbt) {
2431     mp_digit d = *db;
2432 
2433     for (int i = MP_DIGIT_BIT; i > 0; --i, d >>= 1) {
2434       if (d & 1) {
2435         /* The use of a second temporary avoids allocation */
2436         UMUL(c, a, TEMP(0));
2437         if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2438           REQUIRE(MP_MEMORY);
2439         }
2440         mp_int_copy(TEMP(0), c);
2441       }
2442 
2443       USQR(a, TEMP(0));
2444       assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2445       if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2446         REQUIRE(MP_MEMORY);
2447       }
2448       assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2449       mp_int_copy(TEMP(0), a);
2450     }
2451 
2452     ++db;
2453   }
2454 
2455   /* Take care of highest-order digit */
2456   mp_digit d = *dbt;
2457   for (;;) {
2458     if (d & 1) {
2459       UMUL(c, a, TEMP(0));
2460       if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2461         REQUIRE(MP_MEMORY);
2462       }
2463       mp_int_copy(TEMP(0), c);
2464     }
2465 
2466     d >>= 1;
2467     if (!d) break;
2468 
2469     USQR(a, TEMP(0));
2470     if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2471       REQUIRE(MP_MEMORY);
2472     }
2473     (void)mp_int_copy(TEMP(0), a);
2474   }
2475 
2476   CLEANUP_TEMP();
2477   return MP_OK;
2478 }
2479 
2480 /* Division of nonnegative integers
2481 
2482    This function implements division algorithm for unsigned multi-precision
2483    integers. The algorithm is based on Algorithm D from Knuth's "The Art of
2484    Computer Programming", 3rd ed. 1998, pg 272-273.
2485 
2486    We diverge from Knuth's algorithm in that we do not perform the subtraction
2487    from the remainder until we have determined that we have the correct
2488    quotient digit. This makes our algorithm less efficient that Knuth because
2489    we might have to perform multiple multiplication and comparison steps before
2490    the subtraction. The advantage is that it is easy to implement and ensure
2491    correctness without worrying about underflow from the subtraction.
2492 
2493    inputs: u   a n+m digit integer in base b (b is 2^MP_DIGIT_BIT)
2494            v   a n   digit integer in base b (b is 2^MP_DIGIT_BIT)
2495            n >= 1
2496            m >= 0
2497   outputs: u / v stored in u
2498            u % v stored in v
2499  */
s_udiv_knuth(mp_int u,mp_int v)2500 static mp_result s_udiv_knuth(mp_int u, mp_int v) {
2501   /* Force signs to positive */
2502   u->sign = MP_ZPOS;
2503   v->sign = MP_ZPOS;
2504 
2505   /* Use simple division algorithm when v is only one digit long */
2506   if (MP_USED(v) == 1) {
2507     mp_digit d, rem;
2508     d = v->digits[0];
2509     rem = s_ddiv(u, d);
2510     mp_int_set_value(v, rem);
2511     return MP_OK;
2512   }
2513 
2514   /* Algorithm D
2515 
2516      The n and m variables are defined as used by Knuth.
2517      u is an n digit number with digits u_{n-1}..u_0.
2518      v is an n+m digit number with digits from v_{m+n-1}..v_0.
2519      We require that n > 1 and m >= 0
2520    */
2521   mp_size n = MP_USED(v);
2522   mp_size m = MP_USED(u) - n;
2523   assert(n > 1);
2524   /* assert(m >= 0) follows because m is unsigned. */
2525 
2526   /* D1: Normalize.
2527      The normalization step provides the necessary condition for Theorem B,
2528      which states that the quotient estimate for q_j, call it qhat
2529 
2530        qhat = u_{j+n}u_{j+n-1} / v_{n-1}
2531 
2532      is bounded by
2533 
2534       qhat - 2 <= q_j <= qhat.
2535 
2536      That is, qhat is always greater than the actual quotient digit q,
2537      and it is never more than two larger than the actual quotient digit.
2538    */
2539   int k = s_norm(u, v);
2540 
2541   /* Extend size of u by one if needed.
2542 
2543      The algorithm begins with a value of u that has one more digit of input.
2544      The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0. If the
2545      multiplication did not increase the number of digits of u, we need to add
2546      a leading zero here.
2547    */
2548   if (k == 0 || MP_USED(u) != m + n + 1) {
2549     if (!s_pad(u, m + n + 1)) return MP_MEMORY;
2550     u->digits[m + n] = 0;
2551     u->used = m + n + 1;
2552   }
2553 
2554   /* Add a leading 0 to v.
2555 
2556      The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0.  We need to
2557      add the leading zero to v here to ensure that the multiplication will
2558      produce the full n+1 digit result.
2559    */
2560   if (!s_pad(v, n + 1)) return MP_MEMORY;
2561   v->digits[n] = 0;
2562 
2563   /* Initialize temporary variables q and t.
2564      q allocates space for m+1 digits to store the quotient digits
2565      t allocates space for n+1 digits to hold the result of q_j*v
2566    */
2567   DECLARE_TEMP(2);
2568   REQUIRE(GROW(TEMP(0), m + 1));
2569   REQUIRE(GROW(TEMP(1), n + 1));
2570 
2571   /* D2: Initialize j */
2572   int j = m;
2573   mpz_t r;
2574   r.digits = MP_DIGITS(u) + j; /* The contents of r are shared with u */
2575   r.used = n + 1;
2576   r.sign = MP_ZPOS;
2577   r.alloc = MP_ALLOC(u);
2578   ZERO(TEMP(1)->digits, TEMP(1)->alloc);
2579 
2580   /* Calculate the m+1 digits of the quotient result */
2581   for (; j >= 0; j--) {
2582     /* D3: Calculate q' */
2583     /* r->digits is aligned to position j of the number u */
2584     mp_word pfx, qhat;
2585     pfx = r.digits[n];
2586     pfx <<= MP_DIGIT_BIT / 2;
2587     pfx <<= MP_DIGIT_BIT / 2;
2588     pfx |= r.digits[n - 1]; /* pfx = u_{j+n}{j+n-1} */
2589 
2590     qhat = pfx / v->digits[n - 1];
2591     /* Check to see if qhat > b, and decrease qhat if so.
2592        Theorem B guarantess that qhat is at most 2 larger than the
2593        actual value, so it is possible that qhat is greater than
2594        the maximum value that will fit in a digit */
2595     if (qhat > MP_DIGIT_MAX) qhat = MP_DIGIT_MAX;
2596 
2597     /* D4,D5,D6: Multiply qhat * v and test for a correct value of q
2598 
2599        We proceed a bit different than the way described by Knuth. This way is
2600        simpler but less efficent. Instead of doing the multiply and subtract
2601        then checking for underflow, we first do the multiply of qhat * v and
2602        see if it is larger than the current remainder r. If it is larger, we
2603        decrease qhat by one and try again. We may need to decrease qhat one
2604        more time before we get a value that is smaller than r.
2605 
2606        This way is less efficent than Knuth because we do more multiplies, but
2607        we do not need to worry about underflow this way.
2608      */
2609     /* t = qhat * v */
2610     s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1);
2611     TEMP(1)->used = n + 1;
2612     CLAMP(TEMP(1));
2613 
2614     /* Clamp r for the comparison. Comparisons do not like leading zeros. */
2615     CLAMP(&r);
2616     if (s_ucmp(TEMP(1), &r) > 0) { /* would the remainder be negative? */
2617       qhat -= 1;                   /* try a smaller q */
2618       s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1);
2619       TEMP(1)->used = n + 1;
2620       CLAMP(TEMP(1));
2621       if (s_ucmp(TEMP(1), &r) > 0) { /* would the remainder be negative? */
2622         assert(qhat > 0);
2623         qhat -= 1; /* try a smaller q */
2624         s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1);
2625         TEMP(1)->used = n + 1;
2626         CLAMP(TEMP(1));
2627       }
2628       assert(s_ucmp(TEMP(1), &r) <= 0 && "The mathematics failed us.");
2629     }
2630     /* Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be n+1
2631        digits long. */
2632     r.used = n + 1;
2633 
2634     /* D4: Multiply and subtract
2635 
2636        Note: The multiply was completed above so we only need to subtract here.
2637      */
2638     s_usub(r.digits, TEMP(1)->digits, r.digits, r.used, TEMP(1)->used);
2639 
2640     /* D5: Test remainder
2641 
2642        Note: Not needed because we always check that qhat is the correct value
2643              before performing the subtract.  Value cast to mp_digit to prevent
2644              warning, qhat has been clamped to MP_DIGIT_MAX
2645      */
2646     TEMP(0)->digits[j] = (mp_digit)qhat;
2647 
2648     /* D6: Add back
2649        Note: Not needed because we always check that qhat is the correct value
2650              before performing the subtract.
2651      */
2652 
2653     /* D7: Loop on j */
2654     r.digits--;
2655     ZERO(TEMP(1)->digits, TEMP(1)->alloc);
2656   }
2657 
2658   /* Get rid of leading zeros in q */
2659   TEMP(0)->used = m + 1;
2660   CLAMP(TEMP(0));
2661 
2662   /* Denormalize the remainder */
2663   CLAMP(u); /* use u here because the r.digits pointer is off-by-one */
2664   if (k != 0) s_qdiv(u, k);
2665 
2666   mp_int_copy(u, v);       /* ok:  0 <= r < v */
2667   mp_int_copy(TEMP(0), u); /* ok:  q <= u     */
2668 
2669   CLEANUP_TEMP();
2670   return MP_OK;
2671 }
2672 
s_outlen(mp_int z,mp_size r)2673 static int s_outlen(mp_int z, mp_size r) {
2674   assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX);
2675 
2676   mp_result bits = mp_int_count_bits(z);
2677   double raw = (double)bits * s_log2[r];
2678 
2679   return (int)(raw + 0.999999);
2680 }
2681 
s_inlen(int len,mp_size r)2682 static mp_size s_inlen(int len, mp_size r) {
2683   double raw = (double)len / s_log2[r];
2684   mp_size bits = (mp_size)(raw + 0.5);
2685 
2686   return (mp_size)((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1;
2687 }
2688 
s_ch2val(char c,int r)2689 static int s_ch2val(char c, int r) {
2690   int out;
2691 
2692   /*
2693    * In some locales, isalpha() accepts characters outside the range A-Z,
2694    * producing out<0 or out>=36.  The "out >= r" check will always catch
2695    * out>=36.  Though nothing explicitly catches out<0, our caller reacts the
2696    * same way to every negative return value.
2697    */
2698   if (isdigit((unsigned char)c))
2699     out = c - '0';
2700   else if (r > 10 && isalpha((unsigned char)c))
2701     out = toupper((unsigned char)c) - 'A' + 10;
2702   else
2703     return -1;
2704 
2705   return (out >= r) ? -1 : out;
2706 }
2707 
s_val2ch(int v,int caps)2708 static char s_val2ch(int v, int caps) {
2709   assert(v >= 0);
2710 
2711   if (v < 10) {
2712     return v + '0';
2713   } else {
2714     char out = (v - 10) + 'a';
2715 
2716     if (caps) {
2717       return toupper((unsigned char)out);
2718     } else {
2719       return out;
2720     }
2721   }
2722 }
2723 
s_2comp(unsigned char * buf,int len)2724 static void s_2comp(unsigned char *buf, int len) {
2725   unsigned short s = 1;
2726 
2727   for (int i = len - 1; i >= 0; --i) {
2728     unsigned char c = ~buf[i];
2729 
2730     s = c + s;
2731     c = s & UCHAR_MAX;
2732     s >>= CHAR_BIT;
2733 
2734     buf[i] = c;
2735   }
2736 
2737   /* last carry out is ignored */
2738 }
2739 
s_tobin(mp_int z,unsigned char * buf,int * limpos,int pad)2740 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad) {
2741   int pos = 0, limit = *limpos;
2742   mp_size uz = MP_USED(z);
2743   mp_digit *dz = MP_DIGITS(z);
2744 
2745   while (uz > 0 && pos < limit) {
2746     mp_digit d = *dz++;
2747     int i;
2748 
2749     for (i = sizeof(mp_digit); i > 0 && pos < limit; --i) {
2750       buf[pos++] = (unsigned char)d;
2751       d >>= CHAR_BIT;
2752 
2753       /* Don't write leading zeroes */
2754       if (d == 0 && uz == 1) i = 0; /* exit loop without signaling truncation */
2755     }
2756 
2757     /* Detect truncation (loop exited with pos >= limit) */
2758     if (i > 0) break;
2759 
2760     --uz;
2761   }
2762 
2763   if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1))) {
2764     if (pos < limit) {
2765       buf[pos++] = 0;
2766     } else {
2767       uz = 1;
2768     }
2769   }
2770 
2771   /* Digits are in reverse order, fix that */
2772   REV(buf, pos);
2773 
2774   /* Return the number of bytes actually written */
2775   *limpos = pos;
2776 
2777   return (uz == 0) ? MP_OK : MP_TRUNC;
2778 }
2779 
2780 /* Here there be dragons */
2781