1 /*
2  *  mpi.c
3  *
4  *  Arbitrary precision integer arithmetic library
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
9 
10 #include "mpi-priv.h"
11 #include "mplogic.h"
12 #if defined(OSF1)
13 #include <c_asm.h>
14 #endif
15 
16 #if defined(__arm__) && \
17     ((defined(__thumb__) && !defined(__thumb2__)) || defined(__ARM_ARCH_3__))
18 /* 16-bit thumb or ARM v3 doesn't work inlined assember version */
19 #undef MP_ASSEMBLY_MULTIPLY
20 #undef MP_ASSEMBLY_SQUARE
21 #endif
22 
23 #if MP_LOGTAB
24 /*
25   A table of the logs of 2 for various bases (the 0 and 1 entries of
26   this table are meaningless and should not be referenced).
27 
28   This table is used to compute output lengths for the mp_toradix()
29   function.  Since a number n in radix r takes up about log_r(n)
30   digits, we estimate the output size by taking the least integer
31   greater than log_r(n), where:
32 
33   log_r(n) = log_2(n) * log_r(2)
34 
35   This table, therefore, is a table of log_r(2) for 2 <= r <= 36,
36   which are the output bases supported.
37  */
38 #include "logtab.h"
39 #endif
40 
41 #ifdef CT_VERIF
42 #include <valgrind/memcheck.h>
43 #endif
44 
45 /* {{{ Constant strings */
46 
47 /* Constant strings returned by mp_strerror() */
48 static const char *mp_err_string[] = {
49     "unknown result code",     /* say what?            */
50     "boolean true",            /* MP_OKAY, MP_YES      */
51     "boolean false",           /* MP_NO                */
52     "out of memory",           /* MP_MEM               */
53     "argument out of range",   /* MP_RANGE             */
54     "invalid input parameter", /* MP_BADARG            */
55     "result is undefined"      /* MP_UNDEF             */
56 };
57 
58 /* Value to digit maps for radix conversion   */
59 
60 /* s_dmap_1 - standard digits and letters */
61 static const char *s_dmap_1 =
62     "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+/";
63 
64 /* }}} */
65 
66 /* {{{ Default precision manipulation */
67 
68 /* Default precision for newly created mp_int's      */
69 static mp_size s_mp_defprec = MP_DEFPREC;
70 
71 mp_size
mp_get_prec(void)72 mp_get_prec(void)
73 {
74     return s_mp_defprec;
75 
76 } /* end mp_get_prec() */
77 
78 void
mp_set_prec(mp_size prec)79 mp_set_prec(mp_size prec)
80 {
81     if (prec == 0)
82         s_mp_defprec = MP_DEFPREC;
83     else
84         s_mp_defprec = prec;
85 
86 } /* end mp_set_prec() */
87 
88 /* }}} */
89 
90 #ifdef CT_VERIF
91 void
mp_taint(mp_int * mp)92 mp_taint(mp_int *mp)
93 {
94     size_t i;
95     for (i = 0; i < mp->used; ++i) {
96         VALGRIND_MAKE_MEM_UNDEFINED(&(mp->dp[i]), sizeof(mp_digit));
97     }
98 }
99 
100 void
mp_untaint(mp_int * mp)101 mp_untaint(mp_int *mp)
102 {
103     size_t i;
104     for (i = 0; i < mp->used; ++i) {
105         VALGRIND_MAKE_MEM_DEFINED(&(mp->dp[i]), sizeof(mp_digit));
106     }
107 }
108 #endif
109 
110 /*------------------------------------------------------------------------*/
111 /* {{{ mp_init(mp) */
112 
113 /*
114   mp_init(mp)
115 
116   Initialize a new zero-valued mp_int.  Returns MP_OKAY if successful,
117   MP_MEM if memory could not be allocated for the structure.
118  */
119 
120 mp_err
mp_init(mp_int * mp)121 mp_init(mp_int *mp)
122 {
123     return mp_init_size(mp, s_mp_defprec);
124 
125 } /* end mp_init() */
126 
127 /* }}} */
128 
129 /* {{{ mp_init_size(mp, prec) */
130 
131 /*
132   mp_init_size(mp, prec)
133 
134   Initialize a new zero-valued mp_int with at least the given
135   precision; returns MP_OKAY if successful, or MP_MEM if memory could
136   not be allocated for the structure.
137  */
138 
139 mp_err
mp_init_size(mp_int * mp,mp_size prec)140 mp_init_size(mp_int *mp, mp_size prec)
141 {
142     ARGCHK(mp != NULL && prec > 0, MP_BADARG);
143 
144     prec = MP_ROUNDUP(prec, s_mp_defprec);
145     if ((DIGITS(mp) = s_mp_alloc(prec, sizeof(mp_digit))) == NULL)
146         return MP_MEM;
147 
148     SIGN(mp) = ZPOS;
149     USED(mp) = 1;
150     ALLOC(mp) = prec;
151 
152     return MP_OKAY;
153 
154 } /* end mp_init_size() */
155 
156 /* }}} */
157 
158 /* {{{ mp_init_copy(mp, from) */
159 
160 /*
161   mp_init_copy(mp, from)
162 
163   Initialize mp as an exact copy of from.  Returns MP_OKAY if
164   successful, MP_MEM if memory could not be allocated for the new
165   structure.
166  */
167 
168 mp_err
mp_init_copy(mp_int * mp,const mp_int * from)169 mp_init_copy(mp_int *mp, const mp_int *from)
170 {
171     ARGCHK(mp != NULL && from != NULL, MP_BADARG);
172 
173     if (mp == from)
174         return MP_OKAY;
175 
176     if ((DIGITS(mp) = s_mp_alloc(ALLOC(from), sizeof(mp_digit))) == NULL)
177         return MP_MEM;
178 
179     s_mp_copy(DIGITS(from), DIGITS(mp), USED(from));
180     USED(mp) = USED(from);
181     ALLOC(mp) = ALLOC(from);
182     SIGN(mp) = SIGN(from);
183 
184     return MP_OKAY;
185 
186 } /* end mp_init_copy() */
187 
188 /* }}} */
189 
190 /* {{{ mp_copy(from, to) */
191 
192 /*
193   mp_copy(from, to)
194 
195   Copies the mp_int 'from' to the mp_int 'to'.  It is presumed that
196   'to' has already been initialized (if not, use mp_init_copy()
197   instead). If 'from' and 'to' are identical, nothing happens.
198  */
199 
200 mp_err
mp_copy(const mp_int * from,mp_int * to)201 mp_copy(const mp_int *from, mp_int *to)
202 {
203     ARGCHK(from != NULL && to != NULL, MP_BADARG);
204 
205     if (from == to)
206         return MP_OKAY;
207 
208     { /* copy */
209         mp_digit *tmp;
210 
211         /*
212           If the allocated buffer in 'to' already has enough space to hold
213           all the used digits of 'from', we'll re-use it to avoid hitting
214           the memory allocater more than necessary; otherwise, we'd have
215           to grow anyway, so we just allocate a hunk and make the copy as
216           usual
217          */
218         if (ALLOC(to) >= USED(from)) {
219             s_mp_setz(DIGITS(to) + USED(from), ALLOC(to) - USED(from));
220             s_mp_copy(DIGITS(from), DIGITS(to), USED(from));
221 
222         } else {
223             if ((tmp = s_mp_alloc(ALLOC(from), sizeof(mp_digit))) == NULL)
224                 return MP_MEM;
225 
226             s_mp_copy(DIGITS(from), tmp, USED(from));
227 
228             if (DIGITS(to) != NULL) {
229                 s_mp_setz(DIGITS(to), ALLOC(to));
230                 s_mp_free(DIGITS(to));
231             }
232 
233             DIGITS(to) = tmp;
234             ALLOC(to) = ALLOC(from);
235         }
236 
237         /* Copy the precision and sign from the original */
238         USED(to) = USED(from);
239         SIGN(to) = SIGN(from);
240     } /* end copy */
241 
242     return MP_OKAY;
243 
244 } /* end mp_copy() */
245 
246 /* }}} */
247 
248 /* {{{ mp_exch(mp1, mp2) */
249 
250 /*
251   mp_exch(mp1, mp2)
252 
253   Exchange mp1 and mp2 without allocating any intermediate memory
254   (well, unless you count the stack space needed for this call and the
255   locals it creates...).  This cannot fail.
256  */
257 
258 void
mp_exch(mp_int * mp1,mp_int * mp2)259 mp_exch(mp_int *mp1, mp_int *mp2)
260 {
261 #if MP_ARGCHK == 2
262     assert(mp1 != NULL && mp2 != NULL);
263 #else
264     if (mp1 == NULL || mp2 == NULL)
265         return;
266 #endif
267 
268     s_mp_exch(mp1, mp2);
269 
270 } /* end mp_exch() */
271 
272 /* }}} */
273 
274 /* {{{ mp_clear(mp) */
275 
276 /*
277   mp_clear(mp)
278 
279   Release the storage used by an mp_int, and void its fields so that
280   if someone calls mp_clear() again for the same int later, we won't
281   get tollchocked.
282  */
283 
284 void
mp_clear(mp_int * mp)285 mp_clear(mp_int *mp)
286 {
287     if (mp == NULL)
288         return;
289 
290     if (DIGITS(mp) != NULL) {
291         s_mp_setz(DIGITS(mp), ALLOC(mp));
292         s_mp_free(DIGITS(mp));
293         DIGITS(mp) = NULL;
294     }
295 
296     USED(mp) = 0;
297     ALLOC(mp) = 0;
298 
299 } /* end mp_clear() */
300 
301 /* }}} */
302 
303 /* {{{ mp_zero(mp) */
304 
305 /*
306   mp_zero(mp)
307 
308   Set mp to zero.  Does not change the allocated size of the structure,
309   and therefore cannot fail (except on a bad argument, which we ignore)
310  */
311 void
mp_zero(mp_int * mp)312 mp_zero(mp_int *mp)
313 {
314     if (mp == NULL)
315         return;
316 
317     s_mp_setz(DIGITS(mp), ALLOC(mp));
318     USED(mp) = 1;
319     SIGN(mp) = ZPOS;
320 
321 } /* end mp_zero() */
322 
323 /* }}} */
324 
325 /* {{{ mp_set(mp, d) */
326 
327 void
mp_set(mp_int * mp,mp_digit d)328 mp_set(mp_int *mp, mp_digit d)
329 {
330     if (mp == NULL)
331         return;
332 
333     mp_zero(mp);
334     DIGIT(mp, 0) = d;
335 
336 } /* end mp_set() */
337 
338 /* }}} */
339 
340 /* {{{ mp_set_int(mp, z) */
341 
342 mp_err
mp_set_int(mp_int * mp,long z)343 mp_set_int(mp_int *mp, long z)
344 {
345     unsigned long v = labs(z);
346     mp_err res;
347 
348     ARGCHK(mp != NULL, MP_BADARG);
349 
350     /* https://bugzilla.mozilla.org/show_bug.cgi?id=1509432 */
351     if ((res = mp_set_ulong(mp, v)) != MP_OKAY) { /* avoids duplicated code */
352         return res;
353     }
354 
355     if (z < 0) {
356         SIGN(mp) = NEG;
357     }
358 
359     return MP_OKAY;
360 } /* end mp_set_int() */
361 
362 /* }}} */
363 
364 /* {{{ mp_set_ulong(mp, z) */
365 
366 mp_err
mp_set_ulong(mp_int * mp,unsigned long z)367 mp_set_ulong(mp_int *mp, unsigned long z)
368 {
369     int ix;
370     mp_err res;
371 
372     ARGCHK(mp != NULL, MP_BADARG);
373 
374     mp_zero(mp);
375     if (z == 0)
376         return MP_OKAY; /* shortcut for zero */
377 
378     if (sizeof z <= sizeof(mp_digit)) {
379         DIGIT(mp, 0) = z;
380     } else {
381         for (ix = sizeof(long) - 1; ix >= 0; ix--) {
382             if ((res = s_mp_mul_d(mp, (UCHAR_MAX + 1))) != MP_OKAY)
383                 return res;
384 
385             res = s_mp_add_d(mp, (mp_digit)((z >> (ix * CHAR_BIT)) & UCHAR_MAX));
386             if (res != MP_OKAY)
387                 return res;
388         }
389     }
390     return MP_OKAY;
391 } /* end mp_set_ulong() */
392 
393 /* }}} */
394 
395 /*------------------------------------------------------------------------*/
396 /* {{{ Digit arithmetic */
397 
398 /* {{{ mp_add_d(a, d, b) */
399 
400 /*
401   mp_add_d(a, d, b)
402 
403   Compute the sum b = a + d, for a single digit d.  Respects the sign of
404   its primary addend (single digits are unsigned anyway).
405  */
406 
407 mp_err
mp_add_d(const mp_int * a,mp_digit d,mp_int * b)408 mp_add_d(const mp_int *a, mp_digit d, mp_int *b)
409 {
410     mp_int tmp;
411     mp_err res;
412 
413     ARGCHK(a != NULL && b != NULL, MP_BADARG);
414 
415     if ((res = mp_init_copy(&tmp, a)) != MP_OKAY)
416         return res;
417 
418     if (SIGN(&tmp) == ZPOS) {
419         if ((res = s_mp_add_d(&tmp, d)) != MP_OKAY)
420             goto CLEANUP;
421     } else if (s_mp_cmp_d(&tmp, d) >= 0) {
422         if ((res = s_mp_sub_d(&tmp, d)) != MP_OKAY)
423             goto CLEANUP;
424     } else {
425         mp_neg(&tmp, &tmp);
426 
427         DIGIT(&tmp, 0) = d - DIGIT(&tmp, 0);
428     }
429 
430     if (s_mp_cmp_d(&tmp, 0) == 0)
431         SIGN(&tmp) = ZPOS;
432 
433     s_mp_exch(&tmp, b);
434 
435 CLEANUP:
436     mp_clear(&tmp);
437     return res;
438 
439 } /* end mp_add_d() */
440 
441 /* }}} */
442 
443 /* {{{ mp_sub_d(a, d, b) */
444 
445 /*
446   mp_sub_d(a, d, b)
447 
448   Compute the difference b = a - d, for a single digit d.  Respects the
449   sign of its subtrahend (single digits are unsigned anyway).
450  */
451 
452 mp_err
mp_sub_d(const mp_int * a,mp_digit d,mp_int * b)453 mp_sub_d(const mp_int *a, mp_digit d, mp_int *b)
454 {
455     mp_int tmp;
456     mp_err res;
457 
458     ARGCHK(a != NULL && b != NULL, MP_BADARG);
459 
460     if ((res = mp_init_copy(&tmp, a)) != MP_OKAY)
461         return res;
462 
463     if (SIGN(&tmp) == NEG) {
464         if ((res = s_mp_add_d(&tmp, d)) != MP_OKAY)
465             goto CLEANUP;
466     } else if (s_mp_cmp_d(&tmp, d) >= 0) {
467         if ((res = s_mp_sub_d(&tmp, d)) != MP_OKAY)
468             goto CLEANUP;
469     } else {
470         mp_neg(&tmp, &tmp);
471 
472         DIGIT(&tmp, 0) = d - DIGIT(&tmp, 0);
473         SIGN(&tmp) = NEG;
474     }
475 
476     if (s_mp_cmp_d(&tmp, 0) == 0)
477         SIGN(&tmp) = ZPOS;
478 
479     s_mp_exch(&tmp, b);
480 
481 CLEANUP:
482     mp_clear(&tmp);
483     return res;
484 
485 } /* end mp_sub_d() */
486 
487 /* }}} */
488 
489 /* {{{ mp_mul_d(a, d, b) */
490 
491 /*
492   mp_mul_d(a, d, b)
493 
494   Compute the product b = a * d, for a single digit d.  Respects the sign
495   of its multiplicand (single digits are unsigned anyway)
496  */
497 
498 mp_err
mp_mul_d(const mp_int * a,mp_digit d,mp_int * b)499 mp_mul_d(const mp_int *a, mp_digit d, mp_int *b)
500 {
501     mp_err res;
502 
503     ARGCHK(a != NULL && b != NULL, MP_BADARG);
504 
505     if (d == 0) {
506         mp_zero(b);
507         return MP_OKAY;
508     }
509 
510     if ((res = mp_copy(a, b)) != MP_OKAY)
511         return res;
512 
513     res = s_mp_mul_d(b, d);
514 
515     return res;
516 
517 } /* end mp_mul_d() */
518 
519 /* }}} */
520 
521 /* {{{ mp_mul_2(a, c) */
522 
523 mp_err
mp_mul_2(const mp_int * a,mp_int * c)524 mp_mul_2(const mp_int *a, mp_int *c)
525 {
526     mp_err res;
527 
528     ARGCHK(a != NULL && c != NULL, MP_BADARG);
529 
530     if ((res = mp_copy(a, c)) != MP_OKAY)
531         return res;
532 
533     return s_mp_mul_2(c);
534 
535 } /* end mp_mul_2() */
536 
537 /* }}} */
538 
539 /* {{{ mp_div_d(a, d, q, r) */
540 
541 /*
542   mp_div_d(a, d, q, r)
543 
544   Compute the quotient q = a / d and remainder r = a mod d, for a
545   single digit d.  Respects the sign of its divisor (single digits are
546   unsigned anyway).
547  */
548 
549 mp_err
mp_div_d(const mp_int * a,mp_digit d,mp_int * q,mp_digit * r)550 mp_div_d(const mp_int *a, mp_digit d, mp_int *q, mp_digit *r)
551 {
552     mp_err res;
553     mp_int qp;
554     mp_digit rem = 0;
555     int pow;
556 
557     ARGCHK(a != NULL, MP_BADARG);
558 
559     if (d == 0)
560         return MP_RANGE;
561 
562     /* Shortcut for powers of two ... */
563     if ((pow = s_mp_ispow2d(d)) >= 0) {
564         mp_digit mask;
565 
566         mask = ((mp_digit)1 << pow) - 1;
567         rem = DIGIT(a, 0) & mask;
568 
569         if (q) {
570             if ((res = mp_copy(a, q)) != MP_OKAY) {
571                 return res;
572             }
573             s_mp_div_2d(q, pow);
574         }
575 
576         if (r)
577             *r = rem;
578 
579         return MP_OKAY;
580     }
581 
582     if ((res = mp_init_copy(&qp, a)) != MP_OKAY)
583         return res;
584 
585     res = s_mp_div_d(&qp, d, &rem);
586 
587     if (s_mp_cmp_d(&qp, 0) == 0)
588         SIGN(q) = ZPOS;
589 
590     if (r) {
591         *r = rem;
592     }
593 
594     if (q)
595         s_mp_exch(&qp, q);
596 
597     mp_clear(&qp);
598     return res;
599 
600 } /* end mp_div_d() */
601 
602 /* }}} */
603 
604 /* {{{ mp_div_2(a, c) */
605 
606 /*
607   mp_div_2(a, c)
608 
609   Compute c = a / 2, disregarding the remainder.
610  */
611 
612 mp_err
mp_div_2(const mp_int * a,mp_int * c)613 mp_div_2(const mp_int *a, mp_int *c)
614 {
615     mp_err res;
616 
617     ARGCHK(a != NULL && c != NULL, MP_BADARG);
618 
619     if ((res = mp_copy(a, c)) != MP_OKAY)
620         return res;
621 
622     s_mp_div_2(c);
623 
624     return MP_OKAY;
625 
626 } /* end mp_div_2() */
627 
628 /* }}} */
629 
630 /* {{{ mp_expt_d(a, d, b) */
631 
632 mp_err
mp_expt_d(const mp_int * a,mp_digit d,mp_int * c)633 mp_expt_d(const mp_int *a, mp_digit d, mp_int *c)
634 {
635     mp_int s, x;
636     mp_err res;
637 
638     ARGCHK(a != NULL && c != NULL, MP_BADARG);
639 
640     if ((res = mp_init(&s)) != MP_OKAY)
641         return res;
642     if ((res = mp_init_copy(&x, a)) != MP_OKAY)
643         goto X;
644 
645     DIGIT(&s, 0) = 1;
646 
647     while (d != 0) {
648         if (d & 1) {
649             if ((res = s_mp_mul(&s, &x)) != MP_OKAY)
650                 goto CLEANUP;
651         }
652 
653         d /= 2;
654 
655         if ((res = s_mp_sqr(&x)) != MP_OKAY)
656             goto CLEANUP;
657     }
658 
659     s_mp_exch(&s, c);
660 
661 CLEANUP:
662     mp_clear(&x);
663 X:
664     mp_clear(&s);
665 
666     return res;
667 
668 } /* end mp_expt_d() */
669 
670 /* }}} */
671 
672 /* }}} */
673 
674 /*------------------------------------------------------------------------*/
675 /* {{{ Full arithmetic */
676 
677 /* {{{ mp_abs(a, b) */
678 
679 /*
680   mp_abs(a, b)
681 
682   Compute b = |a|.  'a' and 'b' may be identical.
683  */
684 
685 mp_err
mp_abs(const mp_int * a,mp_int * b)686 mp_abs(const mp_int *a, mp_int *b)
687 {
688     mp_err res;
689 
690     ARGCHK(a != NULL && b != NULL, MP_BADARG);
691 
692     if ((res = mp_copy(a, b)) != MP_OKAY)
693         return res;
694 
695     SIGN(b) = ZPOS;
696 
697     return MP_OKAY;
698 
699 } /* end mp_abs() */
700 
701 /* }}} */
702 
703 /* {{{ mp_neg(a, b) */
704 
705 /*
706   mp_neg(a, b)
707 
708   Compute b = -a.  'a' and 'b' may be identical.
709  */
710 
711 mp_err
mp_neg(const mp_int * a,mp_int * b)712 mp_neg(const mp_int *a, mp_int *b)
713 {
714     mp_err res;
715 
716     ARGCHK(a != NULL && b != NULL, MP_BADARG);
717 
718     if ((res = mp_copy(a, b)) != MP_OKAY)
719         return res;
720 
721     if (s_mp_cmp_d(b, 0) == MP_EQ)
722         SIGN(b) = ZPOS;
723     else
724         SIGN(b) = (SIGN(b) == NEG) ? ZPOS : NEG;
725 
726     return MP_OKAY;
727 
728 } /* end mp_neg() */
729 
730 /* }}} */
731 
732 /* {{{ mp_add(a, b, c) */
733 
734 /*
735   mp_add(a, b, c)
736 
737   Compute c = a + b.  All parameters may be identical.
738  */
739 
740 mp_err
mp_add(const mp_int * a,const mp_int * b,mp_int * c)741 mp_add(const mp_int *a, const mp_int *b, mp_int *c)
742 {
743     mp_err res;
744 
745     ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
746 
747     if (SIGN(a) == SIGN(b)) { /* same sign:  add values, keep sign */
748         MP_CHECKOK(s_mp_add_3arg(a, b, c));
749     } else if (s_mp_cmp(a, b) >= 0) { /* different sign: |a| >= |b|   */
750         MP_CHECKOK(s_mp_sub_3arg(a, b, c));
751     } else { /* different sign: |a|  < |b|   */
752         MP_CHECKOK(s_mp_sub_3arg(b, a, c));
753     }
754 
755     if (s_mp_cmp_d(c, 0) == MP_EQ)
756         SIGN(c) = ZPOS;
757 
758 CLEANUP:
759     return res;
760 
761 } /* end mp_add() */
762 
763 /* }}} */
764 
765 /* {{{ mp_sub(a, b, c) */
766 
767 /*
768   mp_sub(a, b, c)
769 
770   Compute c = a - b.  All parameters may be identical.
771  */
772 
773 mp_err
mp_sub(const mp_int * a,const mp_int * b,mp_int * c)774 mp_sub(const mp_int *a, const mp_int *b, mp_int *c)
775 {
776     mp_err res;
777     int magDiff;
778 
779     ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
780 
781     if (a == b) {
782         mp_zero(c);
783         return MP_OKAY;
784     }
785 
786     if (MP_SIGN(a) != MP_SIGN(b)) {
787         MP_CHECKOK(s_mp_add_3arg(a, b, c));
788     } else if (!(magDiff = s_mp_cmp(a, b))) {
789         mp_zero(c);
790         res = MP_OKAY;
791     } else if (magDiff > 0) {
792         MP_CHECKOK(s_mp_sub_3arg(a, b, c));
793     } else {
794         MP_CHECKOK(s_mp_sub_3arg(b, a, c));
795         MP_SIGN(c) = !MP_SIGN(a);
796     }
797 
798     if (s_mp_cmp_d(c, 0) == MP_EQ)
799         MP_SIGN(c) = MP_ZPOS;
800 
801 CLEANUP:
802     return res;
803 
804 } /* end mp_sub() */
805 
806 /* }}} */
807 
808 /* {{{ mp_mul(a, b, c) */
809 
810 /*
811   mp_mul(a, b, c)
812 
813   Compute c = a * b.  All parameters may be identical.
814  */
815 mp_err
mp_mul(const mp_int * a,const mp_int * b,mp_int * c)816 mp_mul(const mp_int *a, const mp_int *b, mp_int *c)
817 {
818     mp_digit *pb;
819     mp_int tmp;
820     mp_err res;
821     mp_size ib;
822     mp_size useda, usedb;
823 
824     ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
825 
826     if (a == c) {
827         if ((res = mp_init_copy(&tmp, a)) != MP_OKAY)
828             return res;
829         if (a == b)
830             b = &tmp;
831         a = &tmp;
832     } else if (b == c) {
833         if ((res = mp_init_copy(&tmp, b)) != MP_OKAY)
834             return res;
835         b = &tmp;
836     } else {
837         MP_DIGITS(&tmp) = 0;
838     }
839 
840     if (MP_USED(a) < MP_USED(b)) {
841         const mp_int *xch = b; /* switch a and b, to do fewer outer loops */
842         b = a;
843         a = xch;
844     }
845 
846     MP_USED(c) = 1;
847     MP_DIGIT(c, 0) = 0;
848     if ((res = s_mp_pad(c, USED(a) + USED(b))) != MP_OKAY)
849         goto CLEANUP;
850 
851 #ifdef NSS_USE_COMBA
852     if ((MP_USED(a) == MP_USED(b)) && IS_POWER_OF_2(MP_USED(b))) {
853         if (MP_USED(a) == 4) {
854             s_mp_mul_comba_4(a, b, c);
855             goto CLEANUP;
856         }
857         if (MP_USED(a) == 8) {
858             s_mp_mul_comba_8(a, b, c);
859             goto CLEANUP;
860         }
861         if (MP_USED(a) == 16) {
862             s_mp_mul_comba_16(a, b, c);
863             goto CLEANUP;
864         }
865         if (MP_USED(a) == 32) {
866             s_mp_mul_comba_32(a, b, c);
867             goto CLEANUP;
868         }
869     }
870 #endif
871 
872     pb = MP_DIGITS(b);
873     s_mpv_mul_d(MP_DIGITS(a), MP_USED(a), *pb++, MP_DIGITS(c));
874 
875     /* Outer loop:  Digits of b */
876     useda = MP_USED(a);
877     usedb = MP_USED(b);
878     for (ib = 1; ib < usedb; ib++) {
879         mp_digit b_i = *pb++;
880 
881         /* Inner product:  Digits of a */
882         if (b_i)
883             s_mpv_mul_d_add(MP_DIGITS(a), useda, b_i, MP_DIGITS(c) + ib);
884         else
885             MP_DIGIT(c, ib + useda) = b_i;
886     }
887 
888     s_mp_clamp(c);
889 
890     if (SIGN(a) == SIGN(b) || s_mp_cmp_d(c, 0) == MP_EQ)
891         SIGN(c) = ZPOS;
892     else
893         SIGN(c) = NEG;
894 
895 CLEANUP:
896     mp_clear(&tmp);
897     return res;
898 } /* end mp_mul() */
899 
900 /* }}} */
901 
902 /* {{{ mp_sqr(a, sqr) */
903 
904 #if MP_SQUARE
905 /*
906   Computes the square of a.  This can be done more
907   efficiently than a general multiplication, because many of the
908   computation steps are redundant when squaring.  The inner product
909   step is a bit more complicated, but we save a fair number of
910   iterations of the multiplication loop.
911  */
912 
913 /* sqr = a^2;   Caller provides both a and tmp; */
914 mp_err
mp_sqr(const mp_int * a,mp_int * sqr)915 mp_sqr(const mp_int *a, mp_int *sqr)
916 {
917     mp_digit *pa;
918     mp_digit d;
919     mp_err res;
920     mp_size ix;
921     mp_int tmp;
922     int count;
923 
924     ARGCHK(a != NULL && sqr != NULL, MP_BADARG);
925 
926     if (a == sqr) {
927         if ((res = mp_init_copy(&tmp, a)) != MP_OKAY)
928             return res;
929         a = &tmp;
930     } else {
931         DIGITS(&tmp) = 0;
932         res = MP_OKAY;
933     }
934 
935     ix = 2 * MP_USED(a);
936     if (ix > MP_ALLOC(sqr)) {
937         MP_USED(sqr) = 1;
938         MP_CHECKOK(s_mp_grow(sqr, ix));
939     }
940     MP_USED(sqr) = ix;
941     MP_DIGIT(sqr, 0) = 0;
942 
943 #ifdef NSS_USE_COMBA
944     if (IS_POWER_OF_2(MP_USED(a))) {
945         if (MP_USED(a) == 4) {
946             s_mp_sqr_comba_4(a, sqr);
947             goto CLEANUP;
948         }
949         if (MP_USED(a) == 8) {
950             s_mp_sqr_comba_8(a, sqr);
951             goto CLEANUP;
952         }
953         if (MP_USED(a) == 16) {
954             s_mp_sqr_comba_16(a, sqr);
955             goto CLEANUP;
956         }
957         if (MP_USED(a) == 32) {
958             s_mp_sqr_comba_32(a, sqr);
959             goto CLEANUP;
960         }
961     }
962 #endif
963 
964     pa = MP_DIGITS(a);
965     count = MP_USED(a) - 1;
966     if (count > 0) {
967         d = *pa++;
968         s_mpv_mul_d(pa, count, d, MP_DIGITS(sqr) + 1);
969         for (ix = 3; --count > 0; ix += 2) {
970             d = *pa++;
971             s_mpv_mul_d_add(pa, count, d, MP_DIGITS(sqr) + ix);
972         }                                    /* for(ix ...) */
973         MP_DIGIT(sqr, MP_USED(sqr) - 1) = 0; /* above loop stopped short of this. */
974 
975         /* now sqr *= 2 */
976         s_mp_mul_2(sqr);
977     } else {
978         MP_DIGIT(sqr, 1) = 0;
979     }
980 
981     /* now add the squares of the digits of a to sqr. */
982     s_mpv_sqr_add_prop(MP_DIGITS(a), MP_USED(a), MP_DIGITS(sqr));
983 
984     SIGN(sqr) = ZPOS;
985     s_mp_clamp(sqr);
986 
987 CLEANUP:
988     mp_clear(&tmp);
989     return res;
990 
991 } /* end mp_sqr() */
992 #endif
993 
994 /* }}} */
995 
996 /* {{{ mp_div(a, b, q, r) */
997 
998 /*
999   mp_div(a, b, q, r)
1000 
1001   Compute q = a / b and r = a mod b.  Input parameters may be re-used
1002   as output parameters.  If q or r is NULL, that portion of the
1003   computation will be discarded (although it will still be computed)
1004  */
1005 mp_err
mp_div(const mp_int * a,const mp_int * b,mp_int * q,mp_int * r)1006 mp_div(const mp_int *a, const mp_int *b, mp_int *q, mp_int *r)
1007 {
1008     mp_err res;
1009     mp_int *pQ, *pR;
1010     mp_int qtmp, rtmp, btmp;
1011     int cmp;
1012     mp_sign signA;
1013     mp_sign signB;
1014 
1015     ARGCHK(a != NULL && b != NULL, MP_BADARG);
1016 
1017     signA = MP_SIGN(a);
1018     signB = MP_SIGN(b);
1019 
1020     if (mp_cmp_z(b) == MP_EQ)
1021         return MP_RANGE;
1022 
1023     DIGITS(&qtmp) = 0;
1024     DIGITS(&rtmp) = 0;
1025     DIGITS(&btmp) = 0;
1026 
1027     /* Set up some temporaries... */
1028     if (!r || r == a || r == b) {
1029         MP_CHECKOK(mp_init_copy(&rtmp, a));
1030         pR = &rtmp;
1031     } else {
1032         MP_CHECKOK(mp_copy(a, r));
1033         pR = r;
1034     }
1035 
1036     if (!q || q == a || q == b) {
1037         MP_CHECKOK(mp_init_size(&qtmp, MP_USED(a)));
1038         pQ = &qtmp;
1039     } else {
1040         MP_CHECKOK(s_mp_pad(q, MP_USED(a)));
1041         pQ = q;
1042         mp_zero(pQ);
1043     }
1044 
1045     /*
1046       If |a| <= |b|, we can compute the solution without division;
1047       otherwise, we actually do the work required.
1048      */
1049     if ((cmp = s_mp_cmp(a, b)) <= 0) {
1050         if (cmp) {
1051             /* r was set to a above. */
1052             mp_zero(pQ);
1053         } else {
1054             mp_set(pQ, 1);
1055             mp_zero(pR);
1056         }
1057     } else {
1058         MP_CHECKOK(mp_init_copy(&btmp, b));
1059         MP_CHECKOK(s_mp_div(pR, &btmp, pQ));
1060     }
1061 
1062     /* Compute the signs for the output  */
1063     MP_SIGN(pR) = signA;        /* Sr = Sa              */
1064     /* Sq = ZPOS if Sa == Sb */ /* Sq = NEG if Sa != Sb */
1065     MP_SIGN(pQ) = (signA == signB) ? ZPOS : NEG;
1066 
1067     if (s_mp_cmp_d(pQ, 0) == MP_EQ)
1068         SIGN(pQ) = ZPOS;
1069     if (s_mp_cmp_d(pR, 0) == MP_EQ)
1070         SIGN(pR) = ZPOS;
1071 
1072     /* Copy output, if it is needed      */
1073     if (q && q != pQ)
1074         s_mp_exch(pQ, q);
1075 
1076     if (r && r != pR)
1077         s_mp_exch(pR, r);
1078 
1079 CLEANUP:
1080     mp_clear(&btmp);
1081     mp_clear(&rtmp);
1082     mp_clear(&qtmp);
1083 
1084     return res;
1085 
1086 } /* end mp_div() */
1087 
1088 /* }}} */
1089 
1090 /* {{{ mp_div_2d(a, d, q, r) */
1091 
1092 mp_err
mp_div_2d(const mp_int * a,mp_digit d,mp_int * q,mp_int * r)1093 mp_div_2d(const mp_int *a, mp_digit d, mp_int *q, mp_int *r)
1094 {
1095     mp_err res;
1096 
1097     ARGCHK(a != NULL, MP_BADARG);
1098 
1099     if (q) {
1100         if ((res = mp_copy(a, q)) != MP_OKAY)
1101             return res;
1102     }
1103     if (r) {
1104         if ((res = mp_copy(a, r)) != MP_OKAY)
1105             return res;
1106     }
1107     if (q) {
1108         s_mp_div_2d(q, d);
1109     }
1110     if (r) {
1111         s_mp_mod_2d(r, d);
1112     }
1113 
1114     return MP_OKAY;
1115 
1116 } /* end mp_div_2d() */
1117 
1118 /* }}} */
1119 
1120 /* {{{ mp_expt(a, b, c) */
1121 
1122 /*
1123   mp_expt(a, b, c)
1124 
1125   Compute c = a ** b, that is, raise a to the b power.  Uses a
1126   standard iterative square-and-multiply technique.
1127  */
1128 
1129 mp_err
mp_expt(mp_int * a,mp_int * b,mp_int * c)1130 mp_expt(mp_int *a, mp_int *b, mp_int *c)
1131 {
1132     mp_int s, x;
1133     mp_err res;
1134     mp_digit d;
1135     unsigned int dig, bit;
1136 
1137     ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
1138 
1139     if (mp_cmp_z(b) < 0)
1140         return MP_RANGE;
1141 
1142     if ((res = mp_init(&s)) != MP_OKAY)
1143         return res;
1144 
1145     mp_set(&s, 1);
1146 
1147     if ((res = mp_init_copy(&x, a)) != MP_OKAY)
1148         goto X;
1149 
1150     /* Loop over low-order digits in ascending order */
1151     for (dig = 0; dig < (USED(b) - 1); dig++) {
1152         d = DIGIT(b, dig);
1153 
1154         /* Loop over bits of each non-maximal digit */
1155         for (bit = 0; bit < DIGIT_BIT; bit++) {
1156             if (d & 1) {
1157                 if ((res = s_mp_mul(&s, &x)) != MP_OKAY)
1158                     goto CLEANUP;
1159             }
1160 
1161             d >>= 1;
1162 
1163             if ((res = s_mp_sqr(&x)) != MP_OKAY)
1164                 goto CLEANUP;
1165         }
1166     }
1167 
1168     /* Consider now the last digit... */
1169     d = DIGIT(b, dig);
1170 
1171     while (d) {
1172         if (d & 1) {
1173             if ((res = s_mp_mul(&s, &x)) != MP_OKAY)
1174                 goto CLEANUP;
1175         }
1176 
1177         d >>= 1;
1178 
1179         if ((res = s_mp_sqr(&x)) != MP_OKAY)
1180             goto CLEANUP;
1181     }
1182 
1183     if (mp_iseven(b))
1184         SIGN(&s) = SIGN(a);
1185 
1186     res = mp_copy(&s, c);
1187 
1188 CLEANUP:
1189     mp_clear(&x);
1190 X:
1191     mp_clear(&s);
1192 
1193     return res;
1194 
1195 } /* end mp_expt() */
1196 
1197 /* }}} */
1198 
1199 /* {{{ mp_2expt(a, k) */
1200 
1201 /* Compute a = 2^k */
1202 
1203 mp_err
mp_2expt(mp_int * a,mp_digit k)1204 mp_2expt(mp_int *a, mp_digit k)
1205 {
1206     ARGCHK(a != NULL, MP_BADARG);
1207 
1208     return s_mp_2expt(a, k);
1209 
1210 } /* end mp_2expt() */
1211 
1212 /* }}} */
1213 
1214 /* {{{ mp_mod(a, m, c) */
1215 
1216 /*
1217   mp_mod(a, m, c)
1218 
1219   Compute c = a (mod m).  Result will always be 0 <= c < m.
1220  */
1221 
1222 mp_err
mp_mod(const mp_int * a,const mp_int * m,mp_int * c)1223 mp_mod(const mp_int *a, const mp_int *m, mp_int *c)
1224 {
1225     mp_err res;
1226     int mag;
1227 
1228     ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
1229 
1230     if (SIGN(m) == NEG)
1231         return MP_RANGE;
1232 
1233     /*
1234      If |a| > m, we need to divide to get the remainder and take the
1235      absolute value.
1236 
1237      If |a| < m, we don't need to do any division, just copy and adjust
1238      the sign (if a is negative).
1239 
1240      If |a| == m, we can simply set the result to zero.
1241 
1242      This order is intended to minimize the average path length of the
1243      comparison chain on common workloads -- the most frequent cases are
1244      that |a| != m, so we do those first.
1245      */
1246     if ((mag = s_mp_cmp(a, m)) > 0) {
1247         if ((res = mp_div(a, m, NULL, c)) != MP_OKAY)
1248             return res;
1249 
1250         if (SIGN(c) == NEG) {
1251             if ((res = mp_add(c, m, c)) != MP_OKAY)
1252                 return res;
1253         }
1254 
1255     } else if (mag < 0) {
1256         if ((res = mp_copy(a, c)) != MP_OKAY)
1257             return res;
1258 
1259         if (mp_cmp_z(a) < 0) {
1260             if ((res = mp_add(c, m, c)) != MP_OKAY)
1261                 return res;
1262         }
1263 
1264     } else {
1265         mp_zero(c);
1266     }
1267 
1268     return MP_OKAY;
1269 
1270 } /* end mp_mod() */
1271 
1272 /* }}} */
1273 
1274 /* {{{ mp_mod_d(a, d, c) */
1275 
1276 /*
1277   mp_mod_d(a, d, c)
1278 
1279   Compute c = a (mod d).  Result will always be 0 <= c < d
1280  */
1281 mp_err
mp_mod_d(const mp_int * a,mp_digit d,mp_digit * c)1282 mp_mod_d(const mp_int *a, mp_digit d, mp_digit *c)
1283 {
1284     mp_err res;
1285     mp_digit rem;
1286 
1287     ARGCHK(a != NULL && c != NULL, MP_BADARG);
1288 
1289     if (s_mp_cmp_d(a, d) > 0) {
1290         if ((res = mp_div_d(a, d, NULL, &rem)) != MP_OKAY)
1291             return res;
1292 
1293     } else {
1294         if (SIGN(a) == NEG)
1295             rem = d - DIGIT(a, 0);
1296         else
1297             rem = DIGIT(a, 0);
1298     }
1299 
1300     if (c)
1301         *c = rem;
1302 
1303     return MP_OKAY;
1304 
1305 } /* end mp_mod_d() */
1306 
1307 /* }}} */
1308 
1309 /* }}} */
1310 
1311 /*------------------------------------------------------------------------*/
1312 /* {{{ Modular arithmetic */
1313 
1314 #if MP_MODARITH
1315 /* {{{ mp_addmod(a, b, m, c) */
1316 
1317 /*
1318   mp_addmod(a, b, m, c)
1319 
1320   Compute c = (a + b) mod m
1321  */
1322 
1323 mp_err
mp_addmod(const mp_int * a,const mp_int * b,const mp_int * m,mp_int * c)1324 mp_addmod(const mp_int *a, const mp_int *b, const mp_int *m, mp_int *c)
1325 {
1326     mp_err res;
1327 
1328     ARGCHK(a != NULL && b != NULL && m != NULL && c != NULL, MP_BADARG);
1329 
1330     if ((res = mp_add(a, b, c)) != MP_OKAY)
1331         return res;
1332     if ((res = mp_mod(c, m, c)) != MP_OKAY)
1333         return res;
1334 
1335     return MP_OKAY;
1336 }
1337 
1338 /* }}} */
1339 
1340 /* {{{ mp_submod(a, b, m, c) */
1341 
1342 /*
1343   mp_submod(a, b, m, c)
1344 
1345   Compute c = (a - b) mod m
1346  */
1347 
1348 mp_err
mp_submod(const mp_int * a,const mp_int * b,const mp_int * m,mp_int * c)1349 mp_submod(const mp_int *a, const mp_int *b, const mp_int *m, mp_int *c)
1350 {
1351     mp_err res;
1352 
1353     ARGCHK(a != NULL && b != NULL && m != NULL && c != NULL, MP_BADARG);
1354 
1355     if ((res = mp_sub(a, b, c)) != MP_OKAY)
1356         return res;
1357     if ((res = mp_mod(c, m, c)) != MP_OKAY)
1358         return res;
1359 
1360     return MP_OKAY;
1361 }
1362 
1363 /* }}} */
1364 
1365 /* {{{ mp_mulmod(a, b, m, c) */
1366 
1367 /*
1368   mp_mulmod(a, b, m, c)
1369 
1370   Compute c = (a * b) mod m
1371  */
1372 
1373 mp_err
mp_mulmod(const mp_int * a,const mp_int * b,const mp_int * m,mp_int * c)1374 mp_mulmod(const mp_int *a, const mp_int *b, const mp_int *m, mp_int *c)
1375 {
1376     mp_err res;
1377 
1378     ARGCHK(a != NULL && b != NULL && m != NULL && c != NULL, MP_BADARG);
1379 
1380     if ((res = mp_mul(a, b, c)) != MP_OKAY)
1381         return res;
1382     if ((res = mp_mod(c, m, c)) != MP_OKAY)
1383         return res;
1384 
1385     return MP_OKAY;
1386 }
1387 
1388 /* }}} */
1389 
1390 /* {{{ mp_sqrmod(a, m, c) */
1391 
1392 #if MP_SQUARE
1393 mp_err
mp_sqrmod(const mp_int * a,const mp_int * m,mp_int * c)1394 mp_sqrmod(const mp_int *a, const mp_int *m, mp_int *c)
1395 {
1396     mp_err res;
1397 
1398     ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
1399 
1400     if ((res = mp_sqr(a, c)) != MP_OKAY)
1401         return res;
1402     if ((res = mp_mod(c, m, c)) != MP_OKAY)
1403         return res;
1404 
1405     return MP_OKAY;
1406 
1407 } /* end mp_sqrmod() */
1408 #endif
1409 
1410 /* }}} */
1411 
1412 /* {{{ s_mp_exptmod(a, b, m, c) */
1413 
1414 /*
1415   s_mp_exptmod(a, b, m, c)
1416 
1417   Compute c = (a ** b) mod m.  Uses a standard square-and-multiply
1418   method with modular reductions at each step. (This is basically the
1419   same code as mp_expt(), except for the addition of the reductions)
1420 
1421   The modular reductions are done using Barrett's algorithm (see
1422   s_mp_reduce() below for details)
1423  */
1424 
1425 mp_err
s_mp_exptmod(const mp_int * a,const mp_int * b,const mp_int * m,mp_int * c)1426 s_mp_exptmod(const mp_int *a, const mp_int *b, const mp_int *m, mp_int *c)
1427 {
1428     mp_int s, x, mu;
1429     mp_err res;
1430     mp_digit d;
1431     unsigned int dig, bit;
1432 
1433     ARGCHK(a != NULL && b != NULL && c != NULL && m != NULL, MP_BADARG);
1434 
1435     if (mp_cmp_z(b) < 0 || mp_cmp_z(m) <= 0)
1436         return MP_RANGE;
1437 
1438     if ((res = mp_init(&s)) != MP_OKAY)
1439         return res;
1440     if ((res = mp_init_copy(&x, a)) != MP_OKAY ||
1441         (res = mp_mod(&x, m, &x)) != MP_OKAY)
1442         goto X;
1443     if ((res = mp_init(&mu)) != MP_OKAY)
1444         goto MU;
1445 
1446     mp_set(&s, 1);
1447 
1448     /* mu = b^2k / m */
1449     if ((res = s_mp_add_d(&mu, 1)) != MP_OKAY)
1450         goto CLEANUP;
1451     if ((res = s_mp_lshd(&mu, 2 * USED(m))) != MP_OKAY)
1452         goto CLEANUP;
1453     if ((res = mp_div(&mu, m, &mu, NULL)) != MP_OKAY)
1454         goto CLEANUP;
1455 
1456     /* Loop over digits of b in ascending order, except highest order */
1457     for (dig = 0; dig < (USED(b) - 1); dig++) {
1458         d = DIGIT(b, dig);
1459 
1460         /* Loop over the bits of the lower-order digits */
1461         for (bit = 0; bit < DIGIT_BIT; bit++) {
1462             if (d & 1) {
1463                 if ((res = s_mp_mul(&s, &x)) != MP_OKAY)
1464                     goto CLEANUP;
1465                 if ((res = s_mp_reduce(&s, m, &mu)) != MP_OKAY)
1466                     goto CLEANUP;
1467             }
1468 
1469             d >>= 1;
1470 
1471             if ((res = s_mp_sqr(&x)) != MP_OKAY)
1472                 goto CLEANUP;
1473             if ((res = s_mp_reduce(&x, m, &mu)) != MP_OKAY)
1474                 goto CLEANUP;
1475         }
1476     }
1477 
1478     /* Now do the last digit... */
1479     d = DIGIT(b, dig);
1480 
1481     while (d) {
1482         if (d & 1) {
1483             if ((res = s_mp_mul(&s, &x)) != MP_OKAY)
1484                 goto CLEANUP;
1485             if ((res = s_mp_reduce(&s, m, &mu)) != MP_OKAY)
1486                 goto CLEANUP;
1487         }
1488 
1489         d >>= 1;
1490 
1491         if ((res = s_mp_sqr(&x)) != MP_OKAY)
1492             goto CLEANUP;
1493         if ((res = s_mp_reduce(&x, m, &mu)) != MP_OKAY)
1494             goto CLEANUP;
1495     }
1496 
1497     s_mp_exch(&s, c);
1498 
1499 CLEANUP:
1500     mp_clear(&mu);
1501 MU:
1502     mp_clear(&x);
1503 X:
1504     mp_clear(&s);
1505 
1506     return res;
1507 
1508 } /* end s_mp_exptmod() */
1509 
1510 /* }}} */
1511 
1512 /* {{{ mp_exptmod_d(a, d, m, c) */
1513 
1514 mp_err
mp_exptmod_d(const mp_int * a,mp_digit d,const mp_int * m,mp_int * c)1515 mp_exptmod_d(const mp_int *a, mp_digit d, const mp_int *m, mp_int *c)
1516 {
1517     mp_int s, x;
1518     mp_err res;
1519 
1520     ARGCHK(a != NULL && c != NULL && m != NULL, MP_BADARG);
1521 
1522     if ((res = mp_init(&s)) != MP_OKAY)
1523         return res;
1524     if ((res = mp_init_copy(&x, a)) != MP_OKAY)
1525         goto X;
1526 
1527     mp_set(&s, 1);
1528 
1529     while (d != 0) {
1530         if (d & 1) {
1531             if ((res = s_mp_mul(&s, &x)) != MP_OKAY ||
1532                 (res = mp_mod(&s, m, &s)) != MP_OKAY)
1533                 goto CLEANUP;
1534         }
1535 
1536         d /= 2;
1537 
1538         if ((res = s_mp_sqr(&x)) != MP_OKAY ||
1539             (res = mp_mod(&x, m, &x)) != MP_OKAY)
1540             goto CLEANUP;
1541     }
1542 
1543     s_mp_exch(&s, c);
1544 
1545 CLEANUP:
1546     mp_clear(&x);
1547 X:
1548     mp_clear(&s);
1549 
1550     return res;
1551 
1552 } /* end mp_exptmod_d() */
1553 
1554 /* }}} */
1555 #endif /* if MP_MODARITH */
1556 
1557 /* }}} */
1558 
1559 /*------------------------------------------------------------------------*/
1560 /* {{{ Comparison functions */
1561 
1562 /* {{{ mp_cmp_z(a) */
1563 
1564 /*
1565   mp_cmp_z(a)
1566 
1567   Compare a <=> 0.  Returns <0 if a<0, 0 if a=0, >0 if a>0.
1568  */
1569 
1570 int
mp_cmp_z(const mp_int * a)1571 mp_cmp_z(const mp_int *a)
1572 {
1573     ARGMPCHK(a != NULL);
1574 
1575     if (SIGN(a) == NEG)
1576         return MP_LT;
1577     else if (USED(a) == 1 && DIGIT(a, 0) == 0)
1578         return MP_EQ;
1579     else
1580         return MP_GT;
1581 
1582 } /* end mp_cmp_z() */
1583 
1584 /* }}} */
1585 
1586 /* {{{ mp_cmp_d(a, d) */
1587 
1588 /*
1589   mp_cmp_d(a, d)
1590 
1591   Compare a <=> d.  Returns <0 if a<d, 0 if a=d, >0 if a>d
1592  */
1593 
1594 int
mp_cmp_d(const mp_int * a,mp_digit d)1595 mp_cmp_d(const mp_int *a, mp_digit d)
1596 {
1597     ARGCHK(a != NULL, MP_EQ);
1598 
1599     if (SIGN(a) == NEG)
1600         return MP_LT;
1601 
1602     return s_mp_cmp_d(a, d);
1603 
1604 } /* end mp_cmp_d() */
1605 
1606 /* }}} */
1607 
1608 /* {{{ mp_cmp(a, b) */
1609 
1610 int
mp_cmp(const mp_int * a,const mp_int * b)1611 mp_cmp(const mp_int *a, const mp_int *b)
1612 {
1613     ARGCHK(a != NULL && b != NULL, MP_EQ);
1614 
1615     if (SIGN(a) == SIGN(b)) {
1616         int mag;
1617 
1618         if ((mag = s_mp_cmp(a, b)) == MP_EQ)
1619             return MP_EQ;
1620 
1621         if (SIGN(a) == ZPOS)
1622             return mag;
1623         else
1624             return -mag;
1625 
1626     } else if (SIGN(a) == ZPOS) {
1627         return MP_GT;
1628     } else {
1629         return MP_LT;
1630     }
1631 
1632 } /* end mp_cmp() */
1633 
1634 /* }}} */
1635 
1636 /* {{{ mp_cmp_mag(a, b) */
1637 
1638 /*
1639   mp_cmp_mag(a, b)
1640 
1641   Compares |a| <=> |b|, and returns an appropriate comparison result
1642  */
1643 
1644 int
mp_cmp_mag(const mp_int * a,const mp_int * b)1645 mp_cmp_mag(const mp_int *a, const mp_int *b)
1646 {
1647     ARGCHK(a != NULL && b != NULL, MP_EQ);
1648 
1649     return s_mp_cmp(a, b);
1650 
1651 } /* end mp_cmp_mag() */
1652 
1653 /* }}} */
1654 
1655 /* {{{ mp_isodd(a) */
1656 
1657 /*
1658   mp_isodd(a)
1659 
1660   Returns a true (non-zero) value if a is odd, false (zero) otherwise.
1661  */
1662 int
mp_isodd(const mp_int * a)1663 mp_isodd(const mp_int *a)
1664 {
1665     ARGMPCHK(a != NULL);
1666 
1667     return (int)(DIGIT(a, 0) & 1);
1668 
1669 } /* end mp_isodd() */
1670 
1671 /* }}} */
1672 
1673 /* {{{ mp_iseven(a) */
1674 
1675 int
mp_iseven(const mp_int * a)1676 mp_iseven(const mp_int *a)
1677 {
1678     return !mp_isodd(a);
1679 
1680 } /* end mp_iseven() */
1681 
1682 /* }}} */
1683 
1684 /* }}} */
1685 
1686 /*------------------------------------------------------------------------*/
1687 /* {{{ Number theoretic functions */
1688 
1689 /* {{{ mp_gcd(a, b, c) */
1690 
1691 /*
1692   Computes the GCD using the constant-time algorithm
1693   by Bernstein and Yang (https://eprint.iacr.org/2019/266)
1694   "Fast constant-time gcd computation and modular inversion"
1695  */
1696 mp_err
mp_gcd(mp_int * a,mp_int * b,mp_int * c)1697 mp_gcd(mp_int *a, mp_int *b, mp_int *c)
1698 {
1699     mp_err res;
1700     mp_digit cond = 0, mask = 0;
1701     mp_int g, temp, f;
1702     int i, j, m, bit = 1, delta = 1, shifts = 0, last = -1;
1703     mp_size top, flen, glen;
1704     mp_int *clear[3];
1705 
1706     ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
1707     /*
1708     Early exit if either of the inputs is zero.
1709     Caller is responsible for the proper handling of inputs.
1710     */
1711     if (mp_cmp_z(a) == MP_EQ) {
1712         res = mp_copy(b, c);
1713         SIGN(c) = ZPOS;
1714         return res;
1715     } else if (mp_cmp_z(b) == MP_EQ) {
1716         res = mp_copy(a, c);
1717         SIGN(c) = ZPOS;
1718         return res;
1719     }
1720 
1721     MP_CHECKOK(mp_init(&temp));
1722     clear[++last] = &temp;
1723     MP_CHECKOK(mp_init_copy(&g, a));
1724     clear[++last] = &g;
1725     MP_CHECKOK(mp_init_copy(&f, b));
1726     clear[++last] = &f;
1727 
1728     /*
1729     For even case compute the number of
1730     shared powers of 2 in f and g.
1731     */
1732     for (i = 0; i < USED(&f) && i < USED(&g); i++) {
1733         mask = ~(DIGIT(&f, i) | DIGIT(&g, i));
1734         for (j = 0; j < MP_DIGIT_BIT; j++) {
1735             bit &= mask;
1736             shifts += bit;
1737             mask >>= 1;
1738         }
1739     }
1740     /* Reduce to the odd case by removing the powers of 2. */
1741     s_mp_div_2d(&f, shifts);
1742     s_mp_div_2d(&g, shifts);
1743 
1744     /* Allocate to the size of largest mp_int. */
1745     top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
1746     MP_CHECKOK(s_mp_grow(&f, top));
1747     MP_CHECKOK(s_mp_grow(&g, top));
1748     MP_CHECKOK(s_mp_grow(&temp, top));
1749 
1750     /* Make sure f contains the odd value. */
1751     MP_CHECKOK(mp_cswap((~DIGIT(&f, 0) & 1), &f, &g, top));
1752 
1753     /* Upper bound for the total iterations. */
1754     flen = mpl_significant_bits(&f);
1755     glen = mpl_significant_bits(&g);
1756     m = 4 + 3 * ((flen >= glen) ? flen : glen);
1757 
1758 #if defined(_MSC_VER)
1759 #pragma warning(push)
1760 #pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
1761 #endif
1762 
1763     for (i = 0; i < m; i++) {
1764         /* Step 1: conditional swap. */
1765         /* Set cond if delta > 0 and g is odd. */
1766         cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
1767         /* If cond is set replace (delta,f) with (-delta,-f). */
1768         delta = (-cond & -delta) | ((cond - 1) & delta);
1769         SIGN(&f) ^= cond;
1770         /* If cond is set swap f with g. */
1771         MP_CHECKOK(mp_cswap(cond, &f, &g, top));
1772 
1773         /* Step 2: elemination. */
1774         /* Update delta. */
1775         delta++;
1776         /* If g is odd, right shift (g+f) else right shift g. */
1777         MP_CHECKOK(mp_add(&g, &f, &temp));
1778         MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
1779         s_mp_div_2(&g);
1780     }
1781 
1782 #if defined(_MSC_VER)
1783 #pragma warning(pop)
1784 #endif
1785 
1786     /* GCD is in f, take the absolute value. */
1787     SIGN(&f) = ZPOS;
1788 
1789     /* Add back the removed powers of 2. */
1790     MP_CHECKOK(s_mp_mul_2d(&f, shifts));
1791 
1792     MP_CHECKOK(mp_copy(&f, c));
1793 
1794 CLEANUP:
1795     while (last >= 0)
1796         mp_clear(clear[last--]);
1797     return res;
1798 } /* end mp_gcd() */
1799 
1800 /* }}} */
1801 
1802 /* {{{ mp_lcm(a, b, c) */
1803 
1804 /* We compute the least common multiple using the rule:
1805 
1806    ab = [a, b](a, b)
1807 
1808    ... by computing the product, and dividing out the gcd.
1809  */
1810 
1811 mp_err
mp_lcm(mp_int * a,mp_int * b,mp_int * c)1812 mp_lcm(mp_int *a, mp_int *b, mp_int *c)
1813 {
1814     mp_int gcd, prod;
1815     mp_err res;
1816 
1817     ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
1818 
1819     /* Set up temporaries */
1820     if ((res = mp_init(&gcd)) != MP_OKAY)
1821         return res;
1822     if ((res = mp_init(&prod)) != MP_OKAY)
1823         goto GCD;
1824 
1825     if ((res = mp_mul(a, b, &prod)) != MP_OKAY)
1826         goto CLEANUP;
1827     if ((res = mp_gcd(a, b, &gcd)) != MP_OKAY)
1828         goto CLEANUP;
1829 
1830     res = mp_div(&prod, &gcd, c, NULL);
1831 
1832 CLEANUP:
1833     mp_clear(&prod);
1834 GCD:
1835     mp_clear(&gcd);
1836 
1837     return res;
1838 
1839 } /* end mp_lcm() */
1840 
1841 /* }}} */
1842 
1843 /* {{{ mp_xgcd(a, b, g, x, y) */
1844 
1845 /*
1846   mp_xgcd(a, b, g, x, y)
1847 
1848   Compute g = (a, b) and values x and y satisfying Bezout's identity
1849   (that is, ax + by = g).  This uses the binary extended GCD algorithm
1850   based on the Stein algorithm used for mp_gcd()
1851   See algorithm 14.61 in Handbook of Applied Cryptogrpahy.
1852  */
1853 
1854 mp_err
mp_xgcd(const mp_int * a,const mp_int * b,mp_int * g,mp_int * x,mp_int * y)1855 mp_xgcd(const mp_int *a, const mp_int *b, mp_int *g, mp_int *x, mp_int *y)
1856 {
1857     mp_int gx, xc, yc, u, v, A, B, C, D;
1858     mp_int *clean[9];
1859     mp_err res;
1860     int last = -1;
1861 
1862     if (mp_cmp_z(b) == 0)
1863         return MP_RANGE;
1864 
1865     /* Initialize all these variables we need */
1866     MP_CHECKOK(mp_init(&u));
1867     clean[++last] = &u;
1868     MP_CHECKOK(mp_init(&v));
1869     clean[++last] = &v;
1870     MP_CHECKOK(mp_init(&gx));
1871     clean[++last] = &gx;
1872     MP_CHECKOK(mp_init(&A));
1873     clean[++last] = &A;
1874     MP_CHECKOK(mp_init(&B));
1875     clean[++last] = &B;
1876     MP_CHECKOK(mp_init(&C));
1877     clean[++last] = &C;
1878     MP_CHECKOK(mp_init(&D));
1879     clean[++last] = &D;
1880     MP_CHECKOK(mp_init_copy(&xc, a));
1881     clean[++last] = &xc;
1882     mp_abs(&xc, &xc);
1883     MP_CHECKOK(mp_init_copy(&yc, b));
1884     clean[++last] = &yc;
1885     mp_abs(&yc, &yc);
1886 
1887     mp_set(&gx, 1);
1888 
1889     /* Divide by two until at least one of them is odd */
1890     while (mp_iseven(&xc) && mp_iseven(&yc)) {
1891         mp_size nx = mp_trailing_zeros(&xc);
1892         mp_size ny = mp_trailing_zeros(&yc);
1893         mp_size n = MP_MIN(nx, ny);
1894         s_mp_div_2d(&xc, n);
1895         s_mp_div_2d(&yc, n);
1896         MP_CHECKOK(s_mp_mul_2d(&gx, n));
1897     }
1898 
1899     MP_CHECKOK(mp_copy(&xc, &u));
1900     MP_CHECKOK(mp_copy(&yc, &v));
1901     mp_set(&A, 1);
1902     mp_set(&D, 1);
1903 
1904     /* Loop through binary GCD algorithm */
1905     do {
1906         while (mp_iseven(&u)) {
1907             s_mp_div_2(&u);
1908 
1909             if (mp_iseven(&A) && mp_iseven(&B)) {
1910                 s_mp_div_2(&A);
1911                 s_mp_div_2(&B);
1912             } else {
1913                 MP_CHECKOK(mp_add(&A, &yc, &A));
1914                 s_mp_div_2(&A);
1915                 MP_CHECKOK(mp_sub(&B, &xc, &B));
1916                 s_mp_div_2(&B);
1917             }
1918         }
1919 
1920         while (mp_iseven(&v)) {
1921             s_mp_div_2(&v);
1922 
1923             if (mp_iseven(&C) && mp_iseven(&D)) {
1924                 s_mp_div_2(&C);
1925                 s_mp_div_2(&D);
1926             } else {
1927                 MP_CHECKOK(mp_add(&C, &yc, &C));
1928                 s_mp_div_2(&C);
1929                 MP_CHECKOK(mp_sub(&D, &xc, &D));
1930                 s_mp_div_2(&D);
1931             }
1932         }
1933 
1934         if (mp_cmp(&u, &v) >= 0) {
1935             MP_CHECKOK(mp_sub(&u, &v, &u));
1936             MP_CHECKOK(mp_sub(&A, &C, &A));
1937             MP_CHECKOK(mp_sub(&B, &D, &B));
1938         } else {
1939             MP_CHECKOK(mp_sub(&v, &u, &v));
1940             MP_CHECKOK(mp_sub(&C, &A, &C));
1941             MP_CHECKOK(mp_sub(&D, &B, &D));
1942         }
1943     } while (mp_cmp_z(&u) != 0);
1944 
1945     /* copy results to output */
1946     if (x)
1947         MP_CHECKOK(mp_copy(&C, x));
1948 
1949     if (y)
1950         MP_CHECKOK(mp_copy(&D, y));
1951 
1952     if (g)
1953         MP_CHECKOK(mp_mul(&gx, &v, g));
1954 
1955 CLEANUP:
1956     while (last >= 0)
1957         mp_clear(clean[last--]);
1958 
1959     return res;
1960 
1961 } /* end mp_xgcd() */
1962 
1963 /* }}} */
1964 
1965 mp_size
mp_trailing_zeros(const mp_int * mp)1966 mp_trailing_zeros(const mp_int *mp)
1967 {
1968     mp_digit d;
1969     mp_size n = 0;
1970     unsigned int ix;
1971 
1972     if (!mp || !MP_DIGITS(mp) || !mp_cmp_z(mp))
1973         return n;
1974 
1975     for (ix = 0; !(d = MP_DIGIT(mp, ix)) && (ix < MP_USED(mp)); ++ix)
1976         n += MP_DIGIT_BIT;
1977     if (!d)
1978         return 0; /* shouldn't happen, but ... */
1979 #if !defined(MP_USE_UINT_DIGIT)
1980     if (!(d & 0xffffffffU)) {
1981         d >>= 32;
1982         n += 32;
1983     }
1984 #endif
1985     if (!(d & 0xffffU)) {
1986         d >>= 16;
1987         n += 16;
1988     }
1989     if (!(d & 0xffU)) {
1990         d >>= 8;
1991         n += 8;
1992     }
1993     if (!(d & 0xfU)) {
1994         d >>= 4;
1995         n += 4;
1996     }
1997     if (!(d & 0x3U)) {
1998         d >>= 2;
1999         n += 2;
2000     }
2001     if (!(d & 0x1U)) {
2002         d >>= 1;
2003         n += 1;
2004     }
2005 #if MP_ARGCHK == 2
2006     assert(0 != (d & 1));
2007 #endif
2008     return n;
2009 }
2010 
2011 /* Given a and prime p, computes c and k such that a*c == 2**k (mod p).
2012 ** Returns k (positive) or error (negative).
2013 ** This technique from the paper "Fast Modular Reciprocals" (unpublished)
2014 ** by Richard Schroeppel (a.k.a. Captain Nemo).
2015 */
2016 mp_err
s_mp_almost_inverse(const mp_int * a,const mp_int * p,mp_int * c)2017 s_mp_almost_inverse(const mp_int *a, const mp_int *p, mp_int *c)
2018 {
2019     mp_err res;
2020     mp_err k = 0;
2021     mp_int d, f, g;
2022 
2023     ARGCHK(a != NULL && p != NULL && c != NULL, MP_BADARG);
2024 
2025     MP_DIGITS(&d) = 0;
2026     MP_DIGITS(&f) = 0;
2027     MP_DIGITS(&g) = 0;
2028     MP_CHECKOK(mp_init(&d));
2029     MP_CHECKOK(mp_init_copy(&f, a)); /* f = a */
2030     MP_CHECKOK(mp_init_copy(&g, p)); /* g = p */
2031 
2032     mp_set(c, 1);
2033     mp_zero(&d);
2034 
2035     if (mp_cmp_z(&f) == 0) {
2036         res = MP_UNDEF;
2037     } else
2038         for (;;) {
2039             int diff_sign;
2040             while (mp_iseven(&f)) {
2041                 mp_size n = mp_trailing_zeros(&f);
2042                 if (!n) {
2043                     res = MP_UNDEF;
2044                     goto CLEANUP;
2045                 }
2046                 s_mp_div_2d(&f, n);
2047                 MP_CHECKOK(s_mp_mul_2d(&d, n));
2048                 k += n;
2049             }
2050             if (mp_cmp_d(&f, 1) == MP_EQ) { /* f == 1 */
2051                 res = k;
2052                 break;
2053             }
2054             diff_sign = mp_cmp(&f, &g);
2055             if (diff_sign < 0) { /* f < g */
2056                 s_mp_exch(&f, &g);
2057                 s_mp_exch(c, &d);
2058             } else if (diff_sign == 0) { /* f == g */
2059                 res = MP_UNDEF;          /* a and p are not relatively prime */
2060                 break;
2061             }
2062             if ((MP_DIGIT(&f, 0) % 4) == (MP_DIGIT(&g, 0) % 4)) {
2063                 MP_CHECKOK(mp_sub(&f, &g, &f)); /* f = f - g */
2064                 MP_CHECKOK(mp_sub(c, &d, c));   /* c = c - d */
2065             } else {
2066                 MP_CHECKOK(mp_add(&f, &g, &f)); /* f = f + g */
2067                 MP_CHECKOK(mp_add(c, &d, c));   /* c = c + d */
2068             }
2069         }
2070     if (res >= 0) {
2071         if (mp_cmp_mag(c, p) >= 0) {
2072             MP_CHECKOK(mp_div(c, p, NULL, c));
2073         }
2074         if (MP_SIGN(c) != MP_ZPOS) {
2075             MP_CHECKOK(mp_add(c, p, c));
2076         }
2077         res = k;
2078     }
2079 
2080 CLEANUP:
2081     mp_clear(&d);
2082     mp_clear(&f);
2083     mp_clear(&g);
2084     return res;
2085 }
2086 
2087 /* Compute T = (P ** -1) mod MP_RADIX.  Also works for 16-bit mp_digits.
2088 ** This technique from the paper "Fast Modular Reciprocals" (unpublished)
2089 ** by Richard Schroeppel (a.k.a. Captain Nemo).
2090 */
2091 mp_digit
s_mp_invmod_radix(mp_digit P)2092 s_mp_invmod_radix(mp_digit P)
2093 {
2094     mp_digit T = P;
2095     T *= 2 - (P * T);
2096     T *= 2 - (P * T);
2097     T *= 2 - (P * T);
2098     T *= 2 - (P * T);
2099 #if !defined(MP_USE_UINT_DIGIT)
2100     T *= 2 - (P * T);
2101     T *= 2 - (P * T);
2102 #endif
2103     return T;
2104 }
2105 
2106 /* Given c, k, and prime p, where a*c == 2**k (mod p),
2107 ** Compute x = (a ** -1) mod p.  This is similar to Montgomery reduction.
2108 ** This technique from the paper "Fast Modular Reciprocals" (unpublished)
2109 ** by Richard Schroeppel (a.k.a. Captain Nemo).
2110 */
2111 mp_err
s_mp_fixup_reciprocal(const mp_int * c,const mp_int * p,int k,mp_int * x)2112 s_mp_fixup_reciprocal(const mp_int *c, const mp_int *p, int k, mp_int *x)
2113 {
2114     int k_orig = k;
2115     mp_digit r;
2116     mp_size ix;
2117     mp_err res;
2118 
2119     if (mp_cmp_z(c) < 0) {           /* c < 0 */
2120         MP_CHECKOK(mp_add(c, p, x)); /* x = c + p */
2121     } else {
2122         MP_CHECKOK(mp_copy(c, x)); /* x = c */
2123     }
2124 
2125     /* make sure x is large enough */
2126     ix = MP_HOWMANY(k, MP_DIGIT_BIT) + MP_USED(p) + 1;
2127     ix = MP_MAX(ix, MP_USED(x));
2128     MP_CHECKOK(s_mp_pad(x, ix));
2129 
2130     r = 0 - s_mp_invmod_radix(MP_DIGIT(p, 0));
2131 
2132     for (ix = 0; k > 0; ix++) {
2133         int j = MP_MIN(k, MP_DIGIT_BIT);
2134         mp_digit v = r * MP_DIGIT(x, ix);
2135         if (j < MP_DIGIT_BIT) {
2136             v &= ((mp_digit)1 << j) - 1; /* v = v mod (2 ** j) */
2137         }
2138         s_mp_mul_d_add_offset(p, v, x, ix); /* x += p * v * (RADIX ** ix) */
2139         k -= j;
2140     }
2141     s_mp_clamp(x);
2142     s_mp_div_2d(x, k_orig);
2143     res = MP_OKAY;
2144 
2145 CLEANUP:
2146     return res;
2147 }
2148 
2149 /*
2150   Computes the modular inverse using the constant-time algorithm
2151   by Bernstein and Yang (https://eprint.iacr.org/2019/266)
2152   "Fast constant-time gcd computation and modular inversion"
2153  */
2154 mp_err
s_mp_invmod_odd_m(const mp_int * a,const mp_int * m,mp_int * c)2155 s_mp_invmod_odd_m(const mp_int *a, const mp_int *m, mp_int *c)
2156 {
2157     mp_err res;
2158     mp_digit cond = 0;
2159     mp_int g, f, v, r, temp;
2160     int i, its, delta = 1, last = -1;
2161     mp_size top, flen, glen;
2162     mp_int *clear[6];
2163 
2164     ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
2165     /* Check for invalid inputs. */
2166     if (mp_cmp_z(a) == MP_EQ || mp_cmp_d(m, 2) == MP_LT)
2167         return MP_RANGE;
2168 
2169     if (a == m || mp_iseven(m))
2170         return MP_UNDEF;
2171 
2172     MP_CHECKOK(mp_init(&temp));
2173     clear[++last] = &temp;
2174     MP_CHECKOK(mp_init(&v));
2175     clear[++last] = &v;
2176     MP_CHECKOK(mp_init(&r));
2177     clear[++last] = &r;
2178     MP_CHECKOK(mp_init_copy(&g, a));
2179     clear[++last] = &g;
2180     MP_CHECKOK(mp_init_copy(&f, m));
2181     clear[++last] = &f;
2182 
2183     mp_set(&v, 0);
2184     mp_set(&r, 1);
2185 
2186     /* Allocate to the size of largest mp_int. */
2187     top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
2188     MP_CHECKOK(s_mp_grow(&f, top));
2189     MP_CHECKOK(s_mp_grow(&g, top));
2190     MP_CHECKOK(s_mp_grow(&temp, top));
2191     MP_CHECKOK(s_mp_grow(&v, top));
2192     MP_CHECKOK(s_mp_grow(&r, top));
2193 
2194     /* Upper bound for the total iterations. */
2195     flen = mpl_significant_bits(&f);
2196     glen = mpl_significant_bits(&g);
2197     its = 4 + 3 * ((flen >= glen) ? flen : glen);
2198 
2199 #if defined(_MSC_VER)
2200 #pragma warning(push)
2201 #pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
2202 #endif
2203 
2204     for (i = 0; i < its; i++) {
2205         /* Step 1: conditional swap. */
2206         /* Set cond if delta > 0 and g is odd. */
2207         cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
2208         /* If cond is set replace (delta,f,v) with (-delta,-f,-v). */
2209         delta = (-cond & -delta) | ((cond - 1) & delta);
2210         SIGN(&f) ^= cond;
2211         SIGN(&v) ^= cond;
2212         /* If cond is set swap (f,v) with (g,r). */
2213         MP_CHECKOK(mp_cswap(cond, &f, &g, top));
2214         MP_CHECKOK(mp_cswap(cond, &v, &r, top));
2215 
2216         /* Step 2: elemination. */
2217         /* Update delta */
2218         delta++;
2219         /* If g is odd replace r with (r+v). */
2220         MP_CHECKOK(mp_add(&r, &v, &temp));
2221         MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &r, &temp, top));
2222         /* If g is odd, right shift (g+f) else right shift g. */
2223         MP_CHECKOK(mp_add(&g, &f, &temp));
2224         MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
2225         s_mp_div_2(&g);
2226         /*
2227         If r is even, right shift it.
2228         If r is odd, right shift (r+m) which is even because m is odd.
2229         We want the result modulo m so adding in multiples of m here vanish.
2230         */
2231         MP_CHECKOK(mp_add(&r, m, &temp));
2232         MP_CHECKOK(mp_cswap((DIGIT(&r, 0) & 1), &r, &temp, top));
2233         s_mp_div_2(&r);
2234     }
2235 
2236 #if defined(_MSC_VER)
2237 #pragma warning(pop)
2238 #endif
2239 
2240     /* We have the inverse in v, propagate sign from f. */
2241     SIGN(&v) ^= SIGN(&f);
2242     /* GCD is in f, take the absolute value. */
2243     SIGN(&f) = ZPOS;
2244 
2245     /* If gcd != 1, not invertible. */
2246     if (mp_cmp_d(&f, 1) != MP_EQ) {
2247         res = MP_UNDEF;
2248         goto CLEANUP;
2249     }
2250 
2251     /* Return inverse modulo m. */
2252     MP_CHECKOK(mp_mod(&v, m, c));
2253 
2254 CLEANUP:
2255     while (last >= 0)
2256         mp_clear(clear[last--]);
2257     return res;
2258 }
2259 
2260 /* Known good algorithm for computing modular inverse.  But slow. */
2261 mp_err
mp_invmod_xgcd(const mp_int * a,const mp_int * m,mp_int * c)2262 mp_invmod_xgcd(const mp_int *a, const mp_int *m, mp_int *c)
2263 {
2264     mp_int g, x;
2265     mp_err res;
2266 
2267     ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
2268 
2269     if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
2270         return MP_RANGE;
2271 
2272     MP_DIGITS(&g) = 0;
2273     MP_DIGITS(&x) = 0;
2274     MP_CHECKOK(mp_init(&x));
2275     MP_CHECKOK(mp_init(&g));
2276 
2277     MP_CHECKOK(mp_xgcd(a, m, &g, &x, NULL));
2278 
2279     if (mp_cmp_d(&g, 1) != MP_EQ) {
2280         res = MP_UNDEF;
2281         goto CLEANUP;
2282     }
2283 
2284     res = mp_mod(&x, m, c);
2285     SIGN(c) = SIGN(a);
2286 
2287 CLEANUP:
2288     mp_clear(&x);
2289     mp_clear(&g);
2290 
2291     return res;
2292 }
2293 
2294 /* modular inverse where modulus is 2**k. */
2295 /* c = a**-1 mod 2**k */
2296 mp_err
s_mp_invmod_2d(const mp_int * a,mp_size k,mp_int * c)2297 s_mp_invmod_2d(const mp_int *a, mp_size k, mp_int *c)
2298 {
2299     mp_err res;
2300     mp_size ix = k + 4;
2301     mp_int t0, t1, val, tmp, two2k;
2302 
2303     static const mp_digit d2 = 2;
2304     static const mp_int two = { MP_ZPOS, 1, 1, (mp_digit *)&d2 };
2305 
2306     if (mp_iseven(a))
2307         return MP_UNDEF;
2308 
2309 #if defined(_MSC_VER)
2310 #pragma warning(push)
2311 #pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
2312 #endif
2313     if (k <= MP_DIGIT_BIT) {
2314         mp_digit i = s_mp_invmod_radix(MP_DIGIT(a, 0));
2315         /* propagate the sign from mp_int */
2316         i = (i ^ -(mp_digit)SIGN(a)) + (mp_digit)SIGN(a);
2317         if (k < MP_DIGIT_BIT)
2318             i &= ((mp_digit)1 << k) - (mp_digit)1;
2319         mp_set(c, i);
2320         return MP_OKAY;
2321     }
2322 #if defined(_MSC_VER)
2323 #pragma warning(pop)
2324 #endif
2325 
2326     MP_DIGITS(&t0) = 0;
2327     MP_DIGITS(&t1) = 0;
2328     MP_DIGITS(&val) = 0;
2329     MP_DIGITS(&tmp) = 0;
2330     MP_DIGITS(&two2k) = 0;
2331     MP_CHECKOK(mp_init_copy(&val, a));
2332     s_mp_mod_2d(&val, k);
2333     MP_CHECKOK(mp_init_copy(&t0, &val));
2334     MP_CHECKOK(mp_init_copy(&t1, &t0));
2335     MP_CHECKOK(mp_init(&tmp));
2336     MP_CHECKOK(mp_init(&two2k));
2337     MP_CHECKOK(s_mp_2expt(&two2k, k));
2338     do {
2339         MP_CHECKOK(mp_mul(&val, &t1, &tmp));
2340         MP_CHECKOK(mp_sub(&two, &tmp, &tmp));
2341         MP_CHECKOK(mp_mul(&t1, &tmp, &t1));
2342         s_mp_mod_2d(&t1, k);
2343         while (MP_SIGN(&t1) != MP_ZPOS) {
2344             MP_CHECKOK(mp_add(&t1, &two2k, &t1));
2345         }
2346         if (mp_cmp(&t1, &t0) == MP_EQ)
2347             break;
2348         MP_CHECKOK(mp_copy(&t1, &t0));
2349     } while (--ix > 0);
2350     if (!ix) {
2351         res = MP_UNDEF;
2352     } else {
2353         mp_exch(c, &t1);
2354     }
2355 
2356 CLEANUP:
2357     mp_clear(&t0);
2358     mp_clear(&t1);
2359     mp_clear(&val);
2360     mp_clear(&tmp);
2361     mp_clear(&two2k);
2362     return res;
2363 }
2364 
2365 mp_err
s_mp_invmod_even_m(const mp_int * a,const mp_int * m,mp_int * c)2366 s_mp_invmod_even_m(const mp_int *a, const mp_int *m, mp_int *c)
2367 {
2368     mp_err res;
2369     mp_size k;
2370     mp_int oddFactor, evenFactor; /* factors of the modulus */
2371     mp_int oddPart, evenPart;     /* parts to combine via CRT. */
2372     mp_int C2, tmp1, tmp2;
2373 
2374     ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
2375 
2376     /*static const mp_digit d1 = 1; */
2377     /*static const mp_int one = { MP_ZPOS, 1, 1, (mp_digit *)&d1 }; */
2378 
2379     if ((res = s_mp_ispow2(m)) >= 0) {
2380         k = res;
2381         return s_mp_invmod_2d(a, k, c);
2382     }
2383     MP_DIGITS(&oddFactor) = 0;
2384     MP_DIGITS(&evenFactor) = 0;
2385     MP_DIGITS(&oddPart) = 0;
2386     MP_DIGITS(&evenPart) = 0;
2387     MP_DIGITS(&C2) = 0;
2388     MP_DIGITS(&tmp1) = 0;
2389     MP_DIGITS(&tmp2) = 0;
2390 
2391     MP_CHECKOK(mp_init_copy(&oddFactor, m)); /* oddFactor = m */
2392     MP_CHECKOK(mp_init(&evenFactor));
2393     MP_CHECKOK(mp_init(&oddPart));
2394     MP_CHECKOK(mp_init(&evenPart));
2395     MP_CHECKOK(mp_init(&C2));
2396     MP_CHECKOK(mp_init(&tmp1));
2397     MP_CHECKOK(mp_init(&tmp2));
2398 
2399     k = mp_trailing_zeros(m);
2400     s_mp_div_2d(&oddFactor, k);
2401     MP_CHECKOK(s_mp_2expt(&evenFactor, k));
2402 
2403     /* compute a**-1 mod oddFactor. */
2404     MP_CHECKOK(s_mp_invmod_odd_m(a, &oddFactor, &oddPart));
2405     /* compute a**-1 mod evenFactor, where evenFactor == 2**k. */
2406     MP_CHECKOK(s_mp_invmod_2d(a, k, &evenPart));
2407 
2408     /* Use Chinese Remainer theorem to compute a**-1 mod m. */
2409     /* let m1 = oddFactor,  v1 = oddPart,
2410      * let m2 = evenFactor, v2 = evenPart.
2411      */
2412 
2413     /* Compute C2 = m1**-1 mod m2. */
2414     MP_CHECKOK(s_mp_invmod_2d(&oddFactor, k, &C2));
2415 
2416     /* compute u = (v2 - v1)*C2 mod m2 */
2417     MP_CHECKOK(mp_sub(&evenPart, &oddPart, &tmp1));
2418     MP_CHECKOK(mp_mul(&tmp1, &C2, &tmp2));
2419     s_mp_mod_2d(&tmp2, k);
2420     while (MP_SIGN(&tmp2) != MP_ZPOS) {
2421         MP_CHECKOK(mp_add(&tmp2, &evenFactor, &tmp2));
2422     }
2423 
2424     /* compute answer = v1 + u*m1 */
2425     MP_CHECKOK(mp_mul(&tmp2, &oddFactor, c));
2426     MP_CHECKOK(mp_add(&oddPart, c, c));
2427     /* not sure this is necessary, but it's low cost if not. */
2428     MP_CHECKOK(mp_mod(c, m, c));
2429 
2430 CLEANUP:
2431     mp_clear(&oddFactor);
2432     mp_clear(&evenFactor);
2433     mp_clear(&oddPart);
2434     mp_clear(&evenPart);
2435     mp_clear(&C2);
2436     mp_clear(&tmp1);
2437     mp_clear(&tmp2);
2438     return res;
2439 }
2440 
2441 /* {{{ mp_invmod(a, m, c) */
2442 
2443 /*
2444   mp_invmod(a, m, c)
2445 
2446   Compute c = a^-1 (mod m), if there is an inverse for a (mod m).
2447   This is equivalent to the question of whether (a, m) = 1.  If not,
2448   MP_UNDEF is returned, and there is no inverse.
2449  */
2450 
2451 mp_err
mp_invmod(const mp_int * a,const mp_int * m,mp_int * c)2452 mp_invmod(const mp_int *a, const mp_int *m, mp_int *c)
2453 {
2454     ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
2455 
2456     if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
2457         return MP_RANGE;
2458 
2459     if (mp_isodd(m)) {
2460         return s_mp_invmod_odd_m(a, m, c);
2461     }
2462     if (mp_iseven(a))
2463         return MP_UNDEF; /* not invertable */
2464 
2465     return s_mp_invmod_even_m(a, m, c);
2466 
2467 } /* end mp_invmod() */
2468 
2469 /* }}} */
2470 
2471 /* }}} */
2472 
2473 /*------------------------------------------------------------------------*/
2474 /* {{{ mp_print(mp, ofp) */
2475 
2476 #if MP_IOFUNC
2477 /*
2478   mp_print(mp, ofp)
2479 
2480   Print a textual representation of the given mp_int on the output
2481   stream 'ofp'.  Output is generated using the internal radix.
2482  */
2483 
2484 void
mp_print(mp_int * mp,FILE * ofp)2485 mp_print(mp_int *mp, FILE *ofp)
2486 {
2487     int ix;
2488 
2489     if (mp == NULL || ofp == NULL)
2490         return;
2491 
2492     fputc((SIGN(mp) == NEG) ? '-' : '+', ofp);
2493 
2494     for (ix = USED(mp) - 1; ix >= 0; ix--) {
2495         fprintf(ofp, DIGIT_FMT, DIGIT(mp, ix));
2496     }
2497 
2498 } /* end mp_print() */
2499 
2500 #endif /* if MP_IOFUNC */
2501 
2502 /* }}} */
2503 
2504 /*------------------------------------------------------------------------*/
2505 /* {{{ More I/O Functions */
2506 
2507 /* {{{ mp_read_raw(mp, str, len) */
2508 
2509 /*
2510    mp_read_raw(mp, str, len)
2511 
2512    Read in a raw value (base 256) into the given mp_int
2513  */
2514 
2515 mp_err
mp_read_raw(mp_int * mp,char * str,int len)2516 mp_read_raw(mp_int *mp, char *str, int len)
2517 {
2518     int ix;
2519     mp_err res;
2520     unsigned char *ustr = (unsigned char *)str;
2521 
2522     ARGCHK(mp != NULL && str != NULL && len > 0, MP_BADARG);
2523 
2524     mp_zero(mp);
2525 
2526     /* Get sign from first byte */
2527     if (ustr[0])
2528         SIGN(mp) = NEG;
2529     else
2530         SIGN(mp) = ZPOS;
2531 
2532     /* Read the rest of the digits */
2533     for (ix = 1; ix < len; ix++) {
2534         if ((res = mp_mul_d(mp, 256, mp)) != MP_OKAY)
2535             return res;
2536         if ((res = mp_add_d(mp, ustr[ix], mp)) != MP_OKAY)
2537             return res;
2538     }
2539 
2540     return MP_OKAY;
2541 
2542 } /* end mp_read_raw() */
2543 
2544 /* }}} */
2545 
2546 /* {{{ mp_raw_size(mp) */
2547 
2548 int
mp_raw_size(mp_int * mp)2549 mp_raw_size(mp_int *mp)
2550 {
2551     ARGCHK(mp != NULL, 0);
2552 
2553     return (USED(mp) * sizeof(mp_digit)) + 1;
2554 
2555 } /* end mp_raw_size() */
2556 
2557 /* }}} */
2558 
2559 /* {{{ mp_toraw(mp, str) */
2560 
2561 mp_err
mp_toraw(mp_int * mp,char * str)2562 mp_toraw(mp_int *mp, char *str)
2563 {
2564     int ix, jx, pos = 1;
2565 
2566     ARGCHK(mp != NULL && str != NULL, MP_BADARG);
2567 
2568     str[0] = (char)SIGN(mp);
2569 
2570     /* Iterate over each digit... */
2571     for (ix = USED(mp) - 1; ix >= 0; ix--) {
2572         mp_digit d = DIGIT(mp, ix);
2573 
2574         /* Unpack digit bytes, high order first */
2575         for (jx = sizeof(mp_digit) - 1; jx >= 0; jx--) {
2576             str[pos++] = (char)(d >> (jx * CHAR_BIT));
2577         }
2578     }
2579 
2580     return MP_OKAY;
2581 
2582 } /* end mp_toraw() */
2583 
2584 /* }}} */
2585 
2586 /* {{{ mp_read_radix(mp, str, radix) */
2587 
2588 /*
2589   mp_read_radix(mp, str, radix)
2590 
2591   Read an integer from the given string, and set mp to the resulting
2592   value.  The input is presumed to be in base 10.  Leading non-digit
2593   characters are ignored, and the function reads until a non-digit
2594   character or the end of the string.
2595  */
2596 
2597 mp_err
mp_read_radix(mp_int * mp,const char * str,int radix)2598 mp_read_radix(mp_int *mp, const char *str, int radix)
2599 {
2600     int ix = 0, val = 0;
2601     mp_err res;
2602     mp_sign sig = ZPOS;
2603 
2604     ARGCHK(mp != NULL && str != NULL && radix >= 2 && radix <= MAX_RADIX,
2605            MP_BADARG);
2606 
2607     mp_zero(mp);
2608 
2609     /* Skip leading non-digit characters until a digit or '-' or '+' */
2610     while (str[ix] &&
2611            (s_mp_tovalue(str[ix], radix) < 0) &&
2612            str[ix] != '-' &&
2613            str[ix] != '+') {
2614         ++ix;
2615     }
2616 
2617     if (str[ix] == '-') {
2618         sig = NEG;
2619         ++ix;
2620     } else if (str[ix] == '+') {
2621         sig = ZPOS; /* this is the default anyway... */
2622         ++ix;
2623     }
2624 
2625     while ((val = s_mp_tovalue(str[ix], radix)) >= 0) {
2626         if ((res = s_mp_mul_d(mp, radix)) != MP_OKAY)
2627             return res;
2628         if ((res = s_mp_add_d(mp, val)) != MP_OKAY)
2629             return res;
2630         ++ix;
2631     }
2632 
2633     if (s_mp_cmp_d(mp, 0) == MP_EQ)
2634         SIGN(mp) = ZPOS;
2635     else
2636         SIGN(mp) = sig;
2637 
2638     return MP_OKAY;
2639 
2640 } /* end mp_read_radix() */
2641 
2642 mp_err
mp_read_variable_radix(mp_int * a,const char * str,int default_radix)2643 mp_read_variable_radix(mp_int *a, const char *str, int default_radix)
2644 {
2645     int radix = default_radix;
2646     int cx;
2647     mp_sign sig = ZPOS;
2648     mp_err res;
2649 
2650     /* Skip leading non-digit characters until a digit or '-' or '+' */
2651     while ((cx = *str) != 0 &&
2652            (s_mp_tovalue(cx, radix) < 0) &&
2653            cx != '-' &&
2654            cx != '+') {
2655         ++str;
2656     }
2657 
2658     if (cx == '-') {
2659         sig = NEG;
2660         ++str;
2661     } else if (cx == '+') {
2662         sig = ZPOS; /* this is the default anyway... */
2663         ++str;
2664     }
2665 
2666     if (str[0] == '0') {
2667         if ((str[1] | 0x20) == 'x') {
2668             radix = 16;
2669             str += 2;
2670         } else {
2671             radix = 8;
2672             str++;
2673         }
2674     }
2675     res = mp_read_radix(a, str, radix);
2676     if (res == MP_OKAY) {
2677         MP_SIGN(a) = (s_mp_cmp_d(a, 0) == MP_EQ) ? ZPOS : sig;
2678     }
2679     return res;
2680 }
2681 
2682 /* }}} */
2683 
2684 /* {{{ mp_radix_size(mp, radix) */
2685 
2686 int
mp_radix_size(mp_int * mp,int radix)2687 mp_radix_size(mp_int *mp, int radix)
2688 {
2689     int bits;
2690 
2691     if (!mp || radix < 2 || radix > MAX_RADIX)
2692         return 0;
2693 
2694     bits = USED(mp) * DIGIT_BIT - 1;
2695 
2696     return SIGN(mp) + s_mp_outlen(bits, radix);
2697 
2698 } /* end mp_radix_size() */
2699 
2700 /* }}} */
2701 
2702 /* {{{ mp_toradix(mp, str, radix) */
2703 
2704 mp_err
mp_toradix(mp_int * mp,char * str,int radix)2705 mp_toradix(mp_int *mp, char *str, int radix)
2706 {
2707     int ix, pos = 0;
2708 
2709     ARGCHK(mp != NULL && str != NULL, MP_BADARG);
2710     ARGCHK(radix > 1 && radix <= MAX_RADIX, MP_RANGE);
2711 
2712     if (mp_cmp_z(mp) == MP_EQ) {
2713         str[0] = '0';
2714         str[1] = '\0';
2715     } else {
2716         mp_err res;
2717         mp_int tmp;
2718         mp_sign sgn;
2719         mp_digit rem, rdx = (mp_digit)radix;
2720         char ch;
2721 
2722         if ((res = mp_init_copy(&tmp, mp)) != MP_OKAY)
2723             return res;
2724 
2725         /* Save sign for later, and take absolute value */
2726         sgn = SIGN(&tmp);
2727         SIGN(&tmp) = ZPOS;
2728 
2729         /* Generate output digits in reverse order      */
2730         while (mp_cmp_z(&tmp) != 0) {
2731             if ((res = mp_div_d(&tmp, rdx, &tmp, &rem)) != MP_OKAY) {
2732                 mp_clear(&tmp);
2733                 return res;
2734             }
2735 
2736             /* Generate digits, use capital letters */
2737             ch = s_mp_todigit(rem, radix, 0);
2738 
2739             str[pos++] = ch;
2740         }
2741 
2742         /* Add - sign if original value was negative */
2743         if (sgn == NEG)
2744             str[pos++] = '-';
2745 
2746         /* Add trailing NUL to end the string        */
2747         str[pos--] = '\0';
2748 
2749         /* Reverse the digits and sign indicator     */
2750         ix = 0;
2751         while (ix < pos) {
2752             char tmpc = str[ix];
2753 
2754             str[ix] = str[pos];
2755             str[pos] = tmpc;
2756             ++ix;
2757             --pos;
2758         }
2759 
2760         mp_clear(&tmp);
2761     }
2762 
2763     return MP_OKAY;
2764 
2765 } /* end mp_toradix() */
2766 
2767 /* }}} */
2768 
2769 /* {{{ mp_tovalue(ch, r) */
2770 
2771 int
mp_tovalue(char ch,int r)2772 mp_tovalue(char ch, int r)
2773 {
2774     return s_mp_tovalue(ch, r);
2775 
2776 } /* end mp_tovalue() */
2777 
2778 /* }}} */
2779 
2780 /* }}} */
2781 
2782 /* {{{ mp_strerror(ec) */
2783 
2784 /*
2785   mp_strerror(ec)
2786 
2787   Return a string describing the meaning of error code 'ec'.  The
2788   string returned is allocated in static memory, so the caller should
2789   not attempt to modify or free the memory associated with this
2790   string.
2791  */
2792 const char *
mp_strerror(mp_err ec)2793 mp_strerror(mp_err ec)
2794 {
2795     int aec = (ec < 0) ? -ec : ec;
2796 
2797     /* Code values are negative, so the senses of these comparisons
2798      are accurate */
2799     if (ec < MP_LAST_CODE || ec > MP_OKAY) {
2800         return mp_err_string[0]; /* unknown error code */
2801     } else {
2802         return mp_err_string[aec + 1];
2803     }
2804 
2805 } /* end mp_strerror() */
2806 
2807 /* }}} */
2808 
2809 /*========================================================================*/
2810 /*------------------------------------------------------------------------*/
2811 /* Static function definitions (internal use only)                        */
2812 
2813 /* {{{ Memory management */
2814 
2815 /* {{{ s_mp_grow(mp, min) */
2816 
2817 /* Make sure there are at least 'min' digits allocated to mp              */
2818 mp_err
s_mp_grow(mp_int * mp,mp_size min)2819 s_mp_grow(mp_int *mp, mp_size min)
2820 {
2821     ARGCHK(mp != NULL, MP_BADARG);
2822 
2823     if (min > ALLOC(mp)) {
2824         mp_digit *tmp;
2825 
2826         /* Set min to next nearest default precision block size */
2827         min = MP_ROUNDUP(min, s_mp_defprec);
2828 
2829         if ((tmp = s_mp_alloc(min, sizeof(mp_digit))) == NULL)
2830             return MP_MEM;
2831 
2832         s_mp_copy(DIGITS(mp), tmp, USED(mp));
2833 
2834         s_mp_setz(DIGITS(mp), ALLOC(mp));
2835         s_mp_free(DIGITS(mp));
2836         DIGITS(mp) = tmp;
2837         ALLOC(mp) = min;
2838     }
2839 
2840     return MP_OKAY;
2841 
2842 } /* end s_mp_grow() */
2843 
2844 /* }}} */
2845 
2846 /* {{{ s_mp_pad(mp, min) */
2847 
2848 /* Make sure the used size of mp is at least 'min', growing if needed     */
2849 mp_err
s_mp_pad(mp_int * mp,mp_size min)2850 s_mp_pad(mp_int *mp, mp_size min)
2851 {
2852     ARGCHK(mp != NULL, MP_BADARG);
2853 
2854     if (min > USED(mp)) {
2855         mp_err res;
2856 
2857         /* Make sure there is room to increase precision  */
2858         if (min > ALLOC(mp)) {
2859             if ((res = s_mp_grow(mp, min)) != MP_OKAY)
2860                 return res;
2861         } else {
2862             s_mp_setz(DIGITS(mp) + USED(mp), min - USED(mp));
2863         }
2864 
2865         /* Increase precision; should already be 0-filled */
2866         USED(mp) = min;
2867     }
2868 
2869     return MP_OKAY;
2870 
2871 } /* end s_mp_pad() */
2872 
2873 /* }}} */
2874 
2875 /* {{{ s_mp_setz(dp, count) */
2876 
2877 /* Set 'count' digits pointed to by dp to be zeroes                       */
2878 void
s_mp_setz(mp_digit * dp,mp_size count)2879 s_mp_setz(mp_digit *dp, mp_size count)
2880 {
2881     memset(dp, 0, count * sizeof(mp_digit));
2882 } /* end s_mp_setz() */
2883 
2884 /* }}} */
2885 
2886 /* {{{ s_mp_copy(sp, dp, count) */
2887 
2888 /* Copy 'count' digits from sp to dp                                      */
2889 void
s_mp_copy(const mp_digit * sp,mp_digit * dp,mp_size count)2890 s_mp_copy(const mp_digit *sp, mp_digit *dp, mp_size count)
2891 {
2892     memcpy(dp, sp, count * sizeof(mp_digit));
2893 } /* end s_mp_copy() */
2894 
2895 /* }}} */
2896 
2897 /* {{{ s_mp_alloc(nb, ni) */
2898 
2899 /* Allocate ni records of nb bytes each, and return a pointer to that     */
2900 void *
s_mp_alloc(size_t nb,size_t ni)2901 s_mp_alloc(size_t nb, size_t ni)
2902 {
2903     return calloc(nb, ni);
2904 
2905 } /* end s_mp_alloc() */
2906 
2907 /* }}} */
2908 
2909 /* {{{ s_mp_free(ptr) */
2910 
2911 /* Free the memory pointed to by ptr                                      */
2912 void
s_mp_free(void * ptr)2913 s_mp_free(void *ptr)
2914 {
2915     if (ptr) {
2916         free(ptr);
2917     }
2918 } /* end s_mp_free() */
2919 
2920 /* }}} */
2921 
2922 /* {{{ s_mp_clamp(mp) */
2923 
2924 /* Remove leading zeroes from the given value                             */
2925 void
s_mp_clamp(mp_int * mp)2926 s_mp_clamp(mp_int *mp)
2927 {
2928     mp_size used = MP_USED(mp);
2929     while (used > 1 && DIGIT(mp, used - 1) == 0)
2930         --used;
2931     MP_USED(mp) = used;
2932     if (used == 1 && DIGIT(mp, 0) == 0)
2933         MP_SIGN(mp) = ZPOS;
2934 } /* end s_mp_clamp() */
2935 
2936 /* }}} */
2937 
2938 /* {{{ s_mp_exch(a, b) */
2939 
2940 /* Exchange the data for a and b; (b, a) = (a, b)                         */
2941 void
s_mp_exch(mp_int * a,mp_int * b)2942 s_mp_exch(mp_int *a, mp_int *b)
2943 {
2944     mp_int tmp;
2945     if (!a || !b) {
2946         return;
2947     }
2948 
2949     tmp = *a;
2950     *a = *b;
2951     *b = tmp;
2952 
2953 } /* end s_mp_exch() */
2954 
2955 /* }}} */
2956 
2957 /* }}} */
2958 
2959 /* {{{ Arithmetic helpers */
2960 
2961 /* {{{ s_mp_lshd(mp, p) */
2962 
2963 /*
2964    Shift mp leftward by p digits, growing if needed, and zero-filling
2965    the in-shifted digits at the right end.  This is a convenient
2966    alternative to multiplication by powers of the radix
2967  */
2968 
2969 mp_err
s_mp_lshd(mp_int * mp,mp_size p)2970 s_mp_lshd(mp_int *mp, mp_size p)
2971 {
2972     mp_err res;
2973     unsigned int ix;
2974 
2975     ARGCHK(mp != NULL, MP_BADARG);
2976 
2977     if (p == 0)
2978         return MP_OKAY;
2979 
2980     if (MP_USED(mp) == 1 && MP_DIGIT(mp, 0) == 0)
2981         return MP_OKAY;
2982 
2983     if ((res = s_mp_pad(mp, USED(mp) + p)) != MP_OKAY)
2984         return res;
2985 
2986     /* Shift all the significant figures over as needed */
2987     for (ix = USED(mp) - p; ix-- > 0;) {
2988         DIGIT(mp, ix + p) = DIGIT(mp, ix);
2989     }
2990 
2991     /* Fill the bottom digits with zeroes */
2992     for (ix = 0; (mp_size)ix < p; ix++)
2993         DIGIT(mp, ix) = 0;
2994 
2995     return MP_OKAY;
2996 
2997 } /* end s_mp_lshd() */
2998 
2999 /* }}} */
3000 
3001 /* {{{ s_mp_mul_2d(mp, d) */
3002 
3003 /*
3004   Multiply the integer by 2^d, where d is a number of bits.  This
3005   amounts to a bitwise shift of the value.
3006  */
3007 mp_err
s_mp_mul_2d(mp_int * mp,mp_digit d)3008 s_mp_mul_2d(mp_int *mp, mp_digit d)
3009 {
3010     mp_err res;
3011     mp_digit dshift, rshift, mask, x, prev = 0;
3012     mp_digit *pa = NULL;
3013     int i;
3014 
3015     ARGCHK(mp != NULL, MP_BADARG);
3016 
3017     dshift = d / MP_DIGIT_BIT;
3018     d %= MP_DIGIT_BIT;
3019     /* mp_digit >> rshift is undefined behavior for rshift >= MP_DIGIT_BIT */
3020     /* mod and corresponding mask logic avoid that when d = 0 */
3021     rshift = MP_DIGIT_BIT - d;
3022     rshift %= MP_DIGIT_BIT;
3023     /* mask = (2**d - 1) * 2**(w-d) mod 2**w */
3024     mask = (DIGIT_MAX << rshift) + 1;
3025     mask &= DIGIT_MAX - 1;
3026     /* bits to be shifted out of the top word */
3027     x = MP_DIGIT(mp, MP_USED(mp) - 1) & mask;
3028 
3029     if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (x != 0))))
3030         return res;
3031 
3032     if (dshift && MP_OKAY != (res = s_mp_lshd(mp, dshift)))
3033         return res;
3034 
3035     pa = MP_DIGITS(mp) + dshift;
3036 
3037     for (i = MP_USED(mp) - dshift; i > 0; i--) {
3038         x = *pa;
3039         *pa++ = (x << d) | prev;
3040         prev = (x & mask) >> rshift;
3041     }
3042 
3043     s_mp_clamp(mp);
3044     return MP_OKAY;
3045 } /* end s_mp_mul_2d() */
3046 
3047 /* {{{ s_mp_rshd(mp, p) */
3048 
3049 /*
3050    Shift mp rightward by p digits.  Maintains the invariant that
3051    digits above the precision are all zero.  Digits shifted off the
3052    end are lost.  Cannot fail.
3053  */
3054 
3055 void
s_mp_rshd(mp_int * mp,mp_size p)3056 s_mp_rshd(mp_int *mp, mp_size p)
3057 {
3058     mp_size ix;
3059     mp_digit *src, *dst;
3060 
3061     if (p == 0)
3062         return;
3063 
3064     /* Shortcut when all digits are to be shifted off */
3065     if (p >= USED(mp)) {
3066         s_mp_setz(DIGITS(mp), ALLOC(mp));
3067         USED(mp) = 1;
3068         SIGN(mp) = ZPOS;
3069         return;
3070     }
3071 
3072     /* Shift all the significant figures over as needed */
3073     dst = MP_DIGITS(mp);
3074     src = dst + p;
3075     for (ix = USED(mp) - p; ix > 0; ix--)
3076         *dst++ = *src++;
3077 
3078     MP_USED(mp) -= p;
3079     /* Fill the top digits with zeroes */
3080     while (p-- > 0)
3081         *dst++ = 0;
3082 
3083 } /* end s_mp_rshd() */
3084 
3085 /* }}} */
3086 
3087 /* {{{ s_mp_div_2(mp) */
3088 
3089 /* Divide by two -- take advantage of radix properties to do it fast      */
3090 void
s_mp_div_2(mp_int * mp)3091 s_mp_div_2(mp_int *mp)
3092 {
3093     s_mp_div_2d(mp, 1);
3094 
3095 } /* end s_mp_div_2() */
3096 
3097 /* }}} */
3098 
3099 /* {{{ s_mp_mul_2(mp) */
3100 
3101 mp_err
s_mp_mul_2(mp_int * mp)3102 s_mp_mul_2(mp_int *mp)
3103 {
3104     mp_digit *pd;
3105     unsigned int ix, used;
3106     mp_digit kin = 0;
3107 
3108     ARGCHK(mp != NULL, MP_BADARG);
3109 
3110     /* Shift digits leftward by 1 bit */
3111     used = MP_USED(mp);
3112     pd = MP_DIGITS(mp);
3113     for (ix = 0; ix < used; ix++) {
3114         mp_digit d = *pd;
3115         *pd++ = (d << 1) | kin;
3116         kin = (d >> (DIGIT_BIT - 1));
3117     }
3118 
3119     /* Deal with rollover from last digit */
3120     if (kin) {
3121         if (ix >= ALLOC(mp)) {
3122             mp_err res;
3123             if ((res = s_mp_grow(mp, ALLOC(mp) + 1)) != MP_OKAY)
3124                 return res;
3125         }
3126 
3127         DIGIT(mp, ix) = kin;
3128         USED(mp) += 1;
3129     }
3130 
3131     return MP_OKAY;
3132 
3133 } /* end s_mp_mul_2() */
3134 
3135 /* }}} */
3136 
3137 /* {{{ s_mp_mod_2d(mp, d) */
3138 
3139 /*
3140   Remainder the integer by 2^d, where d is a number of bits.  This
3141   amounts to a bitwise AND of the value, and does not require the full
3142   division code
3143  */
3144 void
s_mp_mod_2d(mp_int * mp,mp_digit d)3145 s_mp_mod_2d(mp_int *mp, mp_digit d)
3146 {
3147     mp_size ndig = (d / DIGIT_BIT), nbit = (d % DIGIT_BIT);
3148     mp_size ix;
3149     mp_digit dmask;
3150 
3151     if (ndig >= USED(mp))
3152         return;
3153 
3154     /* Flush all the bits above 2^d in its digit */
3155     dmask = ((mp_digit)1 << nbit) - 1;
3156     DIGIT(mp, ndig) &= dmask;
3157 
3158     /* Flush all digits above the one with 2^d in it */
3159     for (ix = ndig + 1; ix < USED(mp); ix++)
3160         DIGIT(mp, ix) = 0;
3161 
3162     s_mp_clamp(mp);
3163 
3164 } /* end s_mp_mod_2d() */
3165 
3166 /* }}} */
3167 
3168 /* {{{ s_mp_div_2d(mp, d) */
3169 
3170 /*
3171   Divide the integer by 2^d, where d is a number of bits.  This
3172   amounts to a bitwise shift of the value, and does not require the
3173   full division code (used in Barrett reduction, see below)
3174  */
3175 void
s_mp_div_2d(mp_int * mp,mp_digit d)3176 s_mp_div_2d(mp_int *mp, mp_digit d)
3177 {
3178     int ix;
3179     mp_digit save, next, mask, lshift;
3180 
3181     s_mp_rshd(mp, d / DIGIT_BIT);
3182     d %= DIGIT_BIT;
3183     /* mp_digit << lshift is undefined behavior for lshift >= MP_DIGIT_BIT */
3184     /* mod and corresponding mask logic avoid that when d = 0 */
3185     lshift = DIGIT_BIT - d;
3186     lshift %= DIGIT_BIT;
3187     mask = ((mp_digit)1 << d) - 1;
3188     save = 0;
3189     for (ix = USED(mp) - 1; ix >= 0; ix--) {
3190         next = DIGIT(mp, ix) & mask;
3191         DIGIT(mp, ix) = (save << lshift) | (DIGIT(mp, ix) >> d);
3192         save = next;
3193     }
3194     s_mp_clamp(mp);
3195 
3196 } /* end s_mp_div_2d() */
3197 
3198 /* }}} */
3199 
3200 /* {{{ s_mp_norm(a, b, *d) */
3201 
3202 /*
3203   s_mp_norm(a, b, *d)
3204 
3205   Normalize a and b for division, where b is the divisor.  In order
3206   that we might make good guesses for quotient digits, we want the
3207   leading digit of b to be at least half the radix, which we
3208   accomplish by multiplying a and b by a power of 2.  The exponent
3209   (shift count) is placed in *pd, so that the remainder can be shifted
3210   back at the end of the division process.
3211  */
3212 
3213 mp_err
s_mp_norm(mp_int * a,mp_int * b,mp_digit * pd)3214 s_mp_norm(mp_int *a, mp_int *b, mp_digit *pd)
3215 {
3216     mp_digit d;
3217     mp_digit mask;
3218     mp_digit b_msd;
3219     mp_err res = MP_OKAY;
3220 
3221     ARGCHK(a != NULL && b != NULL && pd != NULL, MP_BADARG);
3222 
3223     d = 0;
3224     mask = DIGIT_MAX & ~(DIGIT_MAX >> 1); /* mask is msb of digit */
3225     b_msd = DIGIT(b, USED(b) - 1);
3226     while (!(b_msd & mask)) {
3227         b_msd <<= 1;
3228         ++d;
3229     }
3230 
3231     if (d) {
3232         MP_CHECKOK(s_mp_mul_2d(a, d));
3233         MP_CHECKOK(s_mp_mul_2d(b, d));
3234     }
3235 
3236     *pd = d;
3237 CLEANUP:
3238     return res;
3239 
3240 } /* end s_mp_norm() */
3241 
3242 /* }}} */
3243 
3244 /* }}} */
3245 
3246 /* {{{ Primitive digit arithmetic */
3247 
3248 /* {{{ s_mp_add_d(mp, d) */
3249 
3250 /* Add d to |mp| in place                                                 */
s_mp_add_d(mp_int * mp,mp_digit d)3251 mp_err s_mp_add_d(mp_int *mp, mp_digit d) /* unsigned digit addition */
3252 {
3253 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3254     mp_word w, k = 0;
3255     mp_size ix = 1;
3256 
3257     w = (mp_word)DIGIT(mp, 0) + d;
3258     DIGIT(mp, 0) = ACCUM(w);
3259     k = CARRYOUT(w);
3260 
3261     while (ix < USED(mp) && k) {
3262         w = (mp_word)DIGIT(mp, ix) + k;
3263         DIGIT(mp, ix) = ACCUM(w);
3264         k = CARRYOUT(w);
3265         ++ix;
3266     }
3267 
3268     if (k != 0) {
3269         mp_err res;
3270 
3271         if ((res = s_mp_pad(mp, USED(mp) + 1)) != MP_OKAY)
3272             return res;
3273 
3274         DIGIT(mp, ix) = (mp_digit)k;
3275     }
3276 
3277     return MP_OKAY;
3278 #else
3279     mp_digit *pmp = MP_DIGITS(mp);
3280     mp_digit sum, mp_i, carry = 0;
3281     mp_err res = MP_OKAY;
3282     int used = (int)MP_USED(mp);
3283 
3284     mp_i = *pmp;
3285     *pmp++ = sum = d + mp_i;
3286     carry = (sum < d);
3287     while (carry && --used > 0) {
3288         mp_i = *pmp;
3289         *pmp++ = sum = carry + mp_i;
3290         carry = !sum;
3291     }
3292     if (carry && !used) {
3293         /* mp is growing */
3294         used = MP_USED(mp);
3295         MP_CHECKOK(s_mp_pad(mp, used + 1));
3296         MP_DIGIT(mp, used) = carry;
3297     }
3298 CLEANUP:
3299     return res;
3300 #endif
3301 } /* end s_mp_add_d() */
3302 
3303 /* }}} */
3304 
3305 /* {{{ s_mp_sub_d(mp, d) */
3306 
3307 /* Subtract d from |mp| in place, assumes |mp| > d                        */
s_mp_sub_d(mp_int * mp,mp_digit d)3308 mp_err s_mp_sub_d(mp_int *mp, mp_digit d) /* unsigned digit subtract */
3309 {
3310 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3311     mp_word w, b = 0;
3312     mp_size ix = 1;
3313 
3314     /* Compute initial subtraction    */
3315     w = (RADIX + (mp_word)DIGIT(mp, 0)) - d;
3316     b = CARRYOUT(w) ? 0 : 1;
3317     DIGIT(mp, 0) = ACCUM(w);
3318 
3319     /* Propagate borrows leftward     */
3320     while (b && ix < USED(mp)) {
3321         w = (RADIX + (mp_word)DIGIT(mp, ix)) - b;
3322         b = CARRYOUT(w) ? 0 : 1;
3323         DIGIT(mp, ix) = ACCUM(w);
3324         ++ix;
3325     }
3326 
3327     /* Remove leading zeroes          */
3328     s_mp_clamp(mp);
3329 
3330     /* If we have a borrow out, it's a violation of the input invariant */
3331     if (b)
3332         return MP_RANGE;
3333     else
3334         return MP_OKAY;
3335 #else
3336     mp_digit *pmp = MP_DIGITS(mp);
3337     mp_digit mp_i, diff, borrow;
3338     mp_size used = MP_USED(mp);
3339 
3340     mp_i = *pmp;
3341     *pmp++ = diff = mp_i - d;
3342     borrow = (diff > mp_i);
3343     while (borrow && --used) {
3344         mp_i = *pmp;
3345         *pmp++ = diff = mp_i - borrow;
3346         borrow = (diff > mp_i);
3347     }
3348     s_mp_clamp(mp);
3349     return (borrow && !used) ? MP_RANGE : MP_OKAY;
3350 #endif
3351 } /* end s_mp_sub_d() */
3352 
3353 /* }}} */
3354 
3355 /* {{{ s_mp_mul_d(a, d) */
3356 
3357 /* Compute a = a * d, single digit multiplication                         */
3358 mp_err
s_mp_mul_d(mp_int * a,mp_digit d)3359 s_mp_mul_d(mp_int *a, mp_digit d)
3360 {
3361     mp_err res;
3362     mp_size used;
3363     int pow;
3364 
3365     if (!d) {
3366         mp_zero(a);
3367         return MP_OKAY;
3368     }
3369     if (d == 1)
3370         return MP_OKAY;
3371     if (0 <= (pow = s_mp_ispow2d(d))) {
3372         return s_mp_mul_2d(a, (mp_digit)pow);
3373     }
3374 
3375     used = MP_USED(a);
3376     MP_CHECKOK(s_mp_pad(a, used + 1));
3377 
3378     s_mpv_mul_d(MP_DIGITS(a), used, d, MP_DIGITS(a));
3379 
3380     s_mp_clamp(a);
3381 
3382 CLEANUP:
3383     return res;
3384 
3385 } /* end s_mp_mul_d() */
3386 
3387 /* }}} */
3388 
3389 /* {{{ s_mp_div_d(mp, d, r) */
3390 
3391 /*
3392   s_mp_div_d(mp, d, r)
3393 
3394   Compute the quotient mp = mp / d and remainder r = mp mod d, for a
3395   single digit d.  If r is null, the remainder will be discarded.
3396  */
3397 
3398 mp_err
s_mp_div_d(mp_int * mp,mp_digit d,mp_digit * r)3399 s_mp_div_d(mp_int *mp, mp_digit d, mp_digit *r)
3400 {
3401 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_DIV_WORD)
3402     mp_word w = 0, q;
3403 #else
3404     mp_digit w = 0, q;
3405 #endif
3406     int ix;
3407     mp_err res;
3408     mp_int quot;
3409     mp_int rem;
3410 
3411     if (d == 0)
3412         return MP_RANGE;
3413     if (d == 1) {
3414         if (r)
3415             *r = 0;
3416         return MP_OKAY;
3417     }
3418     /* could check for power of 2 here, but mp_div_d does that. */
3419     if (MP_USED(mp) == 1) {
3420         mp_digit n = MP_DIGIT(mp, 0);
3421         mp_digit remdig;
3422 
3423         q = n / d;
3424         remdig = n % d;
3425         MP_DIGIT(mp, 0) = q;
3426         if (r) {
3427             *r = remdig;
3428         }
3429         return MP_OKAY;
3430     }
3431 
3432     MP_DIGITS(&rem) = 0;
3433     MP_DIGITS(&quot) = 0;
3434     /* Make room for the quotient */
3435     MP_CHECKOK(mp_init_size(&quot, USED(mp)));
3436 
3437 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_DIV_WORD)
3438     for (ix = USED(mp) - 1; ix >= 0; ix--) {
3439         w = (w << DIGIT_BIT) | DIGIT(mp, ix);
3440 
3441         if (w >= d) {
3442             q = w / d;
3443             w = w % d;
3444         } else {
3445             q = 0;
3446         }
3447 
3448         s_mp_lshd(&quot, 1);
3449         DIGIT(&quot, 0) = (mp_digit)q;
3450     }
3451 #else
3452     {
3453         mp_digit p;
3454 #if !defined(MP_ASSEMBLY_DIV_2DX1D)
3455         mp_digit norm;
3456 #endif
3457 
3458         MP_CHECKOK(mp_init_copy(&rem, mp));
3459 
3460 #if !defined(MP_ASSEMBLY_DIV_2DX1D)
3461         MP_DIGIT(&quot, 0) = d;
3462         MP_CHECKOK(s_mp_norm(&rem, &quot, &norm));
3463         if (norm)
3464             d <<= norm;
3465         MP_DIGIT(&quot, 0) = 0;
3466 #endif
3467 
3468         p = 0;
3469         for (ix = USED(&rem) - 1; ix >= 0; ix--) {
3470             w = DIGIT(&rem, ix);
3471 
3472             if (p) {
3473                 MP_CHECKOK(s_mpv_div_2dx1d(p, w, d, &q, &w));
3474             } else if (w >= d) {
3475                 q = w / d;
3476                 w = w % d;
3477             } else {
3478                 q = 0;
3479             }
3480 
3481             MP_CHECKOK(s_mp_lshd(&quot, 1));
3482             DIGIT(&quot, 0) = q;
3483             p = w;
3484         }
3485 #if !defined(MP_ASSEMBLY_DIV_2DX1D)
3486         if (norm)
3487             w >>= norm;
3488 #endif
3489     }
3490 #endif
3491 
3492     /* Deliver the remainder, if desired */
3493     if (r) {
3494         *r = (mp_digit)w;
3495     }
3496 
3497     s_mp_clamp(&quot);
3498     mp_exch(&quot, mp);
3499 CLEANUP:
3500     mp_clear(&quot);
3501     mp_clear(&rem);
3502 
3503     return res;
3504 } /* end s_mp_div_d() */
3505 
3506 /* }}} */
3507 
3508 /* }}} */
3509 
3510 /* {{{ Primitive full arithmetic */
3511 
3512 /* {{{ s_mp_add(a, b) */
3513 
3514 /* Compute a = |a| + |b|                                                  */
s_mp_add(mp_int * a,const mp_int * b)3515 mp_err s_mp_add(mp_int *a, const mp_int *b) /* magnitude addition      */
3516 {
3517 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3518     mp_word w = 0;
3519 #else
3520     mp_digit d, sum, carry = 0;
3521 #endif
3522     mp_digit *pa, *pb;
3523     mp_size ix;
3524     mp_size used;
3525     mp_err res;
3526 
3527     /* Make sure a has enough precision for the output value */
3528     if ((USED(b) > USED(a)) && (res = s_mp_pad(a, USED(b))) != MP_OKAY)
3529         return res;
3530 
3531     /*
3532       Add up all digits up to the precision of b.  If b had initially
3533       the same precision as a, or greater, we took care of it by the
3534       padding step above, so there is no problem.  If b had initially
3535       less precision, we'll have to make sure the carry out is duly
3536       propagated upward among the higher-order digits of the sum.
3537      */
3538     pa = MP_DIGITS(a);
3539     pb = MP_DIGITS(b);
3540     used = MP_USED(b);
3541     for (ix = 0; ix < used; ix++) {
3542 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3543         w = w + *pa + *pb++;
3544         *pa++ = ACCUM(w);
3545         w = CARRYOUT(w);
3546 #else
3547         d = *pa;
3548         sum = d + *pb++;
3549         d = (sum < d); /* detect overflow */
3550         *pa++ = sum += carry;
3551         carry = d + (sum < carry); /* detect overflow */
3552 #endif
3553     }
3554 
3555     /* If we run out of 'b' digits before we're actually done, make
3556        sure the carries get propagated upward...
3557      */
3558     used = MP_USED(a);
3559 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3560     while (w && ix < used) {
3561         w = w + *pa;
3562         *pa++ = ACCUM(w);
3563         w = CARRYOUT(w);
3564         ++ix;
3565     }
3566 #else
3567     while (carry && ix < used) {
3568         sum = carry + *pa;
3569         *pa++ = sum;
3570         carry = !sum;
3571         ++ix;
3572     }
3573 #endif
3574 
3575 /* If there's an overall carry out, increase precision and include
3576      it.  We could have done this initially, but why touch the memory
3577      allocator unless we're sure we have to?
3578    */
3579 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3580     if (w) {
3581         if ((res = s_mp_pad(a, used + 1)) != MP_OKAY)
3582             return res;
3583 
3584         DIGIT(a, ix) = (mp_digit)w;
3585     }
3586 #else
3587     if (carry) {
3588         if ((res = s_mp_pad(a, used + 1)) != MP_OKAY)
3589             return res;
3590 
3591         DIGIT(a, used) = carry;
3592     }
3593 #endif
3594 
3595     return MP_OKAY;
3596 } /* end s_mp_add() */
3597 
3598 /* }}} */
3599 
3600 /* Compute c = |a| + |b|         */ /* magnitude addition      */
3601 mp_err
s_mp_add_3arg(const mp_int * a,const mp_int * b,mp_int * c)3602 s_mp_add_3arg(const mp_int *a, const mp_int *b, mp_int *c)
3603 {
3604     mp_digit *pa, *pb, *pc;
3605 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3606     mp_word w = 0;
3607 #else
3608     mp_digit sum, carry = 0, d;
3609 #endif
3610     mp_size ix;
3611     mp_size used;
3612     mp_err res;
3613 
3614     MP_SIGN(c) = MP_SIGN(a);
3615     if (MP_USED(a) < MP_USED(b)) {
3616         const mp_int *xch = a;
3617         a = b;
3618         b = xch;
3619     }
3620 
3621     /* Make sure a has enough precision for the output value */
3622     if (MP_OKAY != (res = s_mp_pad(c, MP_USED(a))))
3623         return res;
3624 
3625     /*
3626      Add up all digits up to the precision of b.  If b had initially
3627      the same precision as a, or greater, we took care of it by the
3628      exchange step above, so there is no problem.  If b had initially
3629      less precision, we'll have to make sure the carry out is duly
3630      propagated upward among the higher-order digits of the sum.
3631     */
3632     pa = MP_DIGITS(a);
3633     pb = MP_DIGITS(b);
3634     pc = MP_DIGITS(c);
3635     used = MP_USED(b);
3636     for (ix = 0; ix < used; ix++) {
3637 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3638         w = w + *pa++ + *pb++;
3639         *pc++ = ACCUM(w);
3640         w = CARRYOUT(w);
3641 #else
3642         d = *pa++;
3643         sum = d + *pb++;
3644         d = (sum < d); /* detect overflow */
3645         *pc++ = sum += carry;
3646         carry = d + (sum < carry); /* detect overflow */
3647 #endif
3648     }
3649 
3650     /* If we run out of 'b' digits before we're actually done, make
3651      sure the carries get propagated upward...
3652    */
3653     for (used = MP_USED(a); ix < used; ++ix) {
3654 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3655         w = w + *pa++;
3656         *pc++ = ACCUM(w);
3657         w = CARRYOUT(w);
3658 #else
3659         *pc++ = sum = carry + *pa++;
3660         carry = (sum < carry);
3661 #endif
3662     }
3663 
3664 /* If there's an overall carry out, increase precision and include
3665      it.  We could have done this initially, but why touch the memory
3666      allocator unless we're sure we have to?
3667    */
3668 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3669     if (w) {
3670         if ((res = s_mp_pad(c, used + 1)) != MP_OKAY)
3671             return res;
3672 
3673         DIGIT(c, used) = (mp_digit)w;
3674         ++used;
3675     }
3676 #else
3677     if (carry) {
3678         if ((res = s_mp_pad(c, used + 1)) != MP_OKAY)
3679             return res;
3680 
3681         DIGIT(c, used) = carry;
3682         ++used;
3683     }
3684 #endif
3685     MP_USED(c) = used;
3686     return MP_OKAY;
3687 }
3688 /* {{{ s_mp_add_offset(a, b, offset) */
3689 
3690 /* Compute a = |a| + ( |b| * (RADIX ** offset) )             */
3691 mp_err
s_mp_add_offset(mp_int * a,mp_int * b,mp_size offset)3692 s_mp_add_offset(mp_int *a, mp_int *b, mp_size offset)
3693 {
3694 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3695     mp_word w, k = 0;
3696 #else
3697     mp_digit d, sum, carry = 0;
3698 #endif
3699     mp_size ib;
3700     mp_size ia;
3701     mp_size lim;
3702     mp_err res;
3703 
3704     /* Make sure a has enough precision for the output value */
3705     lim = MP_USED(b) + offset;
3706     if ((lim > USED(a)) && (res = s_mp_pad(a, lim)) != MP_OKAY)
3707         return res;
3708 
3709     /*
3710     Add up all digits up to the precision of b.  If b had initially
3711     the same precision as a, or greater, we took care of it by the
3712     padding step above, so there is no problem.  If b had initially
3713     less precision, we'll have to make sure the carry out is duly
3714     propagated upward among the higher-order digits of the sum.
3715    */
3716     lim = USED(b);
3717     for (ib = 0, ia = offset; ib < lim; ib++, ia++) {
3718 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3719         w = (mp_word)DIGIT(a, ia) + DIGIT(b, ib) + k;
3720         DIGIT(a, ia) = ACCUM(w);
3721         k = CARRYOUT(w);
3722 #else
3723         d = MP_DIGIT(a, ia);
3724         sum = d + MP_DIGIT(b, ib);
3725         d = (sum < d);
3726         MP_DIGIT(a, ia) = sum += carry;
3727         carry = d + (sum < carry);
3728 #endif
3729     }
3730 
3731 /* If we run out of 'b' digits before we're actually done, make
3732      sure the carries get propagated upward...
3733    */
3734 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3735     for (lim = MP_USED(a); k && (ia < lim); ++ia) {
3736         w = (mp_word)DIGIT(a, ia) + k;
3737         DIGIT(a, ia) = ACCUM(w);
3738         k = CARRYOUT(w);
3739     }
3740 #else
3741     for (lim = MP_USED(a); carry && (ia < lim); ++ia) {
3742         d = MP_DIGIT(a, ia);
3743         MP_DIGIT(a, ia) = sum = d + carry;
3744         carry = (sum < d);
3745     }
3746 #endif
3747 
3748 /* If there's an overall carry out, increase precision and include
3749      it.  We could have done this initially, but why touch the memory
3750      allocator unless we're sure we have to?
3751    */
3752 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_ADD_WORD)
3753     if (k) {
3754         if ((res = s_mp_pad(a, USED(a) + 1)) != MP_OKAY)
3755             return res;
3756 
3757         DIGIT(a, ia) = (mp_digit)k;
3758     }
3759 #else
3760     if (carry) {
3761         if ((res = s_mp_pad(a, lim + 1)) != MP_OKAY)
3762             return res;
3763 
3764         DIGIT(a, lim) = carry;
3765     }
3766 #endif
3767     s_mp_clamp(a);
3768 
3769     return MP_OKAY;
3770 
3771 } /* end s_mp_add_offset() */
3772 
3773 /* }}} */
3774 
3775 /* {{{ s_mp_sub(a, b) */
3776 
3777 /* Compute a = |a| - |b|, assumes |a| >= |b|                              */
s_mp_sub(mp_int * a,const mp_int * b)3778 mp_err s_mp_sub(mp_int *a, const mp_int *b) /* magnitude subtract      */
3779 {
3780     mp_digit *pa, *pb, *limit;
3781 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3782     mp_sword w = 0;
3783 #else
3784     mp_digit d, diff, borrow = 0;
3785 #endif
3786 
3787     /*
3788     Subtract and propagate borrow.  Up to the precision of b, this
3789     accounts for the digits of b; after that, we just make sure the
3790     carries get to the right place.  This saves having to pad b out to
3791     the precision of a just to make the loops work right...
3792    */
3793     pa = MP_DIGITS(a);
3794     pb = MP_DIGITS(b);
3795     limit = pb + MP_USED(b);
3796     while (pb < limit) {
3797 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3798         w = w + *pa - *pb++;
3799         *pa++ = ACCUM(w);
3800         w >>= MP_DIGIT_BIT;
3801 #else
3802         d = *pa;
3803         diff = d - *pb++;
3804         d = (diff > d); /* detect borrow */
3805         if (borrow && --diff == MP_DIGIT_MAX)
3806             ++d;
3807         *pa++ = diff;
3808         borrow = d;
3809 #endif
3810     }
3811     limit = MP_DIGITS(a) + MP_USED(a);
3812 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3813     while (w && pa < limit) {
3814         w = w + *pa;
3815         *pa++ = ACCUM(w);
3816         w >>= MP_DIGIT_BIT;
3817     }
3818 #else
3819     while (borrow && pa < limit) {
3820         d = *pa;
3821         *pa++ = diff = d - borrow;
3822         borrow = (diff > d);
3823     }
3824 #endif
3825 
3826     /* Clobber any leading zeroes we created    */
3827     s_mp_clamp(a);
3828 
3829 /*
3830      If there was a borrow out, then |b| > |a| in violation
3831      of our input invariant.  We've already done the work,
3832      but we'll at least complain about it...
3833    */
3834 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3835     return w ? MP_RANGE : MP_OKAY;
3836 #else
3837     return borrow ? MP_RANGE : MP_OKAY;
3838 #endif
3839 } /* end s_mp_sub() */
3840 
3841 /* }}} */
3842 
3843 /* Compute c = |a| - |b|, assumes |a| >= |b| */ /* magnitude subtract      */
3844 mp_err
s_mp_sub_3arg(const mp_int * a,const mp_int * b,mp_int * c)3845 s_mp_sub_3arg(const mp_int *a, const mp_int *b, mp_int *c)
3846 {
3847     mp_digit *pa, *pb, *pc;
3848 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3849     mp_sword w = 0;
3850 #else
3851     mp_digit d, diff, borrow = 0;
3852 #endif
3853     int ix, limit;
3854     mp_err res;
3855 
3856     MP_SIGN(c) = MP_SIGN(a);
3857 
3858     /* Make sure a has enough precision for the output value */
3859     if (MP_OKAY != (res = s_mp_pad(c, MP_USED(a))))
3860         return res;
3861 
3862     /*
3863     Subtract and propagate borrow.  Up to the precision of b, this
3864     accounts for the digits of b; after that, we just make sure the
3865     carries get to the right place.  This saves having to pad b out to
3866     the precision of a just to make the loops work right...
3867    */
3868     pa = MP_DIGITS(a);
3869     pb = MP_DIGITS(b);
3870     pc = MP_DIGITS(c);
3871     limit = MP_USED(b);
3872     for (ix = 0; ix < limit; ++ix) {
3873 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3874         w = w + *pa++ - *pb++;
3875         *pc++ = ACCUM(w);
3876         w >>= MP_DIGIT_BIT;
3877 #else
3878         d = *pa++;
3879         diff = d - *pb++;
3880         d = (diff > d);
3881         if (borrow && --diff == MP_DIGIT_MAX)
3882             ++d;
3883         *pc++ = diff;
3884         borrow = d;
3885 #endif
3886     }
3887     for (limit = MP_USED(a); ix < limit; ++ix) {
3888 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3889         w = w + *pa++;
3890         *pc++ = ACCUM(w);
3891         w >>= MP_DIGIT_BIT;
3892 #else
3893         d = *pa++;
3894         *pc++ = diff = d - borrow;
3895         borrow = (diff > d);
3896 #endif
3897     }
3898 
3899     /* Clobber any leading zeroes we created    */
3900     MP_USED(c) = ix;
3901     s_mp_clamp(c);
3902 
3903 /*
3904      If there was a borrow out, then |b| > |a| in violation
3905      of our input invariant.  We've already done the work,
3906      but we'll at least complain about it...
3907    */
3908 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_SUB_WORD)
3909     return w ? MP_RANGE : MP_OKAY;
3910 #else
3911     return borrow ? MP_RANGE : MP_OKAY;
3912 #endif
3913 }
3914 /* {{{ s_mp_mul(a, b) */
3915 
3916 /* Compute a = |a| * |b|                                                  */
3917 mp_err
s_mp_mul(mp_int * a,const mp_int * b)3918 s_mp_mul(mp_int *a, const mp_int *b)
3919 {
3920     return mp_mul(a, b, a);
3921 } /* end s_mp_mul() */
3922 
3923 /* }}} */
3924 
3925 #if defined(MP_USE_UINT_DIGIT) && defined(MP_USE_LONG_LONG_MULTIPLY)
3926 /* This trick works on Sparc V8 CPUs with the Workshop compilers. */
3927 #define MP_MUL_DxD(a, b, Phi, Plo)                              \
3928     {                                                           \
3929         unsigned long long product = (unsigned long long)a * b; \
3930         Plo = (mp_digit)product;                                \
3931         Phi = (mp_digit)(product >> MP_DIGIT_BIT);              \
3932     }
3933 #elif defined(OSF1)
3934 #define MP_MUL_DxD(a, b, Phi, Plo)              \
3935     {                                           \
3936         Plo = asm("mulq %a0, %a1, %v0", a, b);  \
3937         Phi = asm("umulh %a0, %a1, %v0", a, b); \
3938     }
3939 #else
3940 #define MP_MUL_DxD(a, b, Phi, Plo)                                 \
3941     {                                                              \
3942         mp_digit a0b1, a1b0;                                       \
3943         Plo = (a & MP_HALF_DIGIT_MAX) * (b & MP_HALF_DIGIT_MAX);   \
3944         Phi = (a >> MP_HALF_DIGIT_BIT) * (b >> MP_HALF_DIGIT_BIT); \
3945         a0b1 = (a & MP_HALF_DIGIT_MAX) * (b >> MP_HALF_DIGIT_BIT); \
3946         a1b0 = (a >> MP_HALF_DIGIT_BIT) * (b & MP_HALF_DIGIT_MAX); \
3947         a1b0 += a0b1;                                              \
3948         Phi += a1b0 >> MP_HALF_DIGIT_BIT;                          \
3949         if (a1b0 < a0b1)                                           \
3950             Phi += MP_HALF_RADIX;                                  \
3951         a1b0 <<= MP_HALF_DIGIT_BIT;                                \
3952         Plo += a1b0;                                               \
3953         if (Plo < a1b0)                                            \
3954             ++Phi;                                                 \
3955     }
3956 #endif
3957 
3958 #if !defined(MP_ASSEMBLY_MULTIPLY)
3959 /* c = a * b */
3960 void
s_mpv_mul_d(const mp_digit * a,mp_size a_len,mp_digit b,mp_digit * c)3961 s_mpv_mul_d(const mp_digit *a, mp_size a_len, mp_digit b, mp_digit *c)
3962 {
3963 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_MUL_WORD)
3964     mp_digit d = 0;
3965 
3966     /* Inner product:  Digits of a */
3967     while (a_len--) {
3968         mp_word w = ((mp_word)b * *a++) + d;
3969         *c++ = ACCUM(w);
3970         d = CARRYOUT(w);
3971     }
3972     *c = d;
3973 #else
3974     mp_digit carry = 0;
3975     while (a_len--) {
3976         mp_digit a_i = *a++;
3977         mp_digit a0b0, a1b1;
3978 
3979         MP_MUL_DxD(a_i, b, a1b1, a0b0);
3980 
3981         a0b0 += carry;
3982         if (a0b0 < carry)
3983             ++a1b1;
3984         *c++ = a0b0;
3985         carry = a1b1;
3986     }
3987     *c = carry;
3988 #endif
3989 }
3990 
3991 /* c += a * b */
3992 void
s_mpv_mul_d_add(const mp_digit * a,mp_size a_len,mp_digit b,mp_digit * c)3993 s_mpv_mul_d_add(const mp_digit *a, mp_size a_len, mp_digit b,
3994                 mp_digit *c)
3995 {
3996 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_MUL_WORD)
3997     mp_digit d = 0;
3998 
3999     /* Inner product:  Digits of a */
4000     while (a_len--) {
4001         mp_word w = ((mp_word)b * *a++) + *c + d;
4002         *c++ = ACCUM(w);
4003         d = CARRYOUT(w);
4004     }
4005     *c = d;
4006 #else
4007     mp_digit carry = 0;
4008     while (a_len--) {
4009         mp_digit a_i = *a++;
4010         mp_digit a0b0, a1b1;
4011 
4012         MP_MUL_DxD(a_i, b, a1b1, a0b0);
4013 
4014         a0b0 += carry;
4015         if (a0b0 < carry)
4016             ++a1b1;
4017         a0b0 += a_i = *c;
4018         if (a0b0 < a_i)
4019             ++a1b1;
4020         *c++ = a0b0;
4021         carry = a1b1;
4022     }
4023     *c = carry;
4024 #endif
4025 }
4026 
4027 /* Presently, this is only used by the Montgomery arithmetic code. */
4028 /* c += a * b */
4029 void
s_mpv_mul_d_add_prop(const mp_digit * a,mp_size a_len,mp_digit b,mp_digit * c)4030 s_mpv_mul_d_add_prop(const mp_digit *a, mp_size a_len, mp_digit b, mp_digit *c)
4031 {
4032 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_MUL_WORD)
4033     mp_digit d = 0;
4034 
4035     /* Inner product:  Digits of a */
4036     while (a_len--) {
4037         mp_word w = ((mp_word)b * *a++) + *c + d;
4038         *c++ = ACCUM(w);
4039         d = CARRYOUT(w);
4040     }
4041 
4042     while (d) {
4043         mp_word w = (mp_word)*c + d;
4044         *c++ = ACCUM(w);
4045         d = CARRYOUT(w);
4046     }
4047 #else
4048     mp_digit carry = 0;
4049     while (a_len--) {
4050         mp_digit a_i = *a++;
4051         mp_digit a0b0, a1b1;
4052 
4053         MP_MUL_DxD(a_i, b, a1b1, a0b0);
4054 
4055         a0b0 += carry;
4056         if (a0b0 < carry)
4057             ++a1b1;
4058 
4059         a0b0 += a_i = *c;
4060         if (a0b0 < a_i)
4061             ++a1b1;
4062 
4063         *c++ = a0b0;
4064         carry = a1b1;
4065     }
4066     while (carry) {
4067         mp_digit c_i = *c;
4068         carry += c_i;
4069         *c++ = carry;
4070         carry = carry < c_i;
4071     }
4072 #endif
4073 }
4074 #endif
4075 
4076 #if defined(MP_USE_UINT_DIGIT) && defined(MP_USE_LONG_LONG_MULTIPLY)
4077 /* This trick works on Sparc V8 CPUs with the Workshop compilers. */
4078 #define MP_SQR_D(a, Phi, Plo)                                  \
4079     {                                                          \
4080         unsigned long long square = (unsigned long long)a * a; \
4081         Plo = (mp_digit)square;                                \
4082         Phi = (mp_digit)(square >> MP_DIGIT_BIT);              \
4083     }
4084 #elif defined(OSF1)
4085 #define MP_SQR_D(a, Phi, Plo)                \
4086     {                                        \
4087         Plo = asm("mulq  %a0, %a0, %v0", a); \
4088         Phi = asm("umulh %a0, %a0, %v0", a); \
4089     }
4090 #else
4091 #define MP_SQR_D(a, Phi, Plo)                                      \
4092     {                                                              \
4093         mp_digit Pmid;                                             \
4094         Plo = (a & MP_HALF_DIGIT_MAX) * (a & MP_HALF_DIGIT_MAX);   \
4095         Phi = (a >> MP_HALF_DIGIT_BIT) * (a >> MP_HALF_DIGIT_BIT); \
4096         Pmid = (a & MP_HALF_DIGIT_MAX) * (a >> MP_HALF_DIGIT_BIT); \
4097         Phi += Pmid >> (MP_HALF_DIGIT_BIT - 1);                    \
4098         Pmid <<= (MP_HALF_DIGIT_BIT + 1);                          \
4099         Plo += Pmid;                                               \
4100         if (Plo < Pmid)                                            \
4101             ++Phi;                                                 \
4102     }
4103 #endif
4104 
4105 #if !defined(MP_ASSEMBLY_SQUARE)
4106 /* Add the squares of the digits of a to the digits of b. */
4107 void
s_mpv_sqr_add_prop(const mp_digit * pa,mp_size a_len,mp_digit * ps)4108 s_mpv_sqr_add_prop(const mp_digit *pa, mp_size a_len, mp_digit *ps)
4109 {
4110 #if !defined(MP_NO_MP_WORD) && !defined(MP_NO_MUL_WORD)
4111     mp_word w;
4112     mp_digit d;
4113     mp_size ix;
4114 
4115     w = 0;
4116 #define ADD_SQUARE(n)                     \
4117     d = pa[n];                            \
4118     w += (d * (mp_word)d) + ps[2 * n];    \
4119     ps[2 * n] = ACCUM(w);                 \
4120     w = (w >> DIGIT_BIT) + ps[2 * n + 1]; \
4121     ps[2 * n + 1] = ACCUM(w);             \
4122     w = (w >> DIGIT_BIT)
4123 
4124     for (ix = a_len; ix >= 4; ix -= 4) {
4125         ADD_SQUARE(0);
4126         ADD_SQUARE(1);
4127         ADD_SQUARE(2);
4128         ADD_SQUARE(3);
4129         pa += 4;
4130         ps += 8;
4131     }
4132     if (ix) {
4133         ps += 2 * ix;
4134         pa += ix;
4135         switch (ix) {
4136             case 3:
4137                 ADD_SQUARE(-3); /* FALLTHRU */
4138             case 2:
4139                 ADD_SQUARE(-2); /* FALLTHRU */
4140             case 1:
4141                 ADD_SQUARE(-1); /* FALLTHRU */
4142             case 0:
4143                 break;
4144         }
4145     }
4146     while (w) {
4147         w += *ps;
4148         *ps++ = ACCUM(w);
4149         w = (w >> DIGIT_BIT);
4150     }
4151 #else
4152     mp_digit carry = 0;
4153     while (a_len--) {
4154         mp_digit a_i = *pa++;
4155         mp_digit a0a0, a1a1;
4156 
4157         MP_SQR_D(a_i, a1a1, a0a0);
4158 
4159         /* here a1a1 and a0a0 constitute a_i ** 2 */
4160         a0a0 += carry;
4161         if (a0a0 < carry)
4162             ++a1a1;
4163 
4164         /* now add to ps */
4165         a0a0 += a_i = *ps;
4166         if (a0a0 < a_i)
4167             ++a1a1;
4168         *ps++ = a0a0;
4169         a1a1 += a_i = *ps;
4170         carry = (a1a1 < a_i);
4171         *ps++ = a1a1;
4172     }
4173     while (carry) {
4174         mp_digit s_i = *ps;
4175         carry += s_i;
4176         *ps++ = carry;
4177         carry = carry < s_i;
4178     }
4179 #endif
4180 }
4181 #endif
4182 
4183 #if !defined(MP_ASSEMBLY_DIV_2DX1D)
4184 /*
4185 ** Divide 64-bit (Nhi,Nlo) by 32-bit divisor, which must be normalized
4186 ** so its high bit is 1.   This code is from NSPR.
4187 */
4188 mp_err
s_mpv_div_2dx1d(mp_digit Nhi,mp_digit Nlo,mp_digit divisor,mp_digit * qp,mp_digit * rp)4189 s_mpv_div_2dx1d(mp_digit Nhi, mp_digit Nlo, mp_digit divisor,
4190                 mp_digit *qp, mp_digit *rp)
4191 {
4192     mp_digit d1, d0, q1, q0;
4193     mp_digit r1, r0, m;
4194 
4195     d1 = divisor >> MP_HALF_DIGIT_BIT;
4196     d0 = divisor & MP_HALF_DIGIT_MAX;
4197     r1 = Nhi % d1;
4198     q1 = Nhi / d1;
4199     m = q1 * d0;
4200     r1 = (r1 << MP_HALF_DIGIT_BIT) | (Nlo >> MP_HALF_DIGIT_BIT);
4201     if (r1 < m) {
4202         q1--, r1 += divisor;
4203         if (r1 >= divisor && r1 < m) {
4204             q1--, r1 += divisor;
4205         }
4206     }
4207     r1 -= m;
4208     r0 = r1 % d1;
4209     q0 = r1 / d1;
4210     m = q0 * d0;
4211     r0 = (r0 << MP_HALF_DIGIT_BIT) | (Nlo & MP_HALF_DIGIT_MAX);
4212     if (r0 < m) {
4213         q0--, r0 += divisor;
4214         if (r0 >= divisor && r0 < m) {
4215             q0--, r0 += divisor;
4216         }
4217     }
4218     if (qp)
4219         *qp = (q1 << MP_HALF_DIGIT_BIT) | q0;
4220     if (rp)
4221         *rp = r0 - m;
4222     return MP_OKAY;
4223 }
4224 #endif
4225 
4226 #if MP_SQUARE
4227 /* {{{ s_mp_sqr(a) */
4228 
4229 mp_err
s_mp_sqr(mp_int * a)4230 s_mp_sqr(mp_int *a)
4231 {
4232     mp_err res;
4233     mp_int tmp;
4234 
4235     if ((res = mp_init_size(&tmp, 2 * USED(a))) != MP_OKAY)
4236         return res;
4237     res = mp_sqr(a, &tmp);
4238     if (res == MP_OKAY) {
4239         s_mp_exch(&tmp, a);
4240     }
4241     mp_clear(&tmp);
4242     return res;
4243 }
4244 
4245 /* }}} */
4246 #endif
4247 
4248 /* {{{ s_mp_div(a, b) */
4249 
4250 /*
4251   s_mp_div(a, b)
4252 
4253   Compute a = a / b and b = a mod b.  Assumes b > a.
4254  */
4255 
s_mp_div(mp_int * rem,mp_int * div,mp_int * quot)4256 mp_err s_mp_div(mp_int *rem,  /* i: dividend, o: remainder */
4257                 mp_int *div,  /* i: divisor                */
4258                 mp_int *quot) /* i: 0;        o: quotient  */
4259 {
4260     mp_int part, t;
4261     mp_digit q_msd;
4262     mp_err res;
4263     mp_digit d;
4264     mp_digit div_msd;
4265     int ix;
4266 
4267     if (mp_cmp_z(div) == 0)
4268         return MP_RANGE;
4269 
4270     DIGITS(&t) = 0;
4271     /* Shortcut if divisor is power of two */
4272     if ((ix = s_mp_ispow2(div)) >= 0) {
4273         MP_CHECKOK(mp_copy(rem, quot));
4274         s_mp_div_2d(quot, (mp_digit)ix);
4275         s_mp_mod_2d(rem, (mp_digit)ix);
4276 
4277         return MP_OKAY;
4278     }
4279 
4280     MP_SIGN(rem) = ZPOS;
4281     MP_SIGN(div) = ZPOS;
4282     MP_SIGN(&part) = ZPOS;
4283 
4284     /* A working temporary for division     */
4285     MP_CHECKOK(mp_init_size(&t, MP_ALLOC(rem)));
4286 
4287     /* Normalize to optimize guessing       */
4288     MP_CHECKOK(s_mp_norm(rem, div, &d));
4289 
4290     /* Perform the division itself...woo!   */
4291     MP_USED(quot) = MP_ALLOC(quot);
4292 
4293     /* Find a partial substring of rem which is at least div */
4294     /* If we didn't find one, we're finished dividing    */
4295     while (MP_USED(rem) > MP_USED(div) || s_mp_cmp(rem, div) >= 0) {
4296         int i;
4297         int unusedRem;
4298         int partExtended = 0; /* set to true if we need to extend part */
4299 
4300         unusedRem = MP_USED(rem) - MP_USED(div);
4301         MP_DIGITS(&part) = MP_DIGITS(rem) + unusedRem;
4302         MP_ALLOC(&part) = MP_ALLOC(rem) - unusedRem;
4303         MP_USED(&part) = MP_USED(div);
4304 
4305         /* We have now truncated the part of the remainder to the same length as
4306          * the divisor. If part is smaller than div, extend part by one digit. */
4307         if (s_mp_cmp(&part, div) < 0) {
4308             --unusedRem;
4309 #if MP_ARGCHK == 2
4310             assert(unusedRem >= 0);
4311 #endif
4312             --MP_DIGITS(&part);
4313             ++MP_USED(&part);
4314             ++MP_ALLOC(&part);
4315             partExtended = 1;
4316         }
4317 
4318         /* Compute a guess for the next quotient digit       */
4319         q_msd = MP_DIGIT(&part, MP_USED(&part) - 1);
4320         div_msd = MP_DIGIT(div, MP_USED(div) - 1);
4321         if (!partExtended) {
4322             /* In this case, q_msd /= div_msd is always 1. First, since div_msd is
4323              * normalized to have the high bit set, 2*div_msd > MP_DIGIT_MAX. Since
4324              * we didn't extend part, q_msd >= div_msd. Therefore we know that
4325              * div_msd <= q_msd <= MP_DIGIT_MAX < 2*div_msd. Dividing by div_msd we
4326              * get 1 <= q_msd/div_msd < 2. So q_msd /= div_msd must be 1. */
4327             q_msd = 1;
4328         } else {
4329             if (q_msd == div_msd) {
4330                 q_msd = MP_DIGIT_MAX;
4331             } else {
4332                 mp_digit r;
4333                 MP_CHECKOK(s_mpv_div_2dx1d(q_msd, MP_DIGIT(&part, MP_USED(&part) - 2),
4334                                            div_msd, &q_msd, &r));
4335             }
4336         }
4337 #if MP_ARGCHK == 2
4338         assert(q_msd > 0); /* This case should never occur any more. */
4339 #endif
4340         if (q_msd <= 0)
4341             break;
4342 
4343         /* See what that multiplies out to                   */
4344         mp_copy(div, &t);
4345         MP_CHECKOK(s_mp_mul_d(&t, q_msd));
4346 
4347         /*
4348            If it's too big, back it off.  We should not have to do this
4349            more than once, or, in rare cases, twice.  Knuth describes a
4350            method by which this could be reduced to a maximum of once, but
4351            I didn't implement that here.
4352            When using s_mpv_div_2dx1d, we may have to do this 3 times.
4353          */
4354         for (i = 4; s_mp_cmp(&t, &part) > 0 && i > 0; --i) {
4355             --q_msd;
4356             MP_CHECKOK(s_mp_sub(&t, div)); /* t -= div */
4357         }
4358         if (i < 0) {
4359             res = MP_RANGE;
4360             goto CLEANUP;
4361         }
4362 
4363         /* At this point, q_msd should be the right next digit   */
4364         MP_CHECKOK(s_mp_sub(&part, &t)); /* part -= t */
4365         s_mp_clamp(rem);
4366 
4367         /*
4368           Include the digit in the quotient.  We allocated enough memory
4369           for any quotient we could ever possibly get, so we should not
4370           have to check for failures here
4371          */
4372         MP_DIGIT(quot, unusedRem) = q_msd;
4373     }
4374 
4375     /* Denormalize remainder                */
4376     if (d) {
4377         s_mp_div_2d(rem, d);
4378     }
4379 
4380     s_mp_clamp(quot);
4381 
4382 CLEANUP:
4383     mp_clear(&t);
4384 
4385     return res;
4386 
4387 } /* end s_mp_div() */
4388 
4389 /* }}} */
4390 
4391 /* {{{ s_mp_2expt(a, k) */
4392 
4393 mp_err
s_mp_2expt(mp_int * a,mp_digit k)4394 s_mp_2expt(mp_int *a, mp_digit k)
4395 {
4396     mp_err res;
4397     mp_size dig, bit;
4398 
4399     dig = k / DIGIT_BIT;
4400     bit = k % DIGIT_BIT;
4401 
4402     mp_zero(a);
4403     if ((res = s_mp_pad(a, dig + 1)) != MP_OKAY)
4404         return res;
4405 
4406     DIGIT(a, dig) |= ((mp_digit)1 << bit);
4407 
4408     return MP_OKAY;
4409 
4410 } /* end s_mp_2expt() */
4411 
4412 /* }}} */
4413 
4414 /* {{{ s_mp_reduce(x, m, mu) */
4415 
4416 /*
4417   Compute Barrett reduction, x (mod m), given a precomputed value for
4418   mu = b^2k / m, where b = RADIX and k = #digits(m).  This should be
4419   faster than straight division, when many reductions by the same
4420   value of m are required (such as in modular exponentiation).  This
4421   can nearly halve the time required to do modular exponentiation,
4422   as compared to using the full integer divide to reduce.
4423 
4424   This algorithm was derived from the _Handbook of Applied
4425   Cryptography_ by Menezes, Oorschot and VanStone, Ch. 14,
4426   pp. 603-604.
4427  */
4428 
4429 mp_err
s_mp_reduce(mp_int * x,const mp_int * m,const mp_int * mu)4430 s_mp_reduce(mp_int *x, const mp_int *m, const mp_int *mu)
4431 {
4432     mp_int q;
4433     mp_err res;
4434 
4435     if ((res = mp_init_copy(&q, x)) != MP_OKAY)
4436         return res;
4437 
4438     s_mp_rshd(&q, USED(m) - 1); /* q1 = x / b^(k-1)  */
4439     s_mp_mul(&q, mu);           /* q2 = q1 * mu      */
4440     s_mp_rshd(&q, USED(m) + 1); /* q3 = q2 / b^(k+1) */
4441 
4442     /* x = x mod b^(k+1), quick (no division) */
4443     s_mp_mod_2d(x, DIGIT_BIT * (USED(m) + 1));
4444 
4445     /* q = q * m mod b^(k+1), quick (no division) */
4446     s_mp_mul(&q, m);
4447     s_mp_mod_2d(&q, DIGIT_BIT * (USED(m) + 1));
4448 
4449     /* x = x - q */
4450     if ((res = mp_sub(x, &q, x)) != MP_OKAY)
4451         goto CLEANUP;
4452 
4453     /* If x < 0, add b^(k+1) to it */
4454     if (mp_cmp_z(x) < 0) {
4455         mp_set(&q, 1);
4456         if ((res = s_mp_lshd(&q, USED(m) + 1)) != MP_OKAY)
4457             goto CLEANUP;
4458         if ((res = mp_add(x, &q, x)) != MP_OKAY)
4459             goto CLEANUP;
4460     }
4461 
4462     /* Back off if it's too big */
4463     while (mp_cmp(x, m) >= 0) {
4464         if ((res = s_mp_sub(x, m)) != MP_OKAY)
4465             break;
4466     }
4467 
4468 CLEANUP:
4469     mp_clear(&q);
4470 
4471     return res;
4472 
4473 } /* end s_mp_reduce() */
4474 
4475 /* }}} */
4476 
4477 /* }}} */
4478 
4479 /* {{{ Primitive comparisons */
4480 
4481 /* {{{ s_mp_cmp(a, b) */
4482 
4483 /* Compare |a| <=> |b|, return 0 if equal, <0 if a<b, >0 if a>b           */
4484 int
s_mp_cmp(const mp_int * a,const mp_int * b)4485 s_mp_cmp(const mp_int *a, const mp_int *b)
4486 {
4487     ARGMPCHK(a != NULL && b != NULL);
4488 
4489     mp_size used_a = MP_USED(a);
4490     {
4491         mp_size used_b = MP_USED(b);
4492 
4493         if (used_a > used_b)
4494             goto IS_GT;
4495         if (used_a < used_b)
4496             goto IS_LT;
4497     }
4498     {
4499         mp_digit *pa, *pb;
4500         mp_digit da = 0, db = 0;
4501 
4502 #define CMP_AB(n)                     \
4503     if ((da = pa[n]) != (db = pb[n])) \
4504     goto done
4505 
4506         pa = MP_DIGITS(a) + used_a;
4507         pb = MP_DIGITS(b) + used_a;
4508         while (used_a >= 4) {
4509             pa -= 4;
4510             pb -= 4;
4511             used_a -= 4;
4512             CMP_AB(3);
4513             CMP_AB(2);
4514             CMP_AB(1);
4515             CMP_AB(0);
4516         }
4517         while (used_a-- > 0 && ((da = *--pa) == (db = *--pb)))
4518             /* do nothing */;
4519     done:
4520         if (da > db)
4521             goto IS_GT;
4522         if (da < db)
4523             goto IS_LT;
4524     }
4525     return MP_EQ;
4526 IS_LT:
4527     return MP_LT;
4528 IS_GT:
4529     return MP_GT;
4530 } /* end s_mp_cmp() */
4531 
4532 /* }}} */
4533 
4534 /* {{{ s_mp_cmp_d(a, d) */
4535 
4536 /* Compare |a| <=> d, return 0 if equal, <0 if a<d, >0 if a>d             */
4537 int
s_mp_cmp_d(const mp_int * a,mp_digit d)4538 s_mp_cmp_d(const mp_int *a, mp_digit d)
4539 {
4540     ARGMPCHK(a != NULL);
4541 
4542     if (USED(a) > 1)
4543         return MP_GT;
4544 
4545     if (DIGIT(a, 0) < d)
4546         return MP_LT;
4547     else if (DIGIT(a, 0) > d)
4548         return MP_GT;
4549     else
4550         return MP_EQ;
4551 
4552 } /* end s_mp_cmp_d() */
4553 
4554 /* }}} */
4555 
4556 /* {{{ s_mp_ispow2(v) */
4557 
4558 /*
4559   Returns -1 if the value is not a power of two; otherwise, it returns
4560   k such that v = 2^k, i.e. lg(v).
4561  */
4562 int
s_mp_ispow2(const mp_int * v)4563 s_mp_ispow2(const mp_int *v)
4564 {
4565     mp_digit d;
4566     int extra = 0, ix;
4567 
4568     ARGMPCHK(v != NULL);
4569 
4570     ix = MP_USED(v) - 1;
4571     d = MP_DIGIT(v, ix); /* most significant digit of v */
4572 
4573     extra = s_mp_ispow2d(d);
4574     if (extra < 0 || ix == 0)
4575         return extra;
4576 
4577     while (--ix >= 0) {
4578         if (DIGIT(v, ix) != 0)
4579             return -1; /* not a power of two */
4580         extra += MP_DIGIT_BIT;
4581     }
4582 
4583     return extra;
4584 
4585 } /* end s_mp_ispow2() */
4586 
4587 /* }}} */
4588 
4589 /* {{{ s_mp_ispow2d(d) */
4590 
4591 int
s_mp_ispow2d(mp_digit d)4592 s_mp_ispow2d(mp_digit d)
4593 {
4594     if ((d != 0) && ((d & (d - 1)) == 0)) { /* d is a power of 2 */
4595         int pow = 0;
4596 #if defined(MP_USE_UINT_DIGIT)
4597         if (d & 0xffff0000U)
4598             pow += 16;
4599         if (d & 0xff00ff00U)
4600             pow += 8;
4601         if (d & 0xf0f0f0f0U)
4602             pow += 4;
4603         if (d & 0xccccccccU)
4604             pow += 2;
4605         if (d & 0xaaaaaaaaU)
4606             pow += 1;
4607 #elif defined(MP_USE_LONG_LONG_DIGIT)
4608         if (d & 0xffffffff00000000ULL)
4609             pow += 32;
4610         if (d & 0xffff0000ffff0000ULL)
4611             pow += 16;
4612         if (d & 0xff00ff00ff00ff00ULL)
4613             pow += 8;
4614         if (d & 0xf0f0f0f0f0f0f0f0ULL)
4615             pow += 4;
4616         if (d & 0xccccccccccccccccULL)
4617             pow += 2;
4618         if (d & 0xaaaaaaaaaaaaaaaaULL)
4619             pow += 1;
4620 #elif defined(MP_USE_LONG_DIGIT)
4621         if (d & 0xffffffff00000000UL)
4622             pow += 32;
4623         if (d & 0xffff0000ffff0000UL)
4624             pow += 16;
4625         if (d & 0xff00ff00ff00ff00UL)
4626             pow += 8;
4627         if (d & 0xf0f0f0f0f0f0f0f0UL)
4628             pow += 4;
4629         if (d & 0xccccccccccccccccUL)
4630             pow += 2;
4631         if (d & 0xaaaaaaaaaaaaaaaaUL)
4632             pow += 1;
4633 #else
4634 #error "unknown type for mp_digit"
4635 #endif
4636         return pow;
4637     }
4638     return -1;
4639 
4640 } /* end s_mp_ispow2d() */
4641 
4642 /* }}} */
4643 
4644 /* }}} */
4645 
4646 /* {{{ Primitive I/O helpers */
4647 
4648 /* {{{ s_mp_tovalue(ch, r) */
4649 
4650 /*
4651   Convert the given character to its digit value, in the given radix.
4652   If the given character is not understood in the given radix, -1 is
4653   returned.  Otherwise the digit's numeric value is returned.
4654 
4655   The results will be odd if you use a radix < 2 or > 62, you are
4656   expected to know what you're up to.
4657  */
4658 int
s_mp_tovalue(char ch,int r)4659 s_mp_tovalue(char ch, int r)
4660 {
4661     int val, xch;
4662 
4663     if (r > 36)
4664         xch = ch;
4665     else
4666         xch = toupper(ch);
4667 
4668     if (isdigit(xch))
4669         val = xch - '0';
4670     else if (isupper(xch))
4671         val = xch - 'A' + 10;
4672     else if (islower(xch))
4673         val = xch - 'a' + 36;
4674     else if (xch == '+')
4675         val = 62;
4676     else if (xch == '/')
4677         val = 63;
4678     else
4679         return -1;
4680 
4681     if (val < 0 || val >= r)
4682         return -1;
4683 
4684     return val;
4685 
4686 } /* end s_mp_tovalue() */
4687 
4688 /* }}} */
4689 
4690 /* {{{ s_mp_todigit(val, r, low) */
4691 
4692 /*
4693   Convert val to a radix-r digit, if possible.  If val is out of range
4694   for r, returns zero.  Otherwise, returns an ASCII character denoting
4695   the value in the given radix.
4696 
4697   The results may be odd if you use a radix < 2 or > 64, you are
4698   expected to know what you're doing.
4699  */
4700 
4701 char
s_mp_todigit(mp_digit val,int r,int low)4702 s_mp_todigit(mp_digit val, int r, int low)
4703 {
4704     char ch;
4705 
4706     if (val >= r)
4707         return 0;
4708 
4709     ch = s_dmap_1[val];
4710 
4711     if (r <= 36 && low)
4712         ch = tolower(ch);
4713 
4714     return ch;
4715 
4716 } /* end s_mp_todigit() */
4717 
4718 /* }}} */
4719 
4720 /* {{{ s_mp_outlen(bits, radix) */
4721 
4722 /*
4723    Return an estimate for how long a string is needed to hold a radix
4724    r representation of a number with 'bits' significant bits, plus an
4725    extra for a zero terminator (assuming C style strings here)
4726  */
4727 int
s_mp_outlen(int bits,int r)4728 s_mp_outlen(int bits, int r)
4729 {
4730     return (int)((double)bits * LOG_V_2(r) + 1.5) + 1;
4731 
4732 } /* end s_mp_outlen() */
4733 
4734 /* }}} */
4735 
4736 /* }}} */
4737 
4738 /* {{{ mp_read_unsigned_octets(mp, str, len) */
4739 /* mp_read_unsigned_octets(mp, str, len)
4740    Read in a raw value (base 256) into the given mp_int
4741    No sign bit, number is positive.  Leading zeros ignored.
4742  */
4743 
4744 mp_err
mp_read_unsigned_octets(mp_int * mp,const unsigned char * str,mp_size len)4745 mp_read_unsigned_octets(mp_int *mp, const unsigned char *str, mp_size len)
4746 {
4747     int count;
4748     mp_err res;
4749     mp_digit d;
4750 
4751     ARGCHK(mp != NULL && str != NULL && len > 0, MP_BADARG);
4752 
4753     mp_zero(mp);
4754 
4755     count = len % sizeof(mp_digit);
4756     if (count) {
4757         for (d = 0; count-- > 0; --len) {
4758             d = (d << 8) | *str++;
4759         }
4760         MP_DIGIT(mp, 0) = d;
4761     }
4762 
4763     /* Read the rest of the digits */
4764     for (; len > 0; len -= sizeof(mp_digit)) {
4765         for (d = 0, count = sizeof(mp_digit); count > 0; --count) {
4766             d = (d << 8) | *str++;
4767         }
4768         if (MP_EQ == mp_cmp_z(mp)) {
4769             if (!d)
4770                 continue;
4771         } else {
4772             if ((res = s_mp_lshd(mp, 1)) != MP_OKAY)
4773                 return res;
4774         }
4775         MP_DIGIT(mp, 0) = d;
4776     }
4777     return MP_OKAY;
4778 } /* end mp_read_unsigned_octets() */
4779 /* }}} */
4780 
4781 /* {{{ mp_unsigned_octet_size(mp) */
4782 unsigned int
mp_unsigned_octet_size(const mp_int * mp)4783 mp_unsigned_octet_size(const mp_int *mp)
4784 {
4785     unsigned int bytes;
4786     int ix;
4787     mp_digit d = 0;
4788 
4789     ARGCHK(mp != NULL, MP_BADARG);
4790     ARGCHK(MP_ZPOS == SIGN(mp), MP_BADARG);
4791 
4792     bytes = (USED(mp) * sizeof(mp_digit));
4793 
4794     /* subtract leading zeros. */
4795     /* Iterate over each digit... */
4796     for (ix = USED(mp) - 1; ix >= 0; ix--) {
4797         d = DIGIT(mp, ix);
4798         if (d)
4799             break;
4800         bytes -= sizeof(d);
4801     }
4802     if (!bytes)
4803         return 1;
4804 
4805     /* Have MSD, check digit bytes, high order first */
4806     for (ix = sizeof(mp_digit) - 1; ix >= 0; ix--) {
4807         unsigned char x = (unsigned char)(d >> (ix * CHAR_BIT));
4808         if (x)
4809             break;
4810         --bytes;
4811     }
4812     return bytes;
4813 } /* end mp_unsigned_octet_size() */
4814 /* }}} */
4815 
4816 /* {{{ mp_to_unsigned_octets(mp, str) */
4817 /* output a buffer of big endian octets no longer than specified. */
4818 mp_err
mp_to_unsigned_octets(const mp_int * mp,unsigned char * str,mp_size maxlen)4819 mp_to_unsigned_octets(const mp_int *mp, unsigned char *str, mp_size maxlen)
4820 {
4821     int ix, pos = 0;
4822     unsigned int bytes;
4823 
4824     ARGCHK(mp != NULL && str != NULL && !SIGN(mp), MP_BADARG);
4825 
4826     bytes = mp_unsigned_octet_size(mp);
4827     ARGCHK(bytes <= maxlen, MP_BADARG);
4828 
4829     /* Iterate over each digit... */
4830     for (ix = USED(mp) - 1; ix >= 0; ix--) {
4831         mp_digit d = DIGIT(mp, ix);
4832         int jx;
4833 
4834         /* Unpack digit bytes, high order first */
4835         for (jx = sizeof(mp_digit) - 1; jx >= 0; jx--) {
4836             unsigned char x = (unsigned char)(d >> (jx * CHAR_BIT));
4837             if (!pos && !x) /* suppress leading zeros */
4838                 continue;
4839             str[pos++] = x;
4840         }
4841     }
4842     if (!pos)
4843         str[pos++] = 0;
4844     return pos;
4845 } /* end mp_to_unsigned_octets() */
4846 /* }}} */
4847 
4848 /* {{{ mp_to_signed_octets(mp, str) */
4849 /* output a buffer of big endian octets no longer than specified. */
4850 mp_err
mp_to_signed_octets(const mp_int * mp,unsigned char * str,mp_size maxlen)4851 mp_to_signed_octets(const mp_int *mp, unsigned char *str, mp_size maxlen)
4852 {
4853     int ix, pos = 0;
4854     unsigned int bytes;
4855 
4856     ARGCHK(mp != NULL && str != NULL && !SIGN(mp), MP_BADARG);
4857 
4858     bytes = mp_unsigned_octet_size(mp);
4859     ARGCHK(bytes <= maxlen, MP_BADARG);
4860 
4861     /* Iterate over each digit... */
4862     for (ix = USED(mp) - 1; ix >= 0; ix--) {
4863         mp_digit d = DIGIT(mp, ix);
4864         int jx;
4865 
4866         /* Unpack digit bytes, high order first */
4867         for (jx = sizeof(mp_digit) - 1; jx >= 0; jx--) {
4868             unsigned char x = (unsigned char)(d >> (jx * CHAR_BIT));
4869             if (!pos) {
4870                 if (!x) /* suppress leading zeros */
4871                     continue;
4872                 if (x & 0x80) { /* add one leading zero to make output positive.  */
4873                     ARGCHK(bytes + 1 <= maxlen, MP_BADARG);
4874                     if (bytes + 1 > maxlen)
4875                         return MP_BADARG;
4876                     str[pos++] = 0;
4877                 }
4878             }
4879             str[pos++] = x;
4880         }
4881     }
4882     if (!pos)
4883         str[pos++] = 0;
4884     return pos;
4885 } /* end mp_to_signed_octets() */
4886 /* }}} */
4887 
4888 /* {{{ mp_to_fixlen_octets(mp, str) */
4889 /* output a buffer of big endian octets exactly as long as requested.
4890    constant time on the value of mp. */
4891 mp_err
mp_to_fixlen_octets(const mp_int * mp,unsigned char * str,mp_size length)4892 mp_to_fixlen_octets(const mp_int *mp, unsigned char *str, mp_size length)
4893 {
4894     int ix, jx;
4895     unsigned int bytes;
4896 
4897     ARGCHK(mp != NULL && str != NULL && !SIGN(mp) && length > 0, MP_BADARG);
4898 
4899     /* Constant time on the value of mp.  Don't use mp_unsigned_octet_size. */
4900     bytes = USED(mp) * MP_DIGIT_SIZE;
4901 
4902     /* If the output is shorter than the native size of mp, then check that any
4903      * bytes not written have zero values.  This check isn't constant time on
4904      * the assumption that timing-sensitive callers can guarantee that mp fits
4905      * in the allocated space. */
4906     ix = USED(mp) - 1;
4907     if (bytes > length) {
4908         unsigned int zeros = bytes - length;
4909 
4910         while (zeros >= MP_DIGIT_SIZE) {
4911             ARGCHK(DIGIT(mp, ix) == 0, MP_BADARG);
4912             zeros -= MP_DIGIT_SIZE;
4913             ix--;
4914         }
4915 
4916         if (zeros > 0) {
4917             mp_digit d = DIGIT(mp, ix);
4918             mp_digit m = ~0ULL << ((MP_DIGIT_SIZE - zeros) * CHAR_BIT);
4919             ARGCHK((d & m) == 0, MP_BADARG);
4920             for (jx = MP_DIGIT_SIZE - zeros - 1; jx >= 0; jx--) {
4921                 *str++ = d >> (jx * CHAR_BIT);
4922             }
4923             ix--;
4924         }
4925     } else if (bytes < length) {
4926         /* Place any needed leading zeros. */
4927         unsigned int zeros = length - bytes;
4928         memset(str, 0, zeros);
4929         str += zeros;
4930     }
4931 
4932     /* Iterate over each whole digit... */
4933     for (; ix >= 0; ix--) {
4934         mp_digit d = DIGIT(mp, ix);
4935 
4936         /* Unpack digit bytes, high order first */
4937         for (jx = MP_DIGIT_SIZE - 1; jx >= 0; jx--) {
4938             *str++ = d >> (jx * CHAR_BIT);
4939         }
4940     }
4941     return MP_OKAY;
4942 } /* end mp_to_fixlen_octets() */
4943 /* }}} */
4944 
4945 /* {{{ mp_cswap(condition, a, b, numdigits) */
4946 /* performs a conditional swap between mp_int. */
4947 mp_err
mp_cswap(mp_digit condition,mp_int * a,mp_int * b,mp_size numdigits)4948 mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits)
4949 {
4950     mp_digit x;
4951     unsigned int i;
4952     mp_err res = 0;
4953 
4954     /* if pointers are equal return */
4955     if (a == b)
4956         return res;
4957 
4958     if (MP_ALLOC(a) < numdigits || MP_ALLOC(b) < numdigits) {
4959         MP_CHECKOK(s_mp_grow(a, numdigits));
4960         MP_CHECKOK(s_mp_grow(b, numdigits));
4961     }
4962 
4963     condition = ((~condition & ((condition - 1))) >> (MP_DIGIT_BIT - 1)) - 1;
4964 
4965     x = (USED(a) ^ USED(b)) & condition;
4966     USED(a) ^= x;
4967     USED(b) ^= x;
4968 
4969     x = (SIGN(a) ^ SIGN(b)) & condition;
4970     SIGN(a) ^= x;
4971     SIGN(b) ^= x;
4972 
4973     for (i = 0; i < numdigits; i++) {
4974         x = (DIGIT(a, i) ^ DIGIT(b, i)) & condition;
4975         DIGIT(a, i) ^= x;
4976         DIGIT(b, i) ^= x;
4977     }
4978 
4979 CLEANUP:
4980     return res;
4981 } /* end mp_cswap() */
4982 /* }}} */
4983 
4984 /*------------------------------------------------------------------------*/
4985 /* HERE THERE BE DRAGONS                                                  */
4986