1 /* integer.c
2  *
3  * Copyright (C) 2006-2021 wolfSSL Inc.
4  *
5  * This file is part of wolfSSL.
6  *
7  * wolfSSL is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  *
12  * wolfSSL is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA
20  */
21 
22 
23 
24 /*
25  * Based on public domain LibTomMath 0.38 by Tom St Denis, tomstdenis@iahu.ca,
26  * http://math.libtomcrypt.com
27  */
28 
29 
30 #ifdef HAVE_CONFIG_H
31     #include <config.h>
32 #endif
33 
34 /* in case user set USE_FAST_MATH there */
35 #include <wolfssl/wolfcrypt/settings.h>
36 
37 #ifdef NO_INLINE
38     #include <wolfssl/wolfcrypt/misc.h>
39 #else
40     #define WOLFSSL_MISC_INCLUDED
41     #include <wolfcrypt/src/misc.c>
42 #endif
43 
44 #ifndef NO_BIG_INT
45 
46 #ifndef USE_FAST_MATH
47 
48 #ifndef WOLFSSL_SP_MATH
49 
50 #include <wolfssl/wolfcrypt/integer.h>
51 
52 #if defined(FREESCALE_LTC_TFM)
53     #include <wolfssl/wolfcrypt/port/nxp/ksdk_port.h>
54 #endif
55 #ifdef WOLFSSL_DEBUG_MATH
56     #include <stdio.h>
57 #endif
58 
59 #ifdef SHOW_GEN
60     #ifndef NO_STDIO_FILESYSTEM
61         #include <stdio.h>
62     #endif
63 #endif
64 
65 #if defined(WOLFSSL_HAVE_SP_RSA) || defined(WOLFSSL_HAVE_SP_DH)
66 #ifdef __cplusplus
67     extern "C" {
68 #endif
69 WOLFSSL_LOCAL int sp_ModExp_1024(mp_int* base, mp_int* exp, mp_int* mod,
70     mp_int* res);
71 WOLFSSL_LOCAL int sp_ModExp_1536(mp_int* base, mp_int* exp, mp_int* mod,
72     mp_int* res);
73 WOLFSSL_LOCAL int sp_ModExp_2048(mp_int* base, mp_int* exp, mp_int* mod,
74     mp_int* res);
75 WOLFSSL_LOCAL int sp_ModExp_3072(mp_int* base, mp_int* exp, mp_int* mod,
76     mp_int* res);
77 WOLFSSL_LOCAL int sp_ModExp_4096(mp_int* base, mp_int* exp, mp_int* mod,
78     mp_int* res);
79 #ifdef __cplusplus
80     } /* extern "C" */
81 #endif
82 #endif
83 
84 /* reverse an array, used for radix code */
85 static void
bn_reverse(unsigned char * s,int len)86 bn_reverse (unsigned char *s, int len)
87 {
88     int     ix, iy;
89     unsigned char t;
90 
91     ix = 0;
92     iy = len - 1;
93     while (ix < iy) {
94         t     = s[ix];
95         s[ix] = s[iy];
96         s[iy] = t;
97         ++ix;
98         --iy;
99     }
100 }
101 
102 /* math settings check */
CheckRunTimeSettings(void)103 word32 CheckRunTimeSettings(void)
104 {
105     return CTC_SETTINGS;
106 }
107 
108 
109 /* handle up to 6 inits */
mp_init_multi(mp_int * a,mp_int * b,mp_int * c,mp_int * d,mp_int * e,mp_int * f)110 int mp_init_multi(mp_int* a, mp_int* b, mp_int* c, mp_int* d, mp_int* e,
111                   mp_int* f)
112 {
113     int res = MP_OKAY;
114 
115     if (a) XMEMSET(a, 0, sizeof(mp_int));
116     if (b) XMEMSET(b, 0, sizeof(mp_int));
117     if (c) XMEMSET(c, 0, sizeof(mp_int));
118     if (d) XMEMSET(d, 0, sizeof(mp_int));
119     if (e) XMEMSET(e, 0, sizeof(mp_int));
120     if (f) XMEMSET(f, 0, sizeof(mp_int));
121 
122     if (a && ((res = mp_init(a)) != MP_OKAY))
123         return res;
124 
125     if (b && ((res = mp_init(b)) != MP_OKAY)) {
126         mp_clear(a);
127         return res;
128     }
129 
130     if (c && ((res = mp_init(c)) != MP_OKAY)) {
131         mp_clear(a); mp_clear(b);
132         return res;
133     }
134 
135     if (d && ((res = mp_init(d)) != MP_OKAY)) {
136         mp_clear(a); mp_clear(b); mp_clear(c);
137         return res;
138     }
139 
140     if (e && ((res = mp_init(e)) != MP_OKAY)) {
141         mp_clear(a); mp_clear(b); mp_clear(c); mp_clear(d);
142         return res;
143     }
144 
145     if (f && ((res = mp_init(f)) != MP_OKAY)) {
146         mp_clear(a); mp_clear(b); mp_clear(c); mp_clear(d); mp_clear(e);
147         return res;
148     }
149 
150     return res;
151 }
152 
153 
154 /* init a new mp_int */
mp_init(mp_int * a)155 int mp_init (mp_int * a)
156 {
157   /* Safeguard against passing in a null pointer */
158   if (a == NULL)
159     return MP_VAL;
160 
161   /* defer allocation until mp_grow */
162   a->dp = NULL;
163 
164   /* set the used to zero, allocated digits to the default precision
165    * and sign to positive */
166   a->used  = 0;
167   a->alloc = 0;
168   a->sign  = MP_ZPOS;
169 #ifdef HAVE_WOLF_BIGINT
170   wc_bigint_init(&a->raw);
171 #endif
172 
173   return MP_OKAY;
174 }
175 
176 
177 /* clear one (frees)  */
mp_clear(mp_int * a)178 void mp_clear (mp_int * a)
179 {
180   int i;
181 
182   if (a == NULL)
183       return;
184 
185   /* only do anything if a hasn't been freed previously */
186 #ifndef HAVE_WOLF_BIGINT
187   /* When HAVE_WOLF_BIGINT then mp_free -> wc_bigint_free needs to be called
188    * because a->raw->buf may be allocated even when a->dp == NULL. This is the
189    * case for when a zero is loaded into the mp_int. */
190   if (a->dp != NULL)
191 #endif
192   {
193     /* first zero the digits */
194     for (i = 0; i < a->used; i++) {
195         a->dp[i] = 0;
196     }
197 
198     /* free ram */
199     mp_free(a);
200 
201     /* reset members to make debugging easier */
202     a->alloc = a->used = 0;
203     a->sign  = MP_ZPOS;
204   }
205 }
206 
mp_free(mp_int * a)207 void mp_free (mp_int * a)
208 {
209   /* only do anything if a hasn't been freed previously */
210   if (a->dp != NULL) {
211     /* free ram */
212     XFREE(a->dp, 0, DYNAMIC_TYPE_BIGINT);
213     a->dp    = NULL;
214   }
215 
216 #ifdef HAVE_WOLF_BIGINT
217   wc_bigint_free(&a->raw);
218 #endif
219 }
220 
mp_forcezero(mp_int * a)221 void mp_forcezero(mp_int * a)
222 {
223     if (a == NULL)
224         return;
225 
226     /* only do anything if a hasn't been freed previously */
227     if (a->dp != NULL) {
228       /* force zero the used digits */
229       ForceZero(a->dp, a->used * sizeof(mp_digit));
230 #ifdef HAVE_WOLF_BIGINT
231       wc_bigint_zero(&a->raw);
232 #endif
233       /* free ram */
234       mp_free(a);
235 
236       /* reset members to make debugging easier */
237       a->alloc = a->used = 0;
238       a->sign  = MP_ZPOS;
239     }
240 
241     a->sign = MP_ZPOS;
242     a->used = 0;
243 }
244 
245 
246 /* get the size for an unsigned equivalent */
mp_unsigned_bin_size(const mp_int * a)247 int mp_unsigned_bin_size (const mp_int * a)
248 {
249   int     size = mp_count_bits (a);
250   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
251 }
252 
253 
254 /* returns the number of bits in an int */
mp_count_bits(const mp_int * a)255 int mp_count_bits (const mp_int * a)
256 {
257   int     r;
258   mp_digit q;
259 
260   /* shortcut */
261   if (a->used == 0) {
262     return 0;
263   }
264 
265   /* get number of digits and add that */
266   r = (a->used - 1) * DIGIT_BIT;
267 
268   /* take the last digit and count the bits in it */
269   q = a->dp[a->used - 1];
270   while (q > ((mp_digit) 0)) {
271     ++r;
272     q >>= ((mp_digit) 1);
273   }
274   return r;
275 }
276 
277 
mp_leading_bit(mp_int * a)278 int mp_leading_bit (mp_int * a)
279 {
280     int c = mp_count_bits(a);
281 
282     if (c == 0) return 0;
283     return (c % 8) == 0;
284 }
285 
mp_to_unsigned_bin_at_pos(int x,mp_int * t,unsigned char * b)286 int mp_to_unsigned_bin_at_pos(int x, mp_int *t, unsigned char *b)
287 {
288   int res = 0;
289   while (mp_iszero(t) == MP_NO) {
290 #ifndef MP_8BIT
291       b[x++] = (unsigned char) (t->dp[0] & 255);
292 #else
293       b[x++] = (unsigned char) (t->dp[0] | ((t->dp[1] & 0x01) << 7));
294 #endif
295     if ((res = mp_div_2d (t, 8, t, NULL)) != MP_OKAY) {
296       return res;
297     }
298     res = x;
299   }
300   return res;
301 }
302 
303 /* store in unsigned [big endian] format */
mp_to_unsigned_bin(mp_int * a,unsigned char * b)304 int mp_to_unsigned_bin (mp_int * a, unsigned char *b)
305 {
306   int     x, res;
307   mp_int  t;
308 
309   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
310     return res;
311   }
312 
313   x = mp_to_unsigned_bin_at_pos(0, &t, b);
314   if (x < 0) {
315     mp_clear(&t);
316     return x;
317   }
318 
319   bn_reverse (b, x);
320   mp_clear (&t);
321   return res;
322 }
323 
mp_to_unsigned_bin_len(mp_int * a,unsigned char * b,int c)324 int mp_to_unsigned_bin_len(mp_int * a, unsigned char *b, int c)
325 {
326     int i, len;
327 
328     len = mp_unsigned_bin_size(a);
329 
330     if (len > c) {
331       return MP_VAL;
332     }
333 
334     /* pad front w/ zeros to match length */
335     for (i = 0; i < c - len; i++) {
336       b[i] = 0x00;
337     }
338     return mp_to_unsigned_bin(a, b + i);
339 }
340 
341 /* creates "a" then copies b into it */
mp_init_copy(mp_int * a,mp_int * b)342 int mp_init_copy (mp_int * a, mp_int * b)
343 {
344   int     res;
345 
346   if ((res = mp_init_size (a, b->used)) != MP_OKAY) {
347     return res;
348   }
349 
350   if((res = mp_copy (b, a)) != MP_OKAY) {
351     mp_clear(a);
352   }
353 
354   return res;
355 }
356 
357 
358 /* copy, b = a */
mp_copy(const mp_int * a,mp_int * b)359 int mp_copy (const mp_int * a, mp_int * b)
360 {
361   int     res, n;
362 
363   /* Safeguard against passing in a null pointer */
364   if (a == NULL || b == NULL)
365     return MP_VAL;
366 
367   /* if dst == src do nothing */
368   if (a == b) {
369     return MP_OKAY;
370   }
371 
372   /* grow dest */
373   if (b->alloc < a->used || b->alloc == 0) {
374      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
375         return res;
376      }
377   }
378 
379   /* zero b and copy the parameters over */
380   {
381     mp_digit *tmpa, *tmpb;
382 
383     /* pointer aliases */
384 
385     /* source */
386     tmpa = a->dp;
387 
388     /* destination */
389     tmpb = b->dp;
390 
391     /* copy all the digits */
392     for (n = 0; n < a->used; n++) {
393       *tmpb++ = *tmpa++;
394     }
395 
396     /* clear high digits */
397     for (; n < b->used && b->dp; n++) {
398       *tmpb++ = 0;
399     }
400   }
401 
402   /* copy used count and sign */
403   b->used = a->used;
404   b->sign = a->sign;
405   return MP_OKAY;
406 }
407 
408 
409 /* grow as required */
mp_grow(mp_int * a,int size)410 int mp_grow (mp_int * a, int size)
411 {
412   int     i;
413   mp_digit *tmp;
414 
415   /* if the alloc size is smaller alloc more ram */
416   if (a->alloc < size || size == 0) {
417     /* ensure there are always at least MP_PREC digits extra on top */
418     size += (MP_PREC * 2) - (size % MP_PREC);
419 
420     /* reallocate the array a->dp
421      *
422      * We store the return in a temporary variable
423      * in case the operation failed we don't want
424      * to overwrite the dp member of a.
425      */
426     tmp = OPT_CAST(mp_digit) XREALLOC (a->dp, sizeof (mp_digit) * size, NULL,
427                                                            DYNAMIC_TYPE_BIGINT);
428     if (tmp == NULL) {
429       /* reallocation failed but "a" is still valid [can be freed] */
430       return MP_MEM;
431     }
432 
433     /* reallocation succeeded so set a->dp */
434     a->dp = tmp;
435 
436     /* zero excess digits */
437     i        = a->alloc;
438     a->alloc = size;
439     for (; i < a->alloc; i++) {
440       a->dp[i] = 0;
441     }
442   }
443   return MP_OKAY;
444 }
445 
446 
447 /* shift right by a certain bit count (store quotient in c, optional
448    remainder in d) */
mp_div_2d(mp_int * a,int b,mp_int * c,mp_int * d)449 int mp_div_2d (mp_int * a, int b, mp_int * c, mp_int * d)
450 {
451   int     D, res;
452   mp_int  t;
453 
454 
455   /* if the shift count is <= 0 then we do no work */
456   if (b <= 0) {
457     res = mp_copy (a, c);
458     if (d != NULL) {
459       mp_zero (d);
460     }
461     return res;
462   }
463 
464   if ((res = mp_init (&t)) != MP_OKAY) {
465     return res;
466   }
467 
468   /* get the remainder */
469   if (d != NULL) {
470     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
471       mp_clear (&t);
472       return res;
473     }
474   }
475 
476   /* copy */
477   if ((res = mp_copy (a, c)) != MP_OKAY) {
478     mp_clear (&t);
479     return res;
480   }
481 
482   /* shift by as many digits in the bit count */
483   if (b >= (int)DIGIT_BIT) {
484     mp_rshd (c, b / DIGIT_BIT);
485   }
486 
487   /* shift any bit count < DIGIT_BIT */
488   D = (b % DIGIT_BIT);
489   if (D != 0) {
490     mp_rshb(c, D);
491   }
492   mp_clamp (c);
493   if (d != NULL) {
494     mp_exch (&t, d);
495   }
496   mp_clear (&t);
497   return MP_OKAY;
498 }
499 
500 
501 /* set to zero */
mp_zero(mp_int * a)502 void mp_zero (mp_int * a)
503 {
504   int       n;
505   mp_digit *tmp;
506 
507   if (a == NULL)
508       return;
509 
510   a->sign = MP_ZPOS;
511   a->used = 0;
512 
513   tmp = a->dp;
514   for (n = 0; n < a->alloc; n++) {
515      *tmp++ = 0;
516   }
517 }
518 
519 
520 /* trim unused digits
521  *
522  * This is used to ensure that leading zero digits are
523  * trimmed and the leading "used" digit will be non-zero
524  * Typically very fast.  Also fixes the sign if there
525  * are no more leading digits
526  */
mp_clamp(mp_int * a)527 void mp_clamp (mp_int * a)
528 {
529   /* decrease used while the most significant digit is
530    * zero.
531    */
532   while (a->used > 0 && a->dp[a->used - 1] == 0) {
533     --(a->used);
534   }
535 
536   /* reset the sign flag if used == 0 */
537   if (a->used == 0) {
538     a->sign = MP_ZPOS;
539   }
540 }
541 
542 
543 /* swap the elements of two integers, for cases where you can't simply swap the
544  * mp_int pointers around
545  */
mp_exch(mp_int * a,mp_int * b)546 int mp_exch (mp_int * a, mp_int * b)
547 {
548   mp_int  t;
549 
550   t  = *a;
551   *a = *b;
552   *b = t;
553   return MP_OKAY;
554 }
555 
mp_cond_swap_ct(mp_int * a,mp_int * b,int c,int m)556 int mp_cond_swap_ct (mp_int * a, mp_int * b, int c, int m)
557 {
558     (void)c;
559     if (m == 1)
560         mp_exch(a, b);
561     return MP_OKAY;
562 }
563 
564 
565 /* shift right a certain number of bits */
mp_rshb(mp_int * c,int x)566 void mp_rshb (mp_int *c, int x)
567 {
568     mp_digit *tmpc, mask, shift;
569     mp_digit r, rr;
570     mp_digit D = x;
571 
572     /* shifting by a negative number not supported, and shifting by
573      * zero changes nothing.
574      */
575     if (x <= 0) return;
576 
577     /* shift digits first if needed */
578     if (x >= DIGIT_BIT) {
579         mp_rshd(c, x / DIGIT_BIT);
580         /* recalculate number of bits to shift */
581         D = x % DIGIT_BIT;
582         /* check if any more shifting needed */
583         if (D == 0) return;
584     }
585 
586     /* zero shifted is always zero */
587     if (mp_iszero(c)) return;
588 
589     /* mask */
590     mask = (((mp_digit)1) << D) - 1;
591 
592     /* shift for lsb */
593     shift = DIGIT_BIT - D;
594 
595     /* alias */
596     tmpc = c->dp + (c->used - 1);
597 
598     /* carry */
599     r = 0;
600     for (x = c->used - 1; x >= 0; x--) {
601       /* get the lower  bits of this word in a temp */
602       rr = *tmpc & mask;
603 
604       /* shift the current word and mix in the carry bits from previous word */
605       *tmpc = (*tmpc >> D) | (r << shift);
606       --tmpc;
607 
608       /* set the carry to the carry bits of the current word found above */
609       r = rr;
610     }
611     mp_clamp(c);
612 }
613 
614 
615 /* shift right a certain amount of digits */
mp_rshd(mp_int * a,int b)616 void mp_rshd (mp_int * a, int b)
617 {
618   int     x;
619 
620   /* if b <= 0 then ignore it */
621   if (b <= 0) {
622     return;
623   }
624 
625   /* if b > used then simply zero it and return */
626   if (a->used <= b) {
627     mp_zero (a);
628     return;
629   }
630 
631   {
632     mp_digit *bottom, *top;
633 
634     /* shift the digits down */
635 
636     /* bottom */
637     bottom = a->dp;
638 
639     /* top [offset into digits] */
640     top = a->dp + b;
641 
642     /* this is implemented as a sliding window where
643      * the window is b-digits long and digits from
644      * the top of the window are copied to the bottom
645      *
646      * e.g.
647 
648      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
649                  /\                   |      ---->
650                   \-------------------/      ---->
651      */
652     for (x = 0; x < (a->used - b); x++) {
653       *bottom++ = *top++;
654     }
655 
656     /* zero the top digits */
657     for (; x < a->used; x++) {
658       *bottom++ = 0;
659     }
660   }
661 
662   /* remove excess digits */
663   a->used -= b;
664 }
665 
666 
667 /* calc a value mod 2**b */
mp_mod_2d(mp_int * a,int b,mp_int * c)668 int mp_mod_2d (mp_int * a, int b, mp_int * c)
669 {
670   int     x, res, bmax;
671 
672   /* if b is <= 0 then zero the int */
673   if (b <= 0) {
674     mp_zero (c);
675     return MP_OKAY;
676   }
677 
678   /* if the modulus is larger than the value than return */
679   if (a->sign == MP_ZPOS && b >= (int) (a->used * DIGIT_BIT)) {
680     res = mp_copy (a, c);
681     return res;
682   }
683 
684   /* copy */
685   if ((res = mp_copy (a, c)) != MP_OKAY) {
686     return res;
687   }
688 
689   /* calculate number of digits in mod value */
690   bmax = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1);
691   /* zero digits above the last digit of the modulus */
692   for (x = bmax; x < c->used; x++) {
693     c->dp[x] = 0;
694   }
695 
696   if (c->sign == MP_NEG) {
697      mp_digit carry = 0;
698 
699      /* grow result to size of modulus */
700      if ((res = mp_grow(c, bmax)) != MP_OKAY) {
701          return res;
702      }
703      /* negate value */
704      for (x = 0; x < c->used; x++) {
705          mp_digit next = c->dp[x] > 0;
706          c->dp[x] = ((mp_digit)0 - c->dp[x] - carry) & MP_MASK;
707          carry |= next;
708      }
709      for (; x < bmax; x++) {
710          c->dp[x] = ((mp_digit)0 - carry) & MP_MASK;
711      }
712      c->used = bmax;
713      c->sign = MP_ZPOS;
714   }
715 
716   /* clear the digit that is not completely outside/inside the modulus */
717   x = DIGIT_BIT - (b % DIGIT_BIT);
718   if (x != DIGIT_BIT) {
719     c->dp[bmax - 1] &=
720          ((mp_digit)~((mp_digit)0)) >> (x + ((sizeof(mp_digit)*8) - DIGIT_BIT));
721   }
722   mp_clamp (c);
723   return MP_OKAY;
724 }
725 
726 
727 /* reads a unsigned char array, assumes the msb is stored first [big endian] */
mp_read_unsigned_bin(mp_int * a,const unsigned char * b,int c)728 int mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
729 {
730   int     res;
731   int     digits_needed;
732 
733   while (c > 0 && b[0] == 0) {
734       c--;
735       b++;
736   }
737 
738   digits_needed = ((c * CHAR_BIT) + DIGIT_BIT - 1) / DIGIT_BIT;
739 
740   /* make sure there are enough digits available */
741   if (a->alloc < digits_needed) {
742      if ((res = mp_grow(a, digits_needed)) != MP_OKAY) {
743         return res;
744      }
745   }
746 
747   /* zero the int */
748   mp_zero (a);
749 
750   /* read the bytes in */
751   while (c-- > 0) {
752     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
753       return res;
754     }
755 
756 #ifndef MP_8BIT
757       a->dp[0] |= *b++;
758       if (a->used == 0)
759           a->used = 1;
760 #else
761       a->dp[0] = (*b & MP_MASK);
762       a->dp[1] |= ((*b++ >> 7U) & 1);
763       if (a->used == 0)
764           a->used = 2;
765 #endif
766   }
767   mp_clamp (a);
768   return MP_OKAY;
769 }
770 
771 
772 /* shift left by a certain bit count */
mp_mul_2d(mp_int * a,int b,mp_int * c)773 int mp_mul_2d (mp_int * a, int b, mp_int * c)
774 {
775   mp_digit d;
776   int      res;
777 
778   /* copy */
779   if (a != c) {
780      if ((res = mp_copy (a, c)) != MP_OKAY) {
781        return res;
782      }
783   }
784 
785   if (c->alloc < (int)(c->used + b/DIGIT_BIT + 1)) {
786      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
787        return res;
788      }
789   }
790 
791   /* shift by as many digits in the bit count */
792   if (b >= (int)DIGIT_BIT) {
793     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
794       return res;
795     }
796   }
797 
798   /* shift any bit count < DIGIT_BIT */
799   d = (mp_digit) (b % DIGIT_BIT);
800   if (d != 0) {
801     mp_digit *tmpc, shift, mask, r, rr;
802     int x;
803 
804     /* bitmask for carries */
805     mask = (((mp_digit)1) << d) - 1;
806 
807     /* shift for msbs */
808     shift = DIGIT_BIT - d;
809 
810     /* alias */
811     tmpc = c->dp;
812 
813     /* carry */
814     r    = 0;
815     for (x = 0; x < c->used; x++) {
816       /* get the higher bits of the current word */
817       rr = (*tmpc >> shift) & mask;
818 
819       /* shift the current word and OR in the carry */
820       *tmpc = (mp_digit)(((*tmpc << d) | r) & MP_MASK);
821       ++tmpc;
822 
823       /* set the carry to the carry bits of the current word */
824       r = rr;
825     }
826 
827     /* set final carry */
828     if (r != 0) {
829        c->dp[(c->used)++] = r;
830     }
831   }
832   mp_clamp (c);
833   return MP_OKAY;
834 }
835 
836 
837 /* shift left a certain amount of digits */
mp_lshd(mp_int * a,int b)838 int mp_lshd (mp_int * a, int b)
839 {
840   int     x, res;
841 
842   /* if its less than zero return */
843   if (b <= 0) {
844     return MP_OKAY;
845   }
846 
847   /* grow to fit the new digits */
848   if (a->alloc < a->used + b) {
849      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
850        return res;
851      }
852   }
853 
854   {
855     mp_digit *top, *bottom;
856 
857     /* increment the used by the shift amount then copy upwards */
858     a->used += b;
859 
860     /* top */
861     top = a->dp + a->used - 1;
862 
863     /* base */
864     bottom = a->dp + a->used - 1 - b;
865 
866     /* much like mp_rshd this is implemented using a sliding window
867      * except the window goes the other way around.  Copying from
868      * the bottom to the top.  see bn_mp_rshd.c for more info.
869      */
870     for (x = a->used - 1; x >= b; x--) {
871       *top-- = *bottom--;
872     }
873 
874     /* zero the lower digits */
875     top = a->dp;
876     for (x = 0; x < b; x++) {
877       *top++ = 0;
878     }
879   }
880   return MP_OKAY;
881 }
882 
883 
884 /* this is a shell function that calls either the normal or Montgomery
885  * exptmod functions.  Originally the call to the montgomery code was
886  * embedded in the normal function but that wasted a lot of stack space
887  * for nothing (since 99% of the time the Montgomery code would be called)
888  */
889 #if defined(FREESCALE_LTC_TFM)
wolfcrypt_mp_exptmod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y)890 int wolfcrypt_mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y)
891 #else
892 int mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y)
893 #endif
894 {
895   int dr;
896 
897   /* modulus P must be positive */
898   if (mp_iszero(P) || P->sign == MP_NEG) {
899      return MP_VAL;
900   }
901   if (mp_isone(P)) {
902      return mp_set(Y, 0);
903   }
904   if (mp_iszero(X)) {
905      return mp_set(Y, 1);
906   }
907   if (mp_iszero(G)) {
908      return mp_set(Y, 0);
909   }
910 
911   /* if exponent X is negative we have to recurse */
912   if (X->sign == MP_NEG) {
913 #ifdef BN_MP_INVMOD_C
914      mp_int tmpG, tmpX;
915      int err;
916 
917      /* first compute 1/G mod P */
918      if ((err = mp_init(&tmpG)) != MP_OKAY) {
919         return err;
920      }
921      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
922         mp_clear(&tmpG);
923         return err;
924      }
925 
926      /* now get |X| */
927      if ((err = mp_init(&tmpX)) != MP_OKAY) {
928         mp_clear(&tmpG);
929         return err;
930      }
931      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
932         mp_clear(&tmpG);
933         mp_clear(&tmpX);
934         return err;
935      }
936 
937      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
938      err = mp_exptmod(&tmpG, &tmpX, P, Y);
939      mp_clear(&tmpG);
940      mp_clear(&tmpX);
941      return err;
942 #else
943      /* no invmod */
944      return MP_VAL;
945 #endif
946   }
947 
948 #ifdef BN_MP_EXPTMOD_BASE_2
949   if (G->used == 1 && G->dp[0] == 2) {
950     return mp_exptmod_base_2(X, P, Y);
951   }
952 #endif
953 
954 /* modified diminished radix reduction */
955 #if defined(BN_MP_REDUCE_IS_2K_L_C) && defined(BN_MP_REDUCE_2K_L_C) && \
956   defined(BN_S_MP_EXPTMOD_C)
957   if (mp_reduce_is_2k_l(P) == MP_YES) {
958      return s_mp_exptmod(G, X, P, Y, 1);
959   }
960 #endif
961 
962 #ifdef BN_MP_DR_IS_MODULUS_C
963   /* is it a DR modulus? */
964   dr = mp_dr_is_modulus(P);
965 #else
966   /* default to no */
967   dr = 0;
968 #endif
969 
970   (void)dr;
971 
972 #ifdef BN_MP_REDUCE_IS_2K_C
973   /* if not, is it a unrestricted DR modulus? */
974   if (dr == 0) {
975      dr = mp_reduce_is_2k(P) << 1;
976   }
977 #endif
978 
979   /* if the modulus is odd or dr != 0 use the montgomery method */
980 #ifdef BN_MP_EXPTMOD_FAST_C
981   if (mp_isodd (P) == MP_YES || dr !=  0) {
982     return mp_exptmod_fast (G, X, P, Y, dr);
983   } else {
984 #endif
985 #ifdef BN_S_MP_EXPTMOD_C
986     /* otherwise use the generic Barrett reduction technique */
987     return s_mp_exptmod (G, X, P, Y, 0);
988 #else
989     /* no exptmod for evens */
990     return MP_VAL;
991 #endif
992 #ifdef BN_MP_EXPTMOD_FAST_C
993   }
994 #endif
995 }
996 
mp_exptmod_ex(mp_int * G,mp_int * X,int digits,mp_int * P,mp_int * Y)997 int mp_exptmod_ex (mp_int * G, mp_int * X, int digits, mp_int * P, mp_int * Y)
998 {
999     (void)digits;
1000     return mp_exptmod(G, X, P, Y);
1001 }
1002 
1003 /* b = |a|
1004  *
1005  * Simple function copies the input and fixes the sign to positive
1006  */
mp_abs(mp_int * a,mp_int * b)1007 int mp_abs (mp_int * a, mp_int * b)
1008 {
1009   int     res;
1010 
1011   /* copy a to b */
1012   if (a != b) {
1013      if ((res = mp_copy (a, b)) != MP_OKAY) {
1014        return res;
1015      }
1016   }
1017 
1018   /* force the sign of b to positive */
1019   b->sign = MP_ZPOS;
1020 
1021   return MP_OKAY;
1022 }
1023 
1024 
1025 /* hac 14.61, pp608 */
1026 #if defined(FREESCALE_LTC_TFM)
wolfcrypt_mp_invmod(mp_int * a,mp_int * b,mp_int * c)1027 int wolfcrypt_mp_invmod(mp_int * a, mp_int * b, mp_int * c)
1028 #else
1029 int mp_invmod (mp_int * a, mp_int * b, mp_int * c)
1030 #endif
1031 {
1032   /* b cannot be negative or zero, and can not divide by 0 (1/a mod b) */
1033   if (b->sign == MP_NEG || mp_iszero(b) == MP_YES || mp_iszero(a) == MP_YES) {
1034     return MP_VAL;
1035   }
1036 
1037 #ifdef BN_FAST_MP_INVMOD_C
1038   /* if the modulus is odd we can use a faster routine instead */
1039   if ((mp_isodd(b) == MP_YES) && (mp_cmp_d(b, 1) != MP_EQ)) {
1040     return fast_mp_invmod (a, b, c);
1041   }
1042 #endif
1043 
1044 #ifdef BN_MP_INVMOD_SLOW_C
1045   return mp_invmod_slow(a, b, c);
1046 #else
1047   return MP_VAL;
1048 #endif
1049 }
1050 
1051 
1052 /* computes the modular inverse via binary extended euclidean algorithm,
1053  * that is c = 1/a mod b
1054  *
1055  * Based on slow invmod except this is optimized for the case where b is
1056  * odd as per HAC Note 14.64 on pp. 610
1057  */
fast_mp_invmod(mp_int * a,mp_int * b,mp_int * c)1058 int fast_mp_invmod (mp_int * a, mp_int * b, mp_int * c)
1059 {
1060   mp_int  x, y, u, v, B, D;
1061   int     res, neg, loop_check = 0;
1062 
1063   /* 2. [modified] b must be odd   */
1064   if (mp_iseven (b) == MP_YES) {
1065     return MP_VAL;
1066   }
1067 
1068   /* init all our temps */
1069   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D)) != MP_OKAY) {
1070      return res;
1071   }
1072 
1073   /* x == modulus, y == value to invert */
1074   if ((res = mp_copy (b, &x)) != MP_OKAY) {
1075     goto LBL_ERR;
1076   }
1077 
1078   /* we need y = |a| */
1079   if ((res = mp_mod (a, b, &y)) != MP_OKAY) {
1080     goto LBL_ERR;
1081   }
1082 
1083   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
1084   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
1085     goto LBL_ERR;
1086   }
1087   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
1088     goto LBL_ERR;
1089   }
1090   if ((res = mp_set (&D, 1)) != MP_OKAY) {
1091     goto LBL_ERR;
1092   }
1093 
1094 top:
1095   /* 4.  while u is even do */
1096   while (mp_iseven (&u) == MP_YES) {
1097     /* 4.1 u = u/2 */
1098     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
1099       goto LBL_ERR;
1100     }
1101     /* 4.2 if B is odd then */
1102     if (mp_isodd (&B) == MP_YES) {
1103       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
1104         goto LBL_ERR;
1105       }
1106     }
1107     /* B = B/2 */
1108     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
1109       goto LBL_ERR;
1110     }
1111   }
1112 
1113   /* 5.  while v is even do */
1114   while (mp_iseven (&v) == MP_YES) {
1115     /* 5.1 v = v/2 */
1116     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
1117       goto LBL_ERR;
1118     }
1119     /* 5.2 if D is odd then */
1120     if (mp_isodd (&D) == MP_YES) {
1121       /* D = (D-x)/2 */
1122       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
1123         goto LBL_ERR;
1124       }
1125     }
1126     /* D = D/2 */
1127     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
1128       goto LBL_ERR;
1129     }
1130   }
1131 
1132   /* 6.  if u >= v then */
1133   if (mp_cmp (&u, &v) != MP_LT) {
1134     /* u = u - v, B = B - D */
1135     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
1136       goto LBL_ERR;
1137     }
1138 
1139     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
1140       goto LBL_ERR;
1141     }
1142   } else {
1143     /* v - v - u, D = D - B */
1144     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
1145       goto LBL_ERR;
1146     }
1147 
1148     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
1149       goto LBL_ERR;
1150     }
1151   }
1152 
1153   /* if not zero goto step 4 */
1154   if (mp_iszero (&u) == MP_NO) {
1155     if (++loop_check > MAX_INVMOD_SZ) {
1156         res = MP_VAL;
1157         goto LBL_ERR;
1158     }
1159     goto top;
1160   }
1161 
1162   /* now a = C, b = D, gcd == g*v */
1163 
1164   /* if v != 1 then there is no inverse */
1165   if (mp_cmp_d (&v, 1) != MP_EQ) {
1166     res = MP_VAL;
1167     goto LBL_ERR;
1168   }
1169 
1170   /* b is now the inverse */
1171   neg = a->sign;
1172   while (D.sign == MP_NEG) {
1173     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
1174       goto LBL_ERR;
1175     }
1176   }
1177   /* too big */
1178   while (mp_cmp_mag(&D, b) != MP_LT) {
1179       if ((res = mp_sub(&D, b, &D)) != MP_OKAY) {
1180          goto LBL_ERR;
1181       }
1182   }
1183   mp_exch (&D, c);
1184   c->sign = neg;
1185   res = MP_OKAY;
1186 
1187 LBL_ERR:mp_clear(&x);
1188         mp_clear(&y);
1189         mp_clear(&u);
1190         mp_clear(&v);
1191         mp_clear(&B);
1192         mp_clear(&D);
1193   return res;
1194 }
1195 
1196 
1197 /* hac 14.61, pp608 */
mp_invmod_slow(mp_int * a,mp_int * b,mp_int * c)1198 int mp_invmod_slow (mp_int * a, mp_int * b, mp_int * c)
1199 {
1200   mp_int  x, y, u, v, A, B, C, D;
1201   int     res;
1202 
1203   /* b cannot be negative */
1204   if (b->sign == MP_NEG || mp_iszero(b) == MP_YES) {
1205     return MP_VAL;
1206   }
1207 
1208   /* init temps */
1209   if ((res = mp_init_multi(&x, &y, &u, &v,
1210                            &A, &B)) != MP_OKAY) {
1211     return res;
1212   }
1213 
1214   /* init rest of tmps temps */
1215   if ((res = mp_init_multi(&C, &D, 0, 0, 0, 0)) != MP_OKAY) {
1216     mp_clear(&x);
1217     mp_clear(&y);
1218     mp_clear(&u);
1219     mp_clear(&v);
1220     mp_clear(&A);
1221     mp_clear(&B);
1222     return res;
1223   }
1224 
1225   /* x = a, y = b */
1226   if ((res = mp_mod(a, b, &x)) != MP_OKAY) {
1227     goto LBL_ERR;
1228   }
1229   if (mp_isone(&x)) {
1230     res = mp_set(c, 1);
1231     goto LBL_ERR;
1232   }
1233   if ((res = mp_copy (b, &y)) != MP_OKAY) {
1234     goto LBL_ERR;
1235   }
1236 
1237   /* 2. [modified] if x,y are both even then return an error! */
1238   if (mp_iseven (&x) == MP_YES && mp_iseven (&y) == MP_YES) {
1239     res = MP_VAL;
1240     goto LBL_ERR;
1241   }
1242 
1243   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
1244   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
1245     goto LBL_ERR;
1246   }
1247   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
1248     goto LBL_ERR;
1249   }
1250   if ((res = mp_set (&A, 1)) != MP_OKAY) {
1251     goto LBL_ERR;
1252   }
1253   if ((res = mp_set (&D, 1)) != MP_OKAY) {
1254     goto LBL_ERR;
1255   }
1256 
1257 top:
1258   /* 4.  while u is even do */
1259   while (mp_iseven (&u) == MP_YES) {
1260     /* 4.1 u = u/2 */
1261     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
1262       goto LBL_ERR;
1263     }
1264     /* 4.2 if A or B is odd then */
1265     if (mp_isodd (&A) == MP_YES || mp_isodd (&B) == MP_YES) {
1266       /* A = (A+y)/2, B = (B-x)/2 */
1267       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
1268         goto LBL_ERR;
1269       }
1270       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
1271         goto LBL_ERR;
1272       }
1273     }
1274     /* A = A/2, B = B/2 */
1275     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
1276       goto LBL_ERR;
1277     }
1278     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
1279       goto LBL_ERR;
1280     }
1281   }
1282 
1283   /* 5.  while v is even do */
1284   while (mp_iseven (&v) == MP_YES) {
1285     /* 5.1 v = v/2 */
1286     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
1287       goto LBL_ERR;
1288     }
1289     /* 5.2 if C or D is odd then */
1290     if (mp_isodd (&C) == MP_YES || mp_isodd (&D) == MP_YES) {
1291       /* C = (C+y)/2, D = (D-x)/2 */
1292       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
1293         goto LBL_ERR;
1294       }
1295       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
1296         goto LBL_ERR;
1297       }
1298     }
1299     /* C = C/2, D = D/2 */
1300     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
1301       goto LBL_ERR;
1302     }
1303     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
1304       goto LBL_ERR;
1305     }
1306   }
1307 
1308   /* 6.  if u >= v then */
1309   if (mp_cmp (&u, &v) != MP_LT) {
1310     /* u = u - v, A = A - C, B = B - D */
1311     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
1312       goto LBL_ERR;
1313     }
1314 
1315     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
1316       goto LBL_ERR;
1317     }
1318 
1319     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
1320       goto LBL_ERR;
1321     }
1322   } else {
1323     /* v - v - u, C = C - A, D = D - B */
1324     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
1325       goto LBL_ERR;
1326     }
1327 
1328     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
1329       goto LBL_ERR;
1330     }
1331 
1332     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
1333       goto LBL_ERR;
1334     }
1335   }
1336 
1337   /* if not zero goto step 4 */
1338   if (mp_iszero (&u) == MP_NO)
1339     goto top;
1340 
1341   /* now a = C, b = D, gcd == g*v */
1342 
1343   /* if v != 1 then there is no inverse */
1344   if (mp_cmp_d (&v, 1) != MP_EQ) {
1345     res = MP_VAL;
1346     goto LBL_ERR;
1347   }
1348 
1349   /* if its too low */
1350   while (mp_cmp_d(&C, 0) == MP_LT) {
1351       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
1352          goto LBL_ERR;
1353       }
1354   }
1355 
1356   /* too big */
1357   while (mp_cmp_mag(&C, b) != MP_LT) {
1358       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
1359          goto LBL_ERR;
1360       }
1361   }
1362 
1363   /* C is now the inverse */
1364   mp_exch (&C, c);
1365   res = MP_OKAY;
1366 LBL_ERR:mp_clear(&x);
1367         mp_clear(&y);
1368         mp_clear(&u);
1369         mp_clear(&v);
1370         mp_clear(&A);
1371         mp_clear(&B);
1372         mp_clear(&C);
1373         mp_clear(&D);
1374   return res;
1375 }
1376 
1377 
1378 /* compare magnitude of two ints (unsigned) */
mp_cmp_mag(mp_int * a,mp_int * b)1379 int mp_cmp_mag (mp_int * a, mp_int * b)
1380 {
1381   int     n;
1382   mp_digit *tmpa, *tmpb;
1383 
1384   /* compare based on # of non-zero digits */
1385   if (a->used > b->used) {
1386     return MP_GT;
1387   }
1388 
1389   if (a->used < b->used) {
1390     return MP_LT;
1391   }
1392 
1393   /* alias for a */
1394   tmpa = a->dp + (a->used - 1);
1395 
1396   /* alias for b */
1397   tmpb = b->dp + (a->used - 1);
1398 
1399   /* compare based on digits  */
1400   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1401     if (*tmpa > *tmpb) {
1402       return MP_GT;
1403     }
1404 
1405     if (*tmpa < *tmpb) {
1406       return MP_LT;
1407     }
1408   }
1409   return MP_EQ;
1410 }
1411 
1412 
1413 /* compare two ints (signed)*/
mp_cmp(mp_int * a,mp_int * b)1414 int mp_cmp (mp_int * a, mp_int * b)
1415 {
1416   /* compare based on sign */
1417   if (a->sign != b->sign) {
1418      if (a->sign == MP_NEG) {
1419         return MP_LT;
1420      } else {
1421         return MP_GT;
1422      }
1423   }
1424 
1425   /* compare digits */
1426   if (a->sign == MP_NEG) {
1427      /* if negative compare opposite direction */
1428      return mp_cmp_mag(b, a);
1429   } else {
1430      return mp_cmp_mag(a, b);
1431   }
1432 }
1433 
1434 
1435 /* compare a digit */
mp_cmp_d(mp_int * a,mp_digit b)1436 int mp_cmp_d(mp_int * a, mp_digit b)
1437 {
1438   /* special case for zero*/
1439   if (a->used == 0 && b == 0)
1440     return MP_EQ;
1441 
1442   /* compare based on sign */
1443   if ((b && a->used == 0) || a->sign == MP_NEG) {
1444     return MP_LT;
1445   }
1446 
1447   /* compare based on magnitude */
1448   if (a->used > 1) {
1449     return MP_GT;
1450   }
1451 
1452   /* compare the only digit of a to b */
1453   if (a->dp[0] > b) {
1454     return MP_GT;
1455   } else if (a->dp[0] < b) {
1456     return MP_LT;
1457   } else {
1458     return MP_EQ;
1459   }
1460 }
1461 
1462 
1463 /* set to a digit */
mp_set(mp_int * a,mp_digit b)1464 int mp_set (mp_int * a, mp_digit b)
1465 {
1466   int res;
1467   mp_zero (a);
1468   res = mp_grow (a, 1);
1469   if (res == MP_OKAY) {
1470     a->dp[0] = (mp_digit)(b & MP_MASK);
1471     a->used  = (a->dp[0] != 0) ? 1 : 0;
1472   }
1473   return res;
1474 }
1475 
1476 /* check if a bit is set */
mp_is_bit_set(mp_int * a,mp_digit b)1477 int mp_is_bit_set (mp_int *a, mp_digit b)
1478 {
1479     mp_digit i = b / DIGIT_BIT;  /* word index */
1480     mp_digit s = b % DIGIT_BIT;  /* bit index */
1481 
1482     if ((mp_digit)a->used <= i) {
1483         /* no words available at that bit count */
1484         return 0;
1485     }
1486 
1487     /* get word and shift bit to check down to index 0 */
1488     return (int)((a->dp[i] >> s) & (mp_digit)1);
1489 }
1490 
1491 /* c = a mod b, 0 <= c < b */
1492 #if defined(FREESCALE_LTC_TFM)
wolfcrypt_mp_mod(mp_int * a,mp_int * b,mp_int * c)1493 int wolfcrypt_mp_mod(mp_int * a, mp_int * b, mp_int * c)
1494 #else
1495 int mp_mod (mp_int * a, mp_int * b, mp_int * c)
1496 #endif
1497 {
1498   mp_int  t;
1499   int     res;
1500 
1501   if ((res = mp_init_size (&t, b->used)) != MP_OKAY) {
1502     return res;
1503   }
1504 
1505   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
1506     mp_clear (&t);
1507     return res;
1508   }
1509 
1510   if ((mp_iszero(&t) != MP_NO) || (t.sign == b->sign)) {
1511     res = MP_OKAY;
1512     mp_exch (&t, c);
1513   } else {
1514     res = mp_add (b, &t, c);
1515   }
1516 
1517   mp_clear (&t);
1518   return res;
1519 }
1520 
1521 
1522 /* slower bit-bang division... also smaller */
mp_div(mp_int * a,mp_int * b,mp_int * c,mp_int * d)1523 int mp_div(mp_int * a, mp_int * b, mp_int * c, mp_int * d)
1524 {
1525    mp_int ta, tb, tq, q;
1526    int    res, n, n2;
1527 
1528   /* is divisor zero ? */
1529   if (mp_iszero (b) == MP_YES) {
1530     return MP_VAL;
1531   }
1532 
1533   /* if a < b then q=0, r = a */
1534   if (mp_cmp_mag (a, b) == MP_LT) {
1535     if (d != NULL) {
1536       res = mp_copy (a, d);
1537     } else {
1538       res = MP_OKAY;
1539     }
1540     if (c != NULL) {
1541       mp_zero (c);
1542     }
1543     return res;
1544   }
1545 
1546   /* init our temps */
1547   if ((res = mp_init_multi(&ta, &tb, &tq, &q, 0, 0)) != MP_OKAY) {
1548      return res;
1549   }
1550 
1551   if ((res = mp_set(&tq, 1)) != MP_OKAY) {
1552      return res;
1553   }
1554   n = mp_count_bits(a) - mp_count_bits(b);
1555   if (((res = mp_abs(a, &ta)) != MP_OKAY) ||
1556       ((res = mp_abs(b, &tb)) != MP_OKAY) ||
1557       ((res = mp_mul_2d(&tb, n, &tb)) != MP_OKAY) ||
1558       ((res = mp_mul_2d(&tq, n, &tq)) != MP_OKAY)) {
1559       goto LBL_ERR;
1560   }
1561 
1562   while (n-- >= 0) {
1563      if (mp_cmp(&tb, &ta) != MP_GT) {
1564         if (((res = mp_sub(&ta, &tb, &ta)) != MP_OKAY) ||
1565             ((res = mp_add(&q, &tq, &q)) != MP_OKAY)) {
1566            goto LBL_ERR;
1567         }
1568      }
1569      if (((res = mp_div_2d(&tb, 1, &tb, NULL)) != MP_OKAY) ||
1570          ((res = mp_div_2d(&tq, 1, &tq, NULL)) != MP_OKAY)) {
1571            goto LBL_ERR;
1572      }
1573   }
1574 
1575   /* now q == quotient and ta == remainder */
1576   n  = a->sign;
1577   n2 = (a->sign == b->sign ? MP_ZPOS : MP_NEG);
1578   if (c != NULL) {
1579      mp_exch(c, &q);
1580      c->sign  = (mp_iszero(c) == MP_YES) ? MP_ZPOS : n2;
1581   }
1582   if (d != NULL) {
1583      mp_exch(d, &ta);
1584      d->sign = (mp_iszero(d) == MP_YES) ? MP_ZPOS : n;
1585   }
1586 LBL_ERR:
1587    mp_clear(&ta);
1588    mp_clear(&tb);
1589    mp_clear(&tq);
1590    mp_clear(&q);
1591    return res;
1592 }
1593 
1594 
1595 /* b = a/2 */
mp_div_2(mp_int * a,mp_int * b)1596 int mp_div_2(mp_int * a, mp_int * b)
1597 {
1598   int     x, res, oldused;
1599 
1600   /* copy */
1601   if (b->alloc < a->used) {
1602     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1603       return res;
1604     }
1605   }
1606 
1607   oldused = b->used;
1608   b->used = a->used;
1609   {
1610     mp_digit r, rr, *tmpa, *tmpb;
1611 
1612     /* source alias */
1613     tmpa = a->dp + b->used - 1;
1614 
1615     /* dest alias */
1616     tmpb = b->dp + b->used - 1;
1617 
1618     /* carry */
1619     r = 0;
1620     for (x = b->used - 1; x >= 0; x--) {
1621       /* get the carry for the next iteration */
1622       rr = *tmpa & 1;
1623 
1624       /* shift the current digit, add in carry and store */
1625       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
1626 
1627       /* forward carry to next iteration */
1628       r = rr;
1629     }
1630 
1631     /* zero excess digits */
1632     tmpb = b->dp + b->used;
1633     for (x = b->used; x < oldused; x++) {
1634       *tmpb++ = 0;
1635     }
1636   }
1637   b->sign = a->sign;
1638   mp_clamp (b);
1639   return MP_OKAY;
1640 }
1641 
1642 /* c = a / 2 (mod b) - constant time (a < b and positive) */
mp_div_2_mod_ct(mp_int * a,mp_int * b,mp_int * c)1643 int mp_div_2_mod_ct(mp_int *a, mp_int *b, mp_int *c)
1644 {
1645     int res;
1646 
1647     if (mp_isodd(a)) {
1648         res = mp_add(a, b, c);
1649         if (res == MP_OKAY) {
1650             res = mp_div_2(c, c);
1651         }
1652     }
1653     else {
1654         res = mp_div_2(a, c);
1655     }
1656 
1657     return res;
1658 }
1659 
1660 
1661 /* high level addition (handles signs) */
mp_add(mp_int * a,mp_int * b,mp_int * c)1662 int mp_add (mp_int * a, mp_int * b, mp_int * c)
1663 {
1664   int sa, sb, res;
1665 
1666   /* get sign of both inputs */
1667   sa = a->sign;
1668   sb = b->sign;
1669 
1670   /* handle two cases, not four */
1671   if (sa == sb) {
1672     /* both positive or both negative */
1673     /* add their magnitudes, copy the sign */
1674     c->sign = sa;
1675     res = s_mp_add (a, b, c);
1676   } else {
1677     /* one positive, the other negative */
1678     /* subtract the one with the greater magnitude from */
1679     /* the one of the lesser magnitude.  The result gets */
1680     /* the sign of the one with the greater magnitude. */
1681     if (mp_cmp_mag (a, b) == MP_LT) {
1682       c->sign = sb;
1683       res = s_mp_sub (b, a, c);
1684     } else {
1685       c->sign = sa;
1686       res = s_mp_sub (a, b, c);
1687     }
1688   }
1689   return res;
1690 }
1691 
1692 
1693 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
s_mp_add(mp_int * a,mp_int * b,mp_int * c)1694 int s_mp_add (mp_int * a, mp_int * b, mp_int * c)
1695 {
1696   mp_int *x;
1697   int     olduse, res, min_ab, max_ab;
1698 
1699   /* find sizes, we let |a| <= |b| which means we have to sort
1700    * them.  "x" will point to the input with the most digits
1701    */
1702   if (a->used > b->used) {
1703     min_ab = b->used;
1704     max_ab = a->used;
1705     x = a;
1706   } else {
1707     min_ab = a->used;
1708     max_ab = b->used;
1709     x = b;
1710   }
1711 
1712   /* init result */
1713   if (c->alloc < max_ab + 1) {
1714     if ((res = mp_grow (c, max_ab + 1)) != MP_OKAY) {
1715       return res;
1716     }
1717   }
1718 
1719   /* get old used digit count and set new one */
1720   olduse = c->used;
1721   c->used = max_ab + 1;
1722 
1723   {
1724     mp_digit u, *tmpa, *tmpb, *tmpc;
1725     int i;
1726 
1727     /* alias for digit pointers */
1728 
1729     /* first input */
1730     tmpa = a->dp;
1731 
1732     /* second input */
1733     tmpb = b->dp;
1734 
1735     /* destination */
1736     tmpc = c->dp;
1737 
1738     /* zero the carry */
1739     u = 0;
1740     for (i = 0; i < min_ab; i++) {
1741       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
1742       *tmpc = *tmpa++ + *tmpb++ + u;
1743 
1744       /* U = carry bit of T[i] */
1745       u = *tmpc >> ((mp_digit)DIGIT_BIT);
1746 
1747       /* take away carry bit from T[i] */
1748       *tmpc++ &= MP_MASK;
1749     }
1750 
1751     /* now copy higher words if any, that is in A+B
1752      * if A or B has more digits add those in
1753      */
1754     if (min_ab != max_ab) {
1755       for (; i < max_ab; i++) {
1756         /* T[i] = X[i] + U */
1757         *tmpc = x->dp[i] + u;
1758 
1759         /* U = carry bit of T[i] */
1760         u = *tmpc >> ((mp_digit)DIGIT_BIT);
1761 
1762         /* take away carry bit from T[i] */
1763         *tmpc++ &= MP_MASK;
1764       }
1765     }
1766 
1767     /* add carry */
1768     *tmpc++ = u;
1769 
1770     /* clear digits above olduse */
1771     for (i = c->used; i < olduse; i++) {
1772       *tmpc++ = 0;
1773     }
1774   }
1775 
1776   mp_clamp (c);
1777   return MP_OKAY;
1778 }
1779 
1780 
1781 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
s_mp_sub(mp_int * a,mp_int * b,mp_int * c)1782 int s_mp_sub (mp_int * a, mp_int * b, mp_int * c)
1783 {
1784   int     olduse, res, min_b, max_a;
1785 
1786   /* find sizes */
1787   min_b = b->used;
1788   max_a = a->used;
1789 
1790   /* init result */
1791   if (c->alloc < max_a) {
1792     if ((res = mp_grow (c, max_a)) != MP_OKAY) {
1793       return res;
1794     }
1795   }
1796 
1797   /* sanity check on destination */
1798   if (c->dp == NULL)
1799      return MP_VAL;
1800 
1801   olduse = c->used;
1802   c->used = max_a;
1803 
1804   {
1805     mp_digit u, *tmpa, *tmpb, *tmpc;
1806     int i;
1807 
1808     /* alias for digit pointers */
1809     tmpa = a->dp;
1810     tmpb = b->dp;
1811     tmpc = c->dp;
1812 
1813     /* set carry to zero */
1814     u = 0;
1815     for (i = 0; i < min_b; i++) {
1816       /* T[i] = A[i] - B[i] - U */
1817       *tmpc = *tmpa++ - *tmpb++ - u;
1818 
1819       /* U = carry bit of T[i]
1820        * Note this saves performing an AND operation since
1821        * if a carry does occur it will propagate all the way to the
1822        * MSB.  As a result a single shift is enough to get the carry
1823        */
1824       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
1825 
1826       /* Clear carry from T[i] */
1827       *tmpc++ &= MP_MASK;
1828     }
1829 
1830     /* now copy higher words if any, e.g. if A has more digits than B  */
1831     for (; i < max_a; i++) {
1832       /* T[i] = A[i] - U */
1833       *tmpc = *tmpa++ - u;
1834 
1835       /* U = carry bit of T[i] */
1836       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
1837 
1838       /* Clear carry from T[i] */
1839       *tmpc++ &= MP_MASK;
1840     }
1841 
1842     /* clear digits above used (since we may not have grown result above) */
1843     for (i = c->used; i < olduse; i++) {
1844       *tmpc++ = 0;
1845     }
1846   }
1847 
1848   mp_clamp (c);
1849   return MP_OKAY;
1850 }
1851 
1852 
1853 /* high level subtraction (handles signs) */
mp_sub(mp_int * a,mp_int * b,mp_int * c)1854 int mp_sub (mp_int * a, mp_int * b, mp_int * c)
1855 {
1856   int     sa, sb, res;
1857 
1858   sa = a->sign;
1859   sb = b->sign;
1860 
1861   if (sa != sb) {
1862     /* subtract a negative from a positive, OR */
1863     /* subtract a positive from a negative. */
1864     /* In either case, ADD their magnitudes, */
1865     /* and use the sign of the first number. */
1866     c->sign = sa;
1867     res = s_mp_add (a, b, c);
1868   } else {
1869     /* subtract a positive from a positive, OR */
1870     /* subtract a negative from a negative. */
1871     /* First, take the difference between their */
1872     /* magnitudes, then... */
1873     if (mp_cmp_mag (a, b) != MP_LT) {
1874       /* Copy the sign from the first */
1875       c->sign = sa;
1876       /* The first has a larger or equal magnitude */
1877       res = s_mp_sub (a, b, c);
1878     } else {
1879       /* The result has the *opposite* sign from */
1880       /* the first number. */
1881       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
1882       /* The second has a larger magnitude */
1883       res = s_mp_sub (b, a, c);
1884     }
1885   }
1886   return res;
1887 }
1888 
1889 
1890 /* determines if reduce_2k_l can be used */
mp_reduce_is_2k_l(mp_int * a)1891 int mp_reduce_is_2k_l(mp_int *a)
1892 {
1893    int ix, iy;
1894 
1895    if (a->used == 0) {
1896       return MP_NO;
1897    } else if (a->used == 1) {
1898       return MP_YES;
1899    } else if (a->used > 1) {
1900       /* if more than half of the digits are -1 we're sold */
1901       for (iy = ix = 0; ix < a->used; ix++) {
1902           if (a->dp[ix] == MP_MASK) {
1903               ++iy;
1904           }
1905       }
1906       return (iy >= (a->used/2)) ? MP_YES : MP_NO;
1907 
1908    }
1909    return MP_NO;
1910 }
1911 
1912 
1913 /* determines if mp_reduce_2k can be used */
mp_reduce_is_2k(mp_int * a)1914 int mp_reduce_is_2k(mp_int *a)
1915 {
1916    int ix, iy, iw;
1917    mp_digit iz;
1918 
1919    if (a->used == 0) {
1920       return MP_NO;
1921    } else if (a->used == 1) {
1922       return MP_YES;
1923    } else if (a->used > 1) {
1924       iy = mp_count_bits(a);
1925       iz = 1;
1926       iw = 1;
1927 
1928       /* Test every bit from the second digit up, must be 1 */
1929       for (ix = DIGIT_BIT; ix < iy; ix++) {
1930           if ((a->dp[iw] & iz) == 0) {
1931              return MP_NO;
1932           }
1933           iz <<= 1;
1934           if (iz > (mp_digit)MP_MASK) {
1935              ++iw;
1936              iz = 1;
1937           }
1938       }
1939    }
1940    return MP_YES;
1941 }
1942 
1943 
1944 /* determines if a number is a valid DR modulus */
mp_dr_is_modulus(mp_int * a)1945 int mp_dr_is_modulus(mp_int *a)
1946 {
1947    int ix;
1948 
1949    /* must be at least two digits */
1950    if (a->used < 2) {
1951       return 0;
1952    }
1953 
1954    /* must be of the form b**k - a [a <= b] so all
1955     * but the first digit must be equal to -1 (mod b).
1956     */
1957    for (ix = 1; ix < a->used; ix++) {
1958        if (a->dp[ix] != MP_MASK) {
1959           return 0;
1960        }
1961    }
1962    return 1;
1963 }
1964 
1965 
1966 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1967  *
1968  * Uses a left-to-right k-ary sliding window to compute the modular
1969  * exponentiation.
1970  * The value of k changes based on the size of the exponent.
1971  *
1972  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1973  */
1974 
1975 #ifdef MP_LOW_MEM
1976    #define TAB_SIZE 32
1977 #else
1978    #define TAB_SIZE 256
1979 #endif
1980 
mp_exptmod_fast(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)1981 int mp_exptmod_fast (mp_int * G, mp_int * X, mp_int * P, mp_int * Y,
1982                      int redmode)
1983 {
1984   mp_int res;
1985   mp_digit buf, mp;
1986   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1987 #ifdef WOLFSSL_SMALL_STACK
1988   mp_int* M;
1989 #else
1990   mp_int M[TAB_SIZE];
1991 #endif
1992   /* use a pointer to the reduction algorithm.  This allows us to use
1993    * one of many reduction algorithms without modding the guts of
1994    * the code with if statements everywhere.
1995    */
1996   int     (*redux)(mp_int*,mp_int*,mp_digit) = NULL;
1997 
1998 #ifdef WOLFSSL_SMALL_STACK
1999   M = (mp_int*) XMALLOC(sizeof(mp_int) * TAB_SIZE, NULL,
2000                                                        DYNAMIC_TYPE_BIGINT);
2001   if (M == NULL)
2002     return MP_MEM;
2003 #endif
2004 
2005   /* find window size */
2006   x = mp_count_bits (X);
2007   if (x <= 7) {
2008     winsize = 2;
2009   } else if (x <= 36) {
2010     winsize = 3;
2011   } else if (x <= 140) {
2012     winsize = 4;
2013   } else if (x <= 450) {
2014     winsize = 5;
2015   } else if (x <= 1303) {
2016     winsize = 6;
2017   } else if (x <= 3529) {
2018     winsize = 7;
2019   } else {
2020     winsize = 8;
2021   }
2022 
2023 #ifdef MP_LOW_MEM
2024   if (winsize > 5) {
2025      winsize = 5;
2026   }
2027 #endif
2028 
2029   /* init M array */
2030   /* init first cell */
2031   if ((err = mp_init_size(&M[1], P->alloc)) != MP_OKAY) {
2032 #ifdef WOLFSSL_SMALL_STACK
2033      XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
2034 #endif
2035 
2036      return err;
2037   }
2038 
2039   /* now init the second half of the array */
2040   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2041     if ((err = mp_init_size(&M[x], P->alloc)) != MP_OKAY) {
2042       for (y = 1<<(winsize-1); y < x; y++) {
2043         mp_clear (&M[y]);
2044       }
2045       mp_clear(&M[1]);
2046 
2047 #ifdef WOLFSSL_SMALL_STACK
2048       XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
2049 #endif
2050 
2051       return err;
2052     }
2053   }
2054 
2055   /* determine and setup reduction code */
2056   if (redmode == 0) {
2057 #ifdef BN_MP_MONTGOMERY_SETUP_C
2058      /* now setup montgomery  */
2059      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
2060         goto LBL_M;
2061      }
2062 #else
2063      err = MP_VAL;
2064      goto LBL_M;
2065 #endif
2066 
2067      /* automatically pick the comba one if available (saves quite a few
2068         calls/ifs) */
2069 #ifdef BN_FAST_MP_MONTGOMERY_REDUCE_C
2070      if (((P->used * 2 + 1) < (int)MP_WARRAY) &&
2071           P->used < (1L << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2072         redux = fast_mp_montgomery_reduce;
2073      } else
2074 #endif
2075      {
2076 #ifdef BN_MP_MONTGOMERY_REDUCE_C
2077         /* use slower baseline Montgomery method */
2078         redux = mp_montgomery_reduce;
2079 #endif
2080      }
2081   } else if (redmode == 1) {
2082 #if defined(BN_MP_DR_SETUP_C) && defined(BN_MP_DR_REDUCE_C)
2083      /* setup DR reduction for moduli of the form B**k - b */
2084      mp_dr_setup(P, &mp);
2085      redux = mp_dr_reduce;
2086 #endif
2087   } else {
2088 #if defined(BN_MP_REDUCE_2K_SETUP_C) && defined(BN_MP_REDUCE_2K_C)
2089      /* setup DR reduction for moduli of the form 2**k - b */
2090      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
2091         goto LBL_M;
2092      }
2093      redux = mp_reduce_2k;
2094 #endif
2095   }
2096 
2097   if (redux == NULL) {
2098      err = MP_VAL;
2099      goto LBL_M;
2100   }
2101 
2102   /* setup result */
2103   if ((err = mp_init_size (&res, P->alloc)) != MP_OKAY) {
2104     goto LBL_M;
2105   }
2106 
2107   /* create M table
2108    *
2109 
2110    *
2111    * The first half of the table is not computed though accept for M[0] and M[1]
2112    */
2113 
2114   if (redmode == 0) {
2115 #ifdef BN_MP_MONTGOMERY_CALC_NORMALIZATION_C
2116      /* now we need R mod m */
2117      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
2118        goto LBL_RES;
2119      }
2120 
2121      /* now set M[1] to G * R mod m */
2122      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
2123        goto LBL_RES;
2124      }
2125 #else
2126      err = MP_VAL;
2127      goto LBL_RES;
2128 #endif
2129   } else {
2130      if ((err = mp_set(&res, 1)) != MP_OKAY) {
2131         goto LBL_RES;
2132      }
2133      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
2134         goto LBL_RES;
2135      }
2136   }
2137 
2138   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times*/
2139   if ((err = mp_copy (&M[1], &M[(mp_digit)(1 << (winsize - 1))])) != MP_OKAY) {
2140     goto LBL_RES;
2141   }
2142 
2143   for (x = 0; x < (winsize - 1); x++) {
2144     if ((err = mp_sqr (&M[(mp_digit)(1 << (winsize - 1))],
2145                        &M[(mp_digit)(1 << (winsize - 1))])) != MP_OKAY) {
2146       goto LBL_RES;
2147     }
2148     if ((err = redux (&M[(mp_digit)(1 << (winsize - 1))], P, mp)) != MP_OKAY) {
2149       goto LBL_RES;
2150     }
2151   }
2152 
2153   /* create upper table */
2154   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2155     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2156       goto LBL_RES;
2157     }
2158     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
2159       goto LBL_RES;
2160     }
2161   }
2162 
2163   /* set initial mode and bit cnt */
2164   mode   = 0;
2165   bitcnt = 1;
2166   buf    = 0;
2167   digidx = X->used - 1;
2168   bitcpy = 0;
2169   bitbuf = 0;
2170 
2171   for (;;) {
2172     /* grab next digit as required */
2173     if (--bitcnt == 0) {
2174       /* if digidx == -1 we are out of digits so break */
2175       if (digidx == -1) {
2176         break;
2177       }
2178       /* read next digit and reset bitcnt */
2179       buf    = X->dp[digidx--];
2180       bitcnt = (int)DIGIT_BIT;
2181     }
2182 
2183     /* grab the next msb from the exponent */
2184     y     = (int)(buf >> (DIGIT_BIT - 1)) & 1;
2185     buf <<= (mp_digit)1;
2186 
2187     /* if the bit is zero and mode == 0 then we ignore it
2188      * These represent the leading zero bits before the first 1 bit
2189      * in the exponent.  Technically this opt is not required but it
2190      * does lower the # of trivial squaring/reductions used
2191      */
2192     if (mode == 0 && y == 0) {
2193       continue;
2194     }
2195 
2196     /* if the bit is zero and mode == 1 then we square */
2197     if (mode == 1 && y == 0) {
2198       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2199         goto LBL_RES;
2200       }
2201       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2202         goto LBL_RES;
2203       }
2204       continue;
2205     }
2206 
2207     /* else we add it to the window */
2208     bitbuf |= (y << (winsize - ++bitcpy));
2209     mode    = 2;
2210 
2211     if (bitcpy == winsize) {
2212       /* ok window is filled so square as required and multiply  */
2213       /* square first */
2214       for (x = 0; x < winsize; x++) {
2215         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2216           goto LBL_RES;
2217         }
2218         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2219           goto LBL_RES;
2220         }
2221       }
2222 
2223       /* then multiply */
2224       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2225         goto LBL_RES;
2226       }
2227       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2228         goto LBL_RES;
2229       }
2230 
2231       /* empty window and reset */
2232       bitcpy = 0;
2233       bitbuf = 0;
2234       mode   = 1;
2235     }
2236   }
2237 
2238   /* if bits remain then square/multiply */
2239   if (mode == 2 && bitcpy > 0) {
2240     /* square then multiply if the bit is set */
2241     for (x = 0; x < bitcpy; x++) {
2242       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2243         goto LBL_RES;
2244       }
2245       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2246         goto LBL_RES;
2247       }
2248 
2249       /* get next bit of the window */
2250       bitbuf <<= 1;
2251       if ((bitbuf & (1 << winsize)) != 0) {
2252         /* then multiply */
2253         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2254           goto LBL_RES;
2255         }
2256         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2257           goto LBL_RES;
2258         }
2259       }
2260     }
2261   }
2262 
2263   if (redmode == 0) {
2264      /* fixup result if Montgomery reduction is used
2265       * recall that any value in a Montgomery system is
2266       * actually multiplied by R mod n.  So we have
2267       * to reduce one more time to cancel out the factor
2268       * of R.
2269       */
2270      if ((err = redux(&res, P, mp)) != MP_OKAY) {
2271        goto LBL_RES;
2272      }
2273   }
2274 
2275   /* swap res with Y */
2276   mp_exch (&res, Y);
2277   err = MP_OKAY;
2278 LBL_RES:mp_clear (&res);
2279 LBL_M:
2280   mp_clear(&M[1]);
2281   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2282     mp_clear (&M[x]);
2283   }
2284 
2285 #ifdef WOLFSSL_SMALL_STACK
2286   XFREE(M, NULL, DYNAMIC_TYPE_BIGINT);
2287 #endif
2288 
2289   return err;
2290 }
2291 
2292 #ifdef BN_MP_EXPTMOD_BASE_2
2293 #if DIGIT_BIT < 16
2294     #define WINSIZE    3
2295 #elif DIGIT_BIT < 32
2296     #define WINSIZE    4
2297 #elif DIGIT_BIT < 64
2298     #define WINSIZE    5
2299 #elif DIGIT_BIT < 128
2300     #define WINSIZE    6
2301 #endif
mp_exptmod_base_2(mp_int * X,mp_int * P,mp_int * Y)2302 int mp_exptmod_base_2(mp_int * X, mp_int * P, mp_int * Y)
2303 {
2304   mp_digit buf, mp;
2305   int      err = MP_OKAY, bitbuf, bitcpy, bitcnt, digidx, x, y;
2306 #ifdef WOLFSSL_SMALL_STACK
2307   mp_int  *res = NULL;
2308 #else
2309   mp_int   res[1];
2310 #endif
2311   int     (*redux)(mp_int*,mp_int*,mp_digit) = NULL;
2312 
2313   /* automatically pick the comba one if available (saves quite a few
2314      calls/ifs) */
2315 #ifdef BN_FAST_MP_MONTGOMERY_REDUCE_C
2316   if (((P->used * 2 + 1) < (int)MP_WARRAY) &&
2317        P->used < (1L << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2318      redux = fast_mp_montgomery_reduce;
2319   } else
2320 #endif
2321 #ifdef BN_MP_MONTGOMERY_REDUCE_C
2322   {
2323      /* use slower baseline Montgomery method */
2324      redux = mp_montgomery_reduce;
2325   }
2326 #endif
2327 
2328   if (redux == NULL) {
2329       return MP_VAL;
2330   }
2331 
2332 #ifdef WOLFSSL_SMALL_STACK
2333   res = (mp_int*)XMALLOC(sizeof(mp_int), NULL, DYNAMIC_TYPE_TMP_BUFFER);
2334   if (res == NULL) {
2335      return MP_MEM;
2336   }
2337 #endif
2338 
2339   /* now setup montgomery  */
2340   if ((err = mp_montgomery_setup(P, &mp)) != MP_OKAY) {
2341      goto LBL_M;
2342   }
2343 
2344   /* setup result */
2345   if ((err = mp_init(res)) != MP_OKAY) {
2346      goto LBL_M;
2347   }
2348 
2349   /* now we need R mod m */
2350   if ((err = mp_montgomery_calc_normalization(res, P)) != MP_OKAY) {
2351      goto LBL_RES;
2352   }
2353 
2354   /* Get the top bits left over after taking WINSIZE bits starting at the
2355    * least-significant.
2356    */
2357   digidx = X->used - 1;
2358   bitcpy = (X->used * DIGIT_BIT) % WINSIZE;
2359   if (bitcpy > 0) {
2360      bitcnt = (int)DIGIT_BIT - bitcpy;
2361      buf    = X->dp[digidx--];
2362      bitbuf = (int)(buf >> bitcnt);
2363      /* Multiply montgomery representation of 1 by 2 ^ top */
2364      err = mp_mul_2d(res, bitbuf, res);
2365      if (err != MP_OKAY) {
2366         goto LBL_RES;
2367      }
2368      err = mp_mod(res, P, res);
2369      if (err != MP_OKAY) {
2370         goto LBL_RES;
2371      }
2372      /* Move out bits used */
2373      buf  <<= bitcpy;
2374      bitcnt++;
2375   }
2376   else {
2377      bitcnt = 1;
2378      buf    = 0;
2379   }
2380 
2381   /* empty window and reset  */
2382   bitbuf = 0;
2383   bitcpy = 0;
2384 
2385   for (;;) {
2386     /* grab next digit as required */
2387     if (--bitcnt == 0) {
2388       /* if digidx == -1 we are out of digits so break */
2389       if (digidx == -1) {
2390         break;
2391       }
2392       /* read next digit and reset bitcnt */
2393       buf    = X->dp[digidx--];
2394       bitcnt = (int)DIGIT_BIT;
2395     }
2396 
2397     /* grab the next msb from the exponent */
2398     y       = (int)(buf >> (DIGIT_BIT - 1)) & 1;
2399     buf   <<= (mp_digit)1;
2400     /* add bit to the window */
2401     bitbuf |= (y << (WINSIZE - ++bitcpy));
2402 
2403     if (bitcpy == WINSIZE) {
2404       /* ok window is filled so square as required and multiply  */
2405       /* square first */
2406       for (x = 0; x < WINSIZE; x++) {
2407         err = mp_sqr(res, res);
2408         if (err != MP_OKAY) {
2409           goto LBL_RES;
2410         }
2411         err = (*redux)(res, P, mp);
2412         if (err != MP_OKAY) {
2413           goto LBL_RES;
2414         }
2415       }
2416 
2417       /* then multiply by 2^bitbuf */
2418       err = mp_mul_2d(res, bitbuf, res);
2419       if (err != MP_OKAY) {
2420          goto LBL_RES;
2421       }
2422       err = mp_mod(res, P, res);
2423       if (err != MP_OKAY) {
2424          goto LBL_RES;
2425       }
2426 
2427       /* empty window and reset */
2428       bitcpy = 0;
2429       bitbuf = 0;
2430     }
2431   }
2432 
2433   /* fixup result if Montgomery reduction is used
2434    * recall that any value in a Montgomery system is
2435    * actually multiplied by R mod n.  So we have
2436    * to reduce one more time to cancel out the factor
2437    * of R.
2438    */
2439   err = (*redux)(res, P, mp);
2440   if (err != MP_OKAY) {
2441      goto LBL_RES;
2442   }
2443 
2444   /* swap res with Y */
2445   err = mp_copy(res, Y);
2446 
2447 LBL_RES:mp_clear (res);
2448 LBL_M:
2449 #ifdef WOLFSSL_SMALL_STACK
2450   XFREE(res, NULL, DYNAMIC_TYPE_TMP_BUFFER);
2451 #endif
2452   return err;
2453 }
2454 
2455 #undef WINSIZE
2456 #endif /* BN_MP_EXPTMOD_BASE_2 */
2457 
2458 
2459 /* setups the montgomery reduction stuff */
mp_montgomery_setup(mp_int * n,mp_digit * rho)2460 int mp_montgomery_setup (mp_int * n, mp_digit * rho)
2461 {
2462   mp_digit x, b;
2463 
2464 /* fast inversion mod 2**k
2465  *
2466  * Based on the fact that
2467  *
2468  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
2469  *                    =>  2*X*A - X*X*A*A = 1
2470  *                    =>  2*(1) - (1)     = 1
2471  */
2472   b = n->dp[0];
2473 
2474   if ((b & 1) == 0) {
2475     return MP_VAL;
2476   }
2477 
2478   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
2479   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
2480 #if !defined(MP_8BIT)
2481   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
2482 #endif
2483 #if defined(MP_64BIT) || !(defined(MP_8BIT) || defined(MP_16BIT))
2484   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
2485 #endif
2486 #ifdef MP_64BIT
2487   x *= 2 - b * x;               /* here x*a==1 mod 2**64 */
2488 #endif
2489 
2490   /* rho = -1/m mod b */
2491   /* TAO, switched mp_word casts to mp_digit to shut up compiler */
2492   *rho = (mp_digit)((((mp_digit)1 << ((mp_digit) DIGIT_BIT)) - x) & MP_MASK);
2493 
2494   return MP_OKAY;
2495 }
2496 
2497 
2498 /* computes xR**-1 == x (mod N) via Montgomery Reduction
2499  *
2500  * This is an optimized implementation of montgomery_reduce
2501  * which uses the comba method to quickly calculate the columns of the
2502  * reduction.
2503  *
2504  * Based on Algorithm 14.32 on pp.601 of HAC.
2505 */
fast_mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)2506 int fast_mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho)
2507 {
2508   int     ix, res, olduse;
2509 #ifdef WOLFSSL_SMALL_STACK
2510   mp_word* W;    /* uses dynamic memory and slower */
2511 #else
2512   mp_word W[MP_WARRAY];
2513 #endif
2514 
2515   /* get old used count */
2516   olduse = x->used;
2517 
2518   /* grow a as required */
2519   if (x->alloc < n->used + 1) {
2520     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
2521       return res;
2522     }
2523   }
2524 
2525 #ifdef WOLFSSL_SMALL_STACK
2526   W = (mp_word*)XMALLOC(sizeof(mp_word) * MP_WARRAY, NULL, DYNAMIC_TYPE_BIGINT);
2527   if (W == NULL)
2528     return MP_MEM;
2529 #endif
2530 
2531   XMEMSET(W, 0, (n->used * 2 + 1) * sizeof(mp_word));
2532 
2533   /* first we have to get the digits of the input into
2534    * an array of double precision words W[...]
2535    */
2536   {
2537     mp_word *_W;
2538     mp_digit *tmpx;
2539 
2540     /* alias for the W[] array */
2541     _W   = W;
2542 
2543     /* alias for the digits of  x*/
2544     tmpx = x->dp;
2545 
2546     /* copy the digits of a into W[0..a->used-1] */
2547     for (ix = 0; ix < x->used; ix++) {
2548       *_W++ = *tmpx++;
2549     }
2550   }
2551 
2552   /* now we proceed to zero successive digits
2553    * from the least significant upwards
2554    */
2555   for (ix = 0; ix < n->used; ix++) {
2556     /* mu = ai * m' mod b
2557      *
2558      * We avoid a double precision multiplication (which isn't required)
2559      * by casting the value down to a mp_digit.  Note this requires
2560      * that W[ix-1] have  the carry cleared (see after the inner loop)
2561      */
2562     mp_digit mu;
2563     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
2564 
2565     /* a = a + mu * m * b**i
2566      *
2567      * This is computed in place and on the fly.  The multiplication
2568      * by b**i is handled by offsetting which columns the results
2569      * are added to.
2570      *
2571      * Note the comba method normally doesn't handle carries in the
2572      * inner loop In this case we fix the carry from the previous
2573      * column since the Montgomery reduction requires digits of the
2574      * result (so far) [see above] to work.  This is
2575      * handled by fixing up one carry after the inner loop.  The
2576      * carry fixups are done in order so after these loops the
2577      * first m->used words of W[] have the carries fixed
2578      */
2579     {
2580       int iy;
2581       mp_digit *tmpn;
2582       mp_word *_W;
2583 
2584       /* alias for the digits of the modulus */
2585       tmpn = n->dp;
2586 
2587       /* Alias for the columns set by an offset of ix */
2588       _W = W + ix;
2589 
2590       /* inner loop */
2591       for (iy = 0; iy < n->used; iy++) {
2592           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
2593       }
2594     }
2595 
2596     /* now fix carry for next digit, W[ix+1] */
2597     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
2598   }
2599 
2600   /* now we have to propagate the carries and
2601    * shift the words downward [all those least
2602    * significant digits we zeroed].
2603    */
2604   {
2605     mp_digit *tmpx;
2606     mp_word *_W, *_W1;
2607 
2608     /* nox fix rest of carries */
2609 
2610     /* alias for current word */
2611     _W1 = W + ix;
2612 
2613     /* alias for next word, where the carry goes */
2614     _W = W + ++ix;
2615 
2616     for (; ix <= n->used * 2 + 1; ix++) {
2617       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
2618     }
2619 
2620     /* copy out, A = A/b**n
2621      *
2622      * The result is A/b**n but instead of converting from an
2623      * array of mp_word to mp_digit than calling mp_rshd
2624      * we just copy them in the right order
2625      */
2626 
2627     /* alias for destination word */
2628     tmpx = x->dp;
2629 
2630     /* alias for shifted double precision result */
2631     _W = W + n->used;
2632 
2633     for (ix = 0; ix < n->used + 1; ix++) {
2634       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
2635     }
2636 
2637     /* zero olduse digits, if the input a was larger than
2638      * m->used+1 we'll have to clear the digits
2639      */
2640     for (; ix < olduse; ix++) {
2641       *tmpx++ = 0;
2642     }
2643   }
2644 
2645   /* set the max used and clamp */
2646   x->used = n->used + 1;
2647   mp_clamp (x);
2648 
2649 #ifdef WOLFSSL_SMALL_STACK
2650   XFREE(W, NULL, DYNAMIC_TYPE_BIGINT);
2651 #endif
2652 
2653   /* if A >= m then A = A - m */
2654   if (mp_cmp_mag (x, n) != MP_LT) {
2655     return s_mp_sub (x, n, x);
2656   }
2657   return MP_OKAY;
2658 }
2659 
2660 
2661 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
mp_montgomery_reduce(mp_int * x,mp_int * n,mp_digit rho)2662 int mp_montgomery_reduce (mp_int * x, mp_int * n, mp_digit rho)
2663 {
2664   int     ix, res, digs;
2665   mp_digit mu;
2666 
2667   /* can the fast reduction [comba] method be used?
2668    *
2669    * Note that unlike in mul you're safely allowed *less*
2670    * than the available columns [255 per default] since carries
2671    * are fixed up in the inner loop.
2672    */
2673   digs = n->used * 2 + 1;
2674   if ((digs < (int)MP_WARRAY) &&
2675       n->used <
2676       (1L << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2677     return fast_mp_montgomery_reduce (x, n, rho);
2678   }
2679 
2680   /* grow the input as required */
2681   if (x->alloc < digs) {
2682     if ((res = mp_grow (x, digs)) != MP_OKAY) {
2683       return res;
2684     }
2685   }
2686   x->used = digs;
2687 
2688   for (ix = 0; ix < n->used; ix++) {
2689     /* mu = ai * rho mod b
2690      *
2691      * The value of rho must be precalculated via
2692      * montgomery_setup() such that
2693      * it equals -1/n0 mod b this allows the
2694      * following inner loop to reduce the
2695      * input one digit at a time
2696      */
2697     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
2698 
2699     /* a = a + mu * m * b**i */
2700     {
2701       int iy;
2702       mp_digit *tmpn, *tmpx, u;
2703       mp_word r;
2704 
2705       /* alias for digits of the modulus */
2706       tmpn = n->dp;
2707 
2708       /* alias for the digits of x [the input] */
2709       tmpx = x->dp + ix;
2710 
2711       /* set the carry to zero */
2712       u = 0;
2713 
2714       /* Multiply and add in place */
2715       for (iy = 0; iy < n->used; iy++) {
2716         /* compute product and sum */
2717         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
2718                   ((mp_word) u) + ((mp_word) * tmpx);
2719 
2720         /* get carry */
2721         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2722 
2723         /* fix digit */
2724         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
2725       }
2726       /* At this point the ix'th digit of x should be zero */
2727 
2728 
2729       /* propagate carries upwards as required*/
2730       while (u) {
2731         *tmpx   += u;
2732         u        = *tmpx >> DIGIT_BIT;
2733         *tmpx++ &= MP_MASK;
2734       }
2735     }
2736   }
2737 
2738   /* at this point the n.used'th least
2739    * significant digits of x are all zero
2740    * which means we can shift x to the
2741    * right by n.used digits and the
2742    * residue is unchanged.
2743    */
2744 
2745   /* x = x/b**n.used */
2746   mp_clamp(x);
2747   mp_rshd (x, n->used);
2748 
2749   /* if x >= n then x = x - n */
2750   if (mp_cmp_mag (x, n) != MP_LT) {
2751     return s_mp_sub (x, n, x);
2752   }
2753 
2754   return MP_OKAY;
2755 }
2756 
2757 
2758 /* determines the setup value */
mp_dr_setup(mp_int * a,mp_digit * d)2759 void mp_dr_setup(mp_int *a, mp_digit *d)
2760 {
2761    /* the casts are required if DIGIT_BIT is one less than
2762     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
2763     */
2764    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) -
2765         ((mp_word)a->dp[0]));
2766 }
2767 
2768 
2769 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
2770  *
2771  * Based on algorithm from the paper
2772  *
2773  * "Generating Efficient Primes for Discrete Log Cryptosystems"
2774  *                 Chae Hoon Lim, Pil Joong Lee,
2775  *          POSTECH Information Research Laboratories
2776  *
2777  * The modulus must be of a special format [see manual]
2778  *
2779  * Has been modified to use algorithm 7.10 from the LTM book instead
2780  *
2781  * Input x must be in the range 0 <= x <= (n-1)**2
2782  */
mp_dr_reduce(mp_int * x,mp_int * n,mp_digit k)2783 int mp_dr_reduce (mp_int * x, mp_int * n, mp_digit k)
2784 {
2785   int      err, i, m;
2786   mp_word  r;
2787   mp_digit mu, *tmpx1, *tmpx2;
2788 
2789   /* m = digits in modulus */
2790   m = n->used;
2791 
2792   /* ensure that "x" has at least 2m digits */
2793   if (x->alloc < m + m) {
2794     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
2795       return err;
2796     }
2797   }
2798 
2799 /* top of loop, this is where the code resumes if
2800  * another reduction pass is required.
2801  */
2802 top:
2803   /* aliases for digits */
2804   /* alias for lower half of x */
2805   tmpx1 = x->dp;
2806 
2807   /* alias for upper half of x, or x/B**m */
2808   tmpx2 = x->dp + m;
2809 
2810   /* set carry to zero */
2811   mu = 0;
2812 
2813   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
2814   for (i = 0; i < m; i++) {
2815       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
2816       *tmpx1++  = (mp_digit)(r & MP_MASK);
2817       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
2818   }
2819 
2820   /* set final carry */
2821   *tmpx1++ = mu;
2822 
2823   /* zero words above m */
2824   for (i = m + 1; i < x->used; i++) {
2825       *tmpx1++ = 0;
2826   }
2827 
2828   /* clamp, sub and return */
2829   mp_clamp (x);
2830 
2831   /* if x >= n then subtract and reduce again
2832    * Each successive "recursion" makes the input smaller and smaller.
2833    */
2834   if (mp_cmp_mag (x, n) != MP_LT) {
2835     if ((err = s_mp_sub(x, n, x)) != MP_OKAY) {
2836         return err;
2837     }
2838     goto top;
2839   }
2840   return MP_OKAY;
2841 }
2842 
2843 
2844 /* reduces a modulo n where n is of the form 2**p - d */
mp_reduce_2k(mp_int * a,mp_int * n,mp_digit d)2845 int mp_reduce_2k(mp_int *a, mp_int *n, mp_digit d)
2846 {
2847    mp_int q;
2848    int    p, res;
2849 
2850    if ((res = mp_init(&q)) != MP_OKAY) {
2851       return res;
2852    }
2853 
2854    p = mp_count_bits(n);
2855 top:
2856    /* q = a/2**p, a = a mod 2**p */
2857    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
2858       goto ERR;
2859    }
2860 
2861    if (d != 1) {
2862       /* q = q * d */
2863       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) {
2864          goto ERR;
2865       }
2866    }
2867 
2868    /* a = a + q */
2869    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
2870       goto ERR;
2871    }
2872 
2873    if (mp_cmp_mag(a, n) != MP_LT) {
2874       if ((res = s_mp_sub(a, n, a)) != MP_OKAY) {
2875          goto ERR;
2876       }
2877       goto top;
2878    }
2879 
2880 ERR:
2881    mp_clear(&q);
2882    return res;
2883 }
2884 
2885 
2886 /* determines the setup value */
mp_reduce_2k_setup(mp_int * a,mp_digit * d)2887 int mp_reduce_2k_setup(mp_int *a, mp_digit *d)
2888 {
2889    int res, p;
2890    mp_int tmp;
2891 
2892    if ((res = mp_init(&tmp)) != MP_OKAY) {
2893       return res;
2894    }
2895 
2896    p = mp_count_bits(a);
2897    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
2898       mp_clear(&tmp);
2899       return res;
2900    }
2901 
2902    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
2903       mp_clear(&tmp);
2904       return res;
2905    }
2906 
2907    *d = tmp.dp[0];
2908    mp_clear(&tmp);
2909    return MP_OKAY;
2910 }
2911 
2912 
2913 /* set the b bit of a */
mp_set_bit(mp_int * a,int b)2914 int mp_set_bit (mp_int * a, int b)
2915 {
2916     int i = b / DIGIT_BIT, res;
2917 
2918     /*
2919      * Require:
2920      *  bit index b >= 0
2921      *  a->alloc == a->used == 0 if a->dp == NULL
2922      */
2923     if (b < 0 || (a->dp == NULL && (a->alloc != 0 || a->used != 0)))
2924         return MP_VAL;
2925 
2926     if (a->dp == NULL || a->used < (int)(i + 1)) {
2927         /* grow a to accommodate the single bit */
2928         if ((res = mp_grow (a, i + 1)) != MP_OKAY) {
2929             return res;
2930         }
2931 
2932         /* set the used count of where the bit will go */
2933         a->used = (int)(i + 1);
2934     }
2935 
2936     /* put the single bit in its place */
2937     a->dp[i] |= ((mp_digit)1) << (b % DIGIT_BIT);
2938 
2939     return MP_OKAY;
2940 }
2941 
2942 /* computes a = 2**b
2943  *
2944  * Simple algorithm which zeros the int, set the required bit
2945  */
mp_2expt(mp_int * a,int b)2946 int mp_2expt (mp_int * a, int b)
2947 {
2948     /* zero a as per default */
2949     mp_zero (a);
2950 
2951     return mp_set_bit(a, b);
2952 }
2953 
2954 /* multiply by a digit */
mp_mul_d(mp_int * a,mp_digit b,mp_int * c)2955 int mp_mul_d (mp_int * a, mp_digit b, mp_int * c)
2956 {
2957   mp_digit u, *tmpa, *tmpc;
2958   mp_word  r;
2959   int      ix, res, olduse;
2960 
2961   /* make sure c is big enough to hold a*b */
2962   if (c->alloc < a->used + 1) {
2963     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
2964       return res;
2965     }
2966   }
2967 
2968   /* get the original destinations used count */
2969   olduse = c->used;
2970 
2971   /* set the sign */
2972   c->sign = a->sign;
2973 
2974   /* alias for a->dp [source] */
2975   tmpa = a->dp;
2976 
2977   /* alias for c->dp [dest] */
2978   tmpc = c->dp;
2979 
2980   /* zero carry */
2981   u = 0;
2982 
2983   /* compute columns */
2984   for (ix = 0; ix < a->used; ix++) {
2985     /* compute product and carry sum for this term */
2986     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
2987 
2988     /* mask off higher bits to get a single digit */
2989     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
2990 
2991     /* send carry into next iteration */
2992     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
2993   }
2994 
2995   /* store final carry [if any] and increment ix offset  */
2996   *tmpc++ = u;
2997   ++ix;
2998 
2999   /* now zero digits above the top */
3000   while (ix++ < olduse) {
3001      *tmpc++ = 0;
3002   }
3003 
3004   /* set used count */
3005   c->used = a->used + 1;
3006   mp_clamp(c);
3007 
3008   return MP_OKAY;
3009 }
3010 
3011 
3012 /* d = a * b (mod c) */
3013 #if defined(FREESCALE_LTC_TFM)
wolfcrypt_mp_mulmod(mp_int * a,mp_int * b,mp_int * c,mp_int * d)3014 int wolfcrypt_mp_mulmod(mp_int *a, mp_int *b, mp_int *c, mp_int *d)
3015 #else
3016 int mp_mulmod (mp_int * a, mp_int * b, mp_int * c, mp_int * d)
3017 #endif
3018 {
3019   int     res;
3020   mp_int  t;
3021 
3022   if ((res = mp_init_size (&t, c->used)) != MP_OKAY) {
3023     return res;
3024   }
3025 
3026   res = mp_mul (a, b, &t);
3027   if (res == MP_OKAY) {
3028       res = mp_mod (&t, c, d);
3029   }
3030 
3031   mp_clear (&t);
3032   return res;
3033 }
3034 
3035 
3036 /* d = a - b (mod c) */
mp_submod(mp_int * a,mp_int * b,mp_int * c,mp_int * d)3037 int mp_submod(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
3038 {
3039   int     res;
3040   mp_int  t;
3041 
3042   if ((res = mp_init (&t)) != MP_OKAY) {
3043     return res;
3044   }
3045 
3046   res = mp_sub (a, b, &t);
3047   if (res == MP_OKAY) {
3048       res = mp_mod (&t, c, d);
3049   }
3050 
3051   mp_clear (&t);
3052 
3053   return res;
3054 }
3055 
3056 /* d = a + b (mod c) */
mp_addmod(mp_int * a,mp_int * b,mp_int * c,mp_int * d)3057 int mp_addmod(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
3058 {
3059    int     res;
3060    mp_int  t;
3061 
3062    if ((res = mp_init (&t)) != MP_OKAY) {
3063      return res;
3064    }
3065 
3066    res = mp_add (a, b, &t);
3067    if (res == MP_OKAY) {
3068        res = mp_mod (&t, c, d);
3069    }
3070 
3071    mp_clear (&t);
3072 
3073    return res;
3074 }
3075 
3076 /* d = a - b (mod c) - a < c and b < c and positive */
mp_submod_ct(mp_int * a,mp_int * b,mp_int * c,mp_int * d)3077 int mp_submod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
3078 {
3079     int res;
3080 
3081     res = mp_sub(a, b, d);
3082     if (res == MP_OKAY && mp_isneg(d)) {
3083         res = mp_add(d, c, d);
3084     }
3085 
3086     return res;
3087 }
3088 
3089 /* d = a + b (mod c) - a < c and b < c and positive */
mp_addmod_ct(mp_int * a,mp_int * b,mp_int * c,mp_int * d)3090 int mp_addmod_ct(mp_int* a, mp_int* b, mp_int* c, mp_int* d)
3091 {
3092     int res;
3093 
3094     res = mp_add(a, b, d);
3095     if (res == MP_OKAY && mp_cmp(d, c) != MP_LT) {
3096         res = mp_sub(d, c, d);
3097     }
3098 
3099     return res;
3100 }
3101 
3102 /* computes b = a*a */
mp_sqr(mp_int * a,mp_int * b)3103 int mp_sqr (mp_int * a, mp_int * b)
3104 {
3105   int     res;
3106 
3107   {
3108 #ifdef BN_FAST_S_MP_SQR_C
3109     /* can we use the fast comba multiplier? */
3110     if ((a->used * 2 + 1) < (int)MP_WARRAY &&
3111          a->used <
3112          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3113       res = fast_s_mp_sqr (a, b);
3114     } else
3115 #endif
3116 #ifdef BN_S_MP_SQR_C
3117       res = s_mp_sqr (a, b);
3118 #else
3119       res = MP_VAL;
3120 #endif
3121   }
3122   b->sign = MP_ZPOS;
3123   return res;
3124 }
3125 
3126 
3127 /* high level multiplication (handles sign) */
3128 #if defined(FREESCALE_LTC_TFM)
wolfcrypt_mp_mul(mp_int * a,mp_int * b,mp_int * c)3129 int wolfcrypt_mp_mul(mp_int *a, mp_int *b, mp_int *c)
3130 #else
3131 int mp_mul (mp_int * a, mp_int * b, mp_int * c)
3132 #endif
3133 {
3134   int     res, neg;
3135   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3136 
3137   {
3138 #ifdef BN_FAST_S_MP_MUL_DIGS_C
3139     /* can we use the fast multiplier?
3140      *
3141      * The fast multiplier can be used if the output will
3142      * have less than MP_WARRAY digits and the number of
3143      * digits won't affect carry propagation
3144      */
3145     int     digs = a->used + b->used + 1;
3146 
3147     if ((digs < (int)MP_WARRAY) &&
3148         MIN(a->used, b->used) <=
3149         (1L << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3150       res = fast_s_mp_mul_digs (a, b, c, digs);
3151     } else
3152 #endif
3153 #ifdef BN_S_MP_MUL_DIGS_C
3154       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3155 #else
3156       res = MP_VAL;
3157 #endif
3158 
3159   }
3160   c->sign = (c->used > 0) ? neg : MP_ZPOS;
3161   return res;
3162 }
3163 
3164 
3165 /* b = a*2 */
mp_mul_2(mp_int * a,mp_int * b)3166 int mp_mul_2(mp_int * a, mp_int * b)
3167 {
3168   int     x, res, oldused;
3169 
3170   /* grow to accommodate result */
3171   if (b->alloc < a->used + 1) {
3172     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
3173       return res;
3174     }
3175   }
3176 
3177   oldused = b->used;
3178   b->used = a->used;
3179 
3180   {
3181     mp_digit r, rr, *tmpa, *tmpb;
3182 
3183     /* alias for source */
3184     tmpa = a->dp;
3185 
3186     /* alias for dest */
3187     tmpb = b->dp;
3188 
3189     /* carry */
3190     r = 0;
3191     for (x = 0; x < a->used; x++) {
3192 
3193       /* get what will be the *next* carry bit from the
3194        * MSB of the current digit
3195        */
3196       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
3197 
3198       /* now shift up this digit, add in the carry [from the previous] */
3199       *tmpb++ = (mp_digit)(((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK);
3200 
3201       /* copy the carry that would be from the source
3202        * digit into the next iteration
3203        */
3204       r = rr;
3205     }
3206 
3207     /* new leading digit? */
3208     if (r != 0) {
3209       /* add a MSB which is always 1 at this point */
3210       *tmpb = 1;
3211       ++(b->used);
3212     }
3213 
3214     /* now zero any excess digits on the destination
3215      * that we didn't write to
3216      */
3217     tmpb = b->dp + b->used;
3218     for (x = b->used; x < oldused; x++) {
3219       *tmpb++ = 0;
3220     }
3221   }
3222   b->sign = a->sign;
3223   return MP_OKAY;
3224 }
3225 
3226 
3227 /* divide by three (based on routine from MPI and the GMP manual) */
mp_div_3(mp_int * a,mp_int * c,mp_digit * d)3228 int mp_div_3 (mp_int * a, mp_int *c, mp_digit * d)
3229 {
3230   mp_int   q;
3231   mp_word  w, t;
3232   mp_digit b;
3233   int      res, ix;
3234 
3235   /* b = 2**DIGIT_BIT / 3 */
3236   b = (mp_digit) ( (((mp_word)1) << ((mp_word)DIGIT_BIT)) / ((mp_word)3) );
3237 
3238   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
3239      return res;
3240   }
3241 
3242   q.used = a->used;
3243   q.sign = a->sign;
3244   w = 0;
3245   for (ix = a->used - 1; ix >= 0; ix--) {
3246      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
3247 
3248      if (w >= 3) {
3249         /* multiply w by [1/3] */
3250         t = (w * ((mp_word)b)) >> ((mp_word)DIGIT_BIT);
3251 
3252         /* now subtract 3 * [w/3] from w, to get the remainder */
3253         w -= t+t+t;
3254 
3255         /* fixup the remainder as required since
3256          * the optimization is not exact.
3257          */
3258         while (w >= 3) {
3259            t += 1;
3260            w -= 3;
3261         }
3262       } else {
3263         t = 0;
3264       }
3265       q.dp[ix] = (mp_digit)t;
3266   }
3267 
3268   /* [optional] store the remainder */
3269   if (d != NULL) {
3270      *d = (mp_digit)w;
3271   }
3272 
3273   /* [optional] store the quotient */
3274   if (c != NULL) {
3275      mp_clamp(&q);
3276      mp_exch(&q, c);
3277   }
3278   mp_clear(&q);
3279 
3280   return res;
3281 }
3282 
3283 
3284 /* init an mp_init for a given size */
mp_init_size(mp_int * a,int size)3285 int mp_init_size (mp_int * a, int size)
3286 {
3287   int x;
3288 
3289   /* pad size so there are always extra digits */
3290   size += (MP_PREC * 2) - (size % MP_PREC);
3291 
3292   /* alloc mem */
3293   a->dp = OPT_CAST(mp_digit) XMALLOC (sizeof (mp_digit) * size, NULL,
3294                                       DYNAMIC_TYPE_BIGINT);
3295   if (a->dp == NULL) {
3296     return MP_MEM;
3297   }
3298 
3299   /* set the members */
3300   a->used  = 0;
3301   a->alloc = size;
3302   a->sign  = MP_ZPOS;
3303 #ifdef HAVE_WOLF_BIGINT
3304   wc_bigint_init(&a->raw);
3305 #endif
3306 
3307   /* zero the digits */
3308   for (x = 0; x < size; x++) {
3309       a->dp[x] = 0;
3310   }
3311 
3312   return MP_OKAY;
3313 }
3314 
3315 
3316 /* the jist of squaring...
3317  * you do like mult except the offset of the tmpx [one that
3318  * starts closer to zero] can't equal the offset of tmpy.
3319  * So basically you set up iy like before then you min it with
3320  * (ty-tx) so that it never happens.  You double all those
3321  * you add in the inner loop
3322 
3323 After that loop you do the squares and add them in.
3324 */
3325 
fast_s_mp_sqr(mp_int * a,mp_int * b)3326 int fast_s_mp_sqr (mp_int * a, mp_int * b)
3327 {
3328   int       olduse, res, pa, ix, iz;
3329 #ifdef WOLFSSL_SMALL_STACK
3330   mp_digit* W;    /* uses dynamic memory and slower */
3331 #else
3332   mp_digit W[MP_WARRAY];
3333 #endif
3334   mp_digit  *tmpx;
3335   mp_word   W1;
3336 
3337   /* grow the destination as required */
3338   pa = a->used + a->used;
3339   if (b->alloc < pa) {
3340     if ((res = mp_grow (b, pa)) != MP_OKAY) {
3341       return res;
3342     }
3343   }
3344 
3345   if (pa > (int)MP_WARRAY)
3346     return MP_RANGE;  /* TAO range check */
3347 
3348 #ifdef WOLFSSL_SMALL_STACK
3349   W = (mp_digit*)XMALLOC(sizeof(mp_digit) * MP_WARRAY, NULL, DYNAMIC_TYPE_BIGINT);
3350   if (W == NULL)
3351     return MP_MEM;
3352 #endif
3353 
3354   /* number of output digits to produce */
3355   W1 = 0;
3356   for (ix = 0; ix < pa; ix++) {
3357       int      tx, ty, iy;
3358       mp_word  _W;
3359       mp_digit *tmpy;
3360 
3361       /* clear counter */
3362       _W = 0;
3363 
3364       /* get offsets into the two bignums */
3365       ty = MIN(a->used-1, ix);
3366       tx = ix - ty;
3367 
3368       /* setup temp aliases */
3369       tmpx = a->dp + tx;
3370       tmpy = a->dp + ty;
3371 
3372       /* this is the number of times the loop will iterate, essentially
3373          while (tx++ < a->used && ty-- >= 0) { ... }
3374        */
3375       iy = MIN(a->used-tx, ty+1);
3376 
3377       /* now for squaring tx can never equal ty
3378        * we halve the distance since they approach at a rate of 2x
3379        * and we have to round because odd cases need to be executed
3380        */
3381       iy = MIN(iy, (ty-tx+1)>>1);
3382 
3383       /* execute loop */
3384       for (iz = 0; iz < iy; iz++) {
3385          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
3386       }
3387 
3388       /* double the inner product and add carry */
3389       _W = _W + _W + W1;
3390 
3391       /* even columns have the square term in them */
3392       if ((ix&1) == 0) {
3393          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
3394       }
3395 
3396       /* store it */
3397       W[ix] = (mp_digit)(_W & MP_MASK);
3398 
3399       /* make next carry */
3400       W1 = _W >> ((mp_word)DIGIT_BIT);
3401   }
3402 
3403   /* setup dest */
3404   olduse  = b->used;
3405   b->used = a->used+a->used;
3406 
3407   {
3408     mp_digit *tmpb;
3409     tmpb = b->dp;
3410     for (ix = 0; ix < pa; ix++) {
3411       *tmpb++ = (mp_digit)(W[ix] & MP_MASK);
3412     }
3413 
3414     /* clear unused digits [that existed in the old copy of c] */
3415     for (; ix < olduse; ix++) {
3416       *tmpb++ = 0;
3417     }
3418   }
3419   mp_clamp (b);
3420 
3421 #ifdef WOLFSSL_SMALL_STACK
3422   XFREE(W, NULL, DYNAMIC_TYPE_BIGINT);
3423 #endif
3424 
3425   return MP_OKAY;
3426 }
3427 
3428 
3429 /* Fast (comba) multiplier
3430  *
3431  * This is the fast column-array [comba] multiplier.  It is
3432  * designed to compute the columns of the product first
3433  * then handle the carries afterwards.  This has the effect
3434  * of making the nested loops that compute the columns very
3435  * simple and schedulable on super-scalar processors.
3436  *
3437  * This has been modified to produce a variable number of
3438  * digits of output so if say only a half-product is required
3439  * you don't have to compute the upper half (a feature
3440  * required for fast Barrett reduction).
3441  *
3442  * Based on Algorithm 14.12 on pp.595 of HAC.
3443  *
3444  */
fast_s_mp_mul_digs(mp_int * a,mp_int * b,mp_int * c,int digs)3445 int fast_s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
3446 {
3447   int     olduse, res, pa, ix, iz;
3448 #ifdef WOLFSSL_SMALL_STACK
3449   mp_digit* W;    /* uses dynamic memory and slower */
3450 #else
3451   mp_digit W[MP_WARRAY];
3452 #endif
3453   mp_word  _W;
3454 
3455   /* grow the destination as required */
3456   if (c->alloc < digs) {
3457     if ((res = mp_grow (c, digs)) != MP_OKAY) {
3458       return res;
3459     }
3460   }
3461 
3462   /* number of output digits to produce */
3463   pa = MIN(digs, a->used + b->used);
3464   if (pa > (int)MP_WARRAY)
3465     return MP_RANGE;  /* TAO range check */
3466 
3467 #ifdef WOLFSSL_SMALL_STACK
3468   W = (mp_digit*)XMALLOC(sizeof(mp_digit) * MP_WARRAY, NULL, DYNAMIC_TYPE_BIGINT);
3469   if (W == NULL)
3470     return MP_MEM;
3471 #endif
3472 
3473   /* clear the carry */
3474   _W = 0;
3475   for (ix = 0; ix < pa; ix++) {
3476       int      tx, ty;
3477       int      iy;
3478       mp_digit *tmpx, *tmpy;
3479 
3480       /* get offsets into the two bignums */
3481       ty = MIN(b->used-1, ix);
3482       tx = ix - ty;
3483 
3484       /* setup temp aliases */
3485       tmpx = a->dp + tx;
3486       tmpy = b->dp + ty;
3487 
3488       /* this is the number of times the loop will iterate, essentially
3489          while (tx++ < a->used && ty-- >= 0) { ... }
3490        */
3491       iy = MIN(a->used-tx, ty+1);
3492 
3493       /* execute loop */
3494       for (iz = 0; iz < iy; ++iz) {
3495          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
3496 
3497       }
3498 
3499       /* store term */
3500       W[ix] = (mp_digit)(((mp_digit)_W) & MP_MASK);
3501 
3502       /* make next carry */
3503       _W = _W >> ((mp_word)DIGIT_BIT);
3504  }
3505 
3506   /* setup dest */
3507   olduse  = c->used;
3508   c->used = pa;
3509 
3510   {
3511     mp_digit *tmpc;
3512     tmpc = c->dp;
3513     for (ix = 0; ix < pa; ix++) { /* JRB, +1 could read uninitialized data */
3514       /* now extract the previous digit [below the carry] */
3515       *tmpc++ = W[ix];
3516     }
3517 
3518     /* clear unused digits [that existed in the old copy of c] */
3519     for (; ix < olduse; ix++) {
3520       *tmpc++ = 0;
3521     }
3522   }
3523   mp_clamp (c);
3524 
3525 #ifdef WOLFSSL_SMALL_STACK
3526   XFREE(W, NULL, DYNAMIC_TYPE_BIGINT);
3527 #endif
3528 
3529   return MP_OKAY;
3530 }
3531 
3532 
3533 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
s_mp_sqr(mp_int * a,mp_int * b)3534 int s_mp_sqr (mp_int * a, mp_int * b)
3535 {
3536   mp_int  t;
3537   int     res, ix, iy, pa;
3538   mp_word r;
3539   mp_digit u, tmpx, *tmpt;
3540 
3541   pa = a->used;
3542   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
3543     return res;
3544   }
3545 
3546   /* default used is maximum possible size */
3547   t.used = 2*pa + 1;
3548 
3549   for (ix = 0; ix < pa; ix++) {
3550     /* first calculate the digit at 2*ix */
3551     /* calculate double precision result */
3552     r = ((mp_word) t.dp[2*ix]) +
3553         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
3554 
3555     /* store lower part in result */
3556     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
3557 
3558     /* get the carry */
3559     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3560 
3561     /* left hand side of A[ix] * A[iy] */
3562     tmpx        = a->dp[ix];
3563 
3564     /* alias for where to store the results */
3565     tmpt        = t.dp + (2*ix + 1);
3566 
3567     for (iy = ix + 1; iy < pa; iy++) {
3568       /* first calculate the product */
3569       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
3570 
3571       /* now calculate the double precision result, note we use
3572        * addition instead of *2 since it's easier to optimize
3573        */
3574       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
3575 
3576       /* store lower part */
3577       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
3578 
3579       /* get carry */
3580       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3581     }
3582     /* propagate upwards */
3583     while (u != ((mp_digit) 0)) {
3584       r       = ((mp_word) *tmpt) + ((mp_word) u);
3585       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
3586       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3587     }
3588   }
3589 
3590   mp_clamp (&t);
3591   mp_exch (&t, b);
3592   mp_clear (&t);
3593   return MP_OKAY;
3594 }
3595 
3596 
3597 /* multiplies |a| * |b| and only computes up to digs digits of result
3598  * HAC pp. 595, Algorithm 14.12  Modified so you can control how
3599  * many digits of output are created.
3600  */
s_mp_mul_digs(mp_int * a,mp_int * b,mp_int * c,int digs)3601 int s_mp_mul_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
3602 {
3603   mp_int  t;
3604   int     res, pa, pb, ix, iy;
3605   mp_digit u;
3606   mp_word r;
3607   mp_digit tmpx, *tmpt, *tmpy;
3608 
3609   /* can we use the fast multiplier? */
3610   if ((digs < (int)MP_WARRAY) &&
3611       MIN (a->used, b->used) <
3612           (1L << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3613     return fast_s_mp_mul_digs (a, b, c, digs);
3614   }
3615 
3616   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
3617     return res;
3618   }
3619   t.used = digs;
3620 
3621   /* compute the digits of the product directly */
3622   pa = a->used;
3623   for (ix = 0; ix < pa; ix++) {
3624     /* set the carry to zero */
3625     u = 0;
3626 
3627     /* limit ourselves to making digs digits of output */
3628     pb = MIN (b->used, digs - ix);
3629 
3630     /* setup some aliases */
3631     /* copy of the digit from a used within the nested loop */
3632     tmpx = a->dp[ix];
3633 
3634     /* an alias for the destination shifted ix places */
3635     tmpt = t.dp + ix;
3636 
3637     /* an alias for the digits of b */
3638     tmpy = b->dp;
3639 
3640     /* compute the columns of the output and propagate the carry */
3641     for (iy = 0; iy < pb; iy++) {
3642       /* compute the column as a mp_word */
3643       r       = ((mp_word)*tmpt) +
3644                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
3645                 ((mp_word) u);
3646 
3647       /* the new column is the lower part of the result */
3648       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
3649 
3650       /* get the carry word from the result */
3651       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
3652     }
3653     /* set carry if it is placed below digs */
3654     if (ix + iy < digs) {
3655       *tmpt = u;
3656     }
3657   }
3658 
3659   mp_clamp (&t);
3660   mp_exch (&t, c);
3661 
3662   mp_clear (&t);
3663   return MP_OKAY;
3664 }
3665 
3666 
3667 /*
3668  * shifts with subtractions when the result is greater than b.
3669  *
3670  * The method is slightly modified to shift B unconditionally up to just under
3671  * the leading bit of b.  This saves a lot of multiple precision shifting.
3672  */
mp_montgomery_calc_normalization(mp_int * a,mp_int * b)3673 int mp_montgomery_calc_normalization (mp_int * a, mp_int * b)
3674 {
3675   int     x, bits, res;
3676 
3677   /* how many bits of last digit does b use */
3678   bits = mp_count_bits (b) % DIGIT_BIT;
3679 
3680   if (b->used > 1) {
3681      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1))
3682          != MP_OKAY) {
3683         return res;
3684      }
3685   } else {
3686      if ((res = mp_set(a, 1)) != MP_OKAY) {
3687         return res;
3688      }
3689      bits = 1;
3690   }
3691 
3692   /* now compute C = A * B mod b */
3693   for (x = bits - 1; x < (int)DIGIT_BIT; x++) {
3694     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
3695       return res;
3696     }
3697     if (mp_cmp_mag (a, b) != MP_LT) {
3698       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
3699         return res;
3700       }
3701     }
3702   }
3703 
3704   return MP_OKAY;
3705 }
3706 
3707 
3708 #ifdef MP_LOW_MEM
3709    #define TAB_SIZE 32
3710 #else
3711    #define TAB_SIZE 256
3712 #endif
3713 
s_mp_exptmod(mp_int * G,mp_int * X,mp_int * P,mp_int * Y,int redmode)3714 int s_mp_exptmod (mp_int * G, mp_int * X, mp_int * P, mp_int * Y, int redmode)
3715 {
3716   mp_int  M[TAB_SIZE], res, mu;
3717   mp_digit buf;
3718   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3719   int (*redux)(mp_int*,mp_int*,mp_int*);
3720 
3721   /* find window size */
3722   x = mp_count_bits (X);
3723   if (x <= 7) {
3724     winsize = 2;
3725   } else if (x <= 36) {
3726     winsize = 3;
3727   } else if (x <= 140) {
3728     winsize = 4;
3729   } else if (x <= 450) {
3730     winsize = 5;
3731   } else if (x <= 1303) {
3732     winsize = 6;
3733   } else if (x <= 3529) {
3734     winsize = 7;
3735   } else {
3736     winsize = 8;
3737   }
3738 
3739 #ifdef MP_LOW_MEM
3740     if (winsize > 5) {
3741        winsize = 5;
3742     }
3743 #endif
3744 
3745   /* init M array */
3746   /* init first cell */
3747   if ((err = mp_init(&M[1])) != MP_OKAY) {
3748      return err;
3749   }
3750 
3751   /* now init the second half of the array */
3752   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3753     if ((err = mp_init(&M[x])) != MP_OKAY) {
3754       for (y = 1<<(winsize-1); y < x; y++) {
3755         mp_clear (&M[y]);
3756       }
3757       mp_clear(&M[1]);
3758       return err;
3759     }
3760   }
3761 
3762   /* create mu, used for Barrett reduction */
3763   if ((err = mp_init (&mu)) != MP_OKAY) {
3764     goto LBL_M;
3765   }
3766 
3767   if (redmode == 0) {
3768      if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
3769         goto LBL_MU;
3770      }
3771      redux = mp_reduce;
3772   } else {
3773      if ((err = mp_reduce_2k_setup_l (P, &mu)) != MP_OKAY) {
3774         goto LBL_MU;
3775      }
3776      redux = mp_reduce_2k_l;
3777   }
3778 
3779   /* create M table
3780    *
3781    * The M table contains powers of the base,
3782    * e.g. M[x] = G**x mod P
3783    *
3784    * The first half of the table is not
3785    * computed though accept for M[0] and M[1]
3786    */
3787   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
3788     goto LBL_MU;
3789   }
3790 
3791   /* compute the value at M[1<<(winsize-1)] by squaring
3792    * M[1] (winsize-1) times
3793    */
3794   if ((err = mp_copy (&M[1], &M[(mp_digit)(1 << (winsize - 1))])) != MP_OKAY) {
3795     goto LBL_MU;
3796   }
3797 
3798   for (x = 0; x < (winsize - 1); x++) {
3799     /* square it */
3800     if ((err = mp_sqr (&M[(mp_digit)(1 << (winsize - 1))],
3801                        &M[(mp_digit)(1 << (winsize - 1))])) != MP_OKAY) {
3802       goto LBL_MU;
3803     }
3804 
3805     /* reduce modulo P */
3806     if ((err = redux (&M[(mp_digit)(1 << (winsize - 1))], P, &mu)) != MP_OKAY) {
3807       goto LBL_MU;
3808     }
3809   }
3810 
3811   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
3812    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
3813    */
3814   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
3815     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
3816       goto LBL_MU;
3817     }
3818     if ((err = redux (&M[x], P, &mu)) != MP_OKAY) {
3819       goto LBL_MU;
3820     }
3821   }
3822 
3823   /* setup result */
3824   if ((err = mp_init (&res)) != MP_OKAY) {
3825     goto LBL_MU;
3826   }
3827   if ((err = mp_set (&res, 1)) != MP_OKAY) {
3828     goto LBL_MU;
3829   }
3830 
3831   /* set initial mode and bit cnt */
3832   mode   = 0;
3833   bitcnt = 1;
3834   buf    = 0;
3835   digidx = X->used - 1;
3836   bitcpy = 0;
3837   bitbuf = 0;
3838 
3839   for (;;) {
3840     /* grab next digit as required */
3841     if (--bitcnt == 0) {
3842       /* if digidx == -1 we are out of digits */
3843       if (digidx == -1) {
3844         break;
3845       }
3846       /* read next digit and reset the bitcnt */
3847       buf    = X->dp[digidx--];
3848       bitcnt = (int) DIGIT_BIT;
3849     }
3850 
3851     /* grab the next msb from the exponent */
3852     y     = (int)(buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
3853     buf <<= (mp_digit)1;
3854 
3855     /* if the bit is zero and mode == 0 then we ignore it
3856      * These represent the leading zero bits before the first 1 bit
3857      * in the exponent.  Technically this opt is not required but it
3858      * does lower the # of trivial squaring/reductions used
3859      */
3860     if (mode == 0 && y == 0) {
3861       continue;
3862     }
3863 
3864     /* if the bit is zero and mode == 1 then we square */
3865     if (mode == 1 && y == 0) {
3866       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3867         goto LBL_RES;
3868       }
3869       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
3870         goto LBL_RES;
3871       }
3872       continue;
3873     }
3874 
3875     /* else we add it to the window */
3876     bitbuf |= (y << (winsize - ++bitcpy));
3877     mode    = 2;
3878 
3879     if (bitcpy == winsize) {
3880       /* ok window is filled so square as required and multiply  */
3881       /* square first */
3882       for (x = 0; x < winsize; x++) {
3883         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3884           goto LBL_RES;
3885         }
3886         if ((err = redux (&res, P, &mu)) != MP_OKAY) {
3887           goto LBL_RES;
3888         }
3889       }
3890 
3891       /* then multiply */
3892       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
3893         goto LBL_RES;
3894       }
3895       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
3896         goto LBL_RES;
3897       }
3898 
3899       /* empty window and reset */
3900       bitcpy = 0;
3901       bitbuf = 0;
3902       mode   = 1;
3903     }
3904   }
3905 
3906   /* if bits remain then square/multiply */
3907   if (mode == 2 && bitcpy > 0) {
3908     /* square then multiply if the bit is set */
3909     for (x = 0; x < bitcpy; x++) {
3910       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
3911         goto LBL_RES;
3912       }
3913       if ((err = redux (&res, P, &mu)) != MP_OKAY) {
3914         goto LBL_RES;
3915       }
3916 
3917       bitbuf <<= 1;
3918       if ((bitbuf & (1 << winsize)) != 0) {
3919         /* then multiply */
3920         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
3921           goto LBL_RES;
3922         }
3923         if ((err = redux (&res, P, &mu)) != MP_OKAY) {
3924           goto LBL_RES;
3925         }
3926       }
3927     }
3928   }
3929 
3930   mp_exch (&res, Y);
3931   err = MP_OKAY;
3932 LBL_RES:mp_clear (&res);
3933 LBL_MU:mp_clear (&mu);
3934 LBL_M:
3935   mp_clear(&M[1]);
3936   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
3937     mp_clear (&M[x]);
3938   }
3939   return err;
3940 }
3941 
3942 
3943 /* pre-calculate the value required for Barrett reduction
3944  * For a given modulus "b" it calculates the value required in "a"
3945  */
mp_reduce_setup(mp_int * a,mp_int * b)3946 int mp_reduce_setup (mp_int * a, mp_int * b)
3947 {
3948   int     res;
3949 
3950   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3951     return res;
3952   }
3953   return mp_div (a, b, a, NULL);
3954 }
3955 
3956 
3957 /* reduces x mod m, assumes 0 < x < m**2, mu is
3958  * precomputed via mp_reduce_setup.
3959  * From HAC pp.604 Algorithm 14.42
3960  */
mp_reduce(mp_int * x,mp_int * m,mp_int * mu)3961 int mp_reduce (mp_int * x, mp_int * m, mp_int * mu)
3962 {
3963   mp_int  q;
3964   int     res, um = m->used;
3965 
3966   /* q = x */
3967   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3968     return res;
3969   }
3970 
3971   /* q1 = x / b**(k-1)  */
3972   mp_rshd (&q, um - 1);
3973 
3974   /* according to HAC this optimization is ok */
3975   if (((mp_word) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3976     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3977       goto CLEANUP;
3978     }
3979   } else {
3980 #ifdef BN_S_MP_MUL_HIGH_DIGS_C
3981     if ((res = s_mp_mul_high_digs (&q, mu, &q, um)) != MP_OKAY) {
3982       goto CLEANUP;
3983     }
3984 #elif defined(BN_FAST_S_MP_MUL_HIGH_DIGS_C)
3985     if ((res = fast_s_mp_mul_high_digs (&q, mu, &q, um)) != MP_OKAY) {
3986       goto CLEANUP;
3987     }
3988 #else
3989     {
3990       res = MP_VAL;
3991       goto CLEANUP;
3992     }
3993 #endif
3994   }
3995 
3996   /* q3 = q2 / b**(k+1) */
3997   mp_rshd (&q, um + 1);
3998 
3999   /* x = x mod b**(k+1), quick (no division) */
4000   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
4001     goto CLEANUP;
4002   }
4003 
4004   /* q = q * m mod b**(k+1), quick (no division) */
4005   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
4006     goto CLEANUP;
4007   }
4008 
4009   /* x = x - q */
4010   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
4011     goto CLEANUP;
4012   }
4013 
4014   /* If x < 0, add b**(k+1) to it */
4015   if (mp_cmp_d (x, 0) == MP_LT) {
4016     if ((res = mp_set (&q, 1)) != MP_OKAY)
4017         goto CLEANUP;
4018     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
4019       goto CLEANUP;
4020     if ((res = mp_add (x, &q, x)) != MP_OKAY)
4021       goto CLEANUP;
4022   }
4023 
4024   /* Back off if it's too big */
4025   while (mp_cmp (x, m) != MP_LT) {
4026     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
4027       goto CLEANUP;
4028     }
4029   }
4030 
4031 CLEANUP:
4032   mp_clear (&q);
4033 
4034   return res;
4035 }
4036 
4037 
4038 /* reduces a modulo n where n is of the form 2**p - d
4039    This differs from reduce_2k since "d" can be larger
4040    than a single digit.
4041 */
mp_reduce_2k_l(mp_int * a,mp_int * n,mp_int * d)4042 int mp_reduce_2k_l(mp_int *a, mp_int *n, mp_int *d)
4043 {
4044    mp_int q;
4045    int    p, res;
4046 
4047    if ((res = mp_init(&q)) != MP_OKAY) {
4048       return res;
4049    }
4050 
4051    p = mp_count_bits(n);
4052 top:
4053    /* q = a/2**p, a = a mod 2**p */
4054    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
4055       goto ERR;
4056    }
4057 
4058    /* q = q * d */
4059    if ((res = mp_mul(&q, d, &q)) != MP_OKAY) {
4060       goto ERR;
4061    }
4062 
4063    /* a = a + q */
4064    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
4065       goto ERR;
4066    }
4067 
4068    if (mp_cmp_mag(a, n) != MP_LT) {
4069       if ((res = s_mp_sub(a, n, a)) != MP_OKAY) {
4070          goto ERR;
4071       }
4072       goto top;
4073    }
4074 
4075 ERR:
4076    mp_clear(&q);
4077    return res;
4078 }
4079 
4080 
4081 /* determines the setup value */
mp_reduce_2k_setup_l(mp_int * a,mp_int * d)4082 int mp_reduce_2k_setup_l(mp_int *a, mp_int *d)
4083 {
4084    int    res;
4085    mp_int tmp;
4086 
4087    if ((res = mp_init(&tmp)) != MP_OKAY) {
4088       return res;
4089    }
4090 
4091    if ((res = mp_2expt(&tmp, mp_count_bits(a))) != MP_OKAY) {
4092       goto ERR;
4093    }
4094 
4095    if ((res = s_mp_sub(&tmp, a, d)) != MP_OKAY) {
4096       goto ERR;
4097    }
4098 
4099 ERR:
4100    mp_clear(&tmp);
4101    return res;
4102 }
4103 
4104 
4105 /* multiplies |a| * |b| and does not compute the lower digs digits
4106  * [meant to get the higher part of the product]
4107  */
s_mp_mul_high_digs(mp_int * a,mp_int * b,mp_int * c,int digs)4108 int s_mp_mul_high_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
4109 {
4110   mp_int  t;
4111   int     res, pa, pb, ix, iy;
4112   mp_digit u;
4113   mp_word r;
4114   mp_digit tmpx, *tmpt, *tmpy;
4115 
4116   /* can we use the fast multiplier? */
4117 #ifdef BN_FAST_S_MP_MUL_HIGH_DIGS_C
4118   if (((a->used + b->used + 1) < (int)MP_WARRAY)
4119       && MIN (a->used, b->used) <
4120       (1L << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4121     return fast_s_mp_mul_high_digs (a, b, c, digs);
4122   }
4123 #endif
4124 
4125   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4126     return res;
4127   }
4128   t.used = a->used + b->used + 1;
4129 
4130   pa = a->used;
4131   pb = b->used;
4132   for (ix = 0; ix < pa && a->dp; ix++) {
4133     /* clear the carry */
4134     u = 0;
4135 
4136     /* left hand side of A[ix] * B[iy] */
4137     tmpx = a->dp[ix];
4138 
4139     /* alias to the address of where the digits will be stored */
4140     tmpt = &(t.dp[digs]);
4141 
4142     /* alias for where to read the right hand side from */
4143     tmpy = b->dp + (digs - ix);
4144 
4145     for (iy = digs - ix; iy < pb; iy++) {
4146       /* calculate the double precision result */
4147       r       = ((mp_word)*tmpt) +
4148                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4149                 ((mp_word) u);
4150 
4151       /* get the lower part */
4152       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4153 
4154       /* carry the carry */
4155       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4156     }
4157     *tmpt = u;
4158   }
4159   mp_clamp (&t);
4160   mp_exch (&t, c);
4161   mp_clear (&t);
4162   return MP_OKAY;
4163 }
4164 
4165 
4166 /* this is a modified version of fast_s_mul_digs that only produces
4167  * output digits *above* digs.  See the comments for fast_s_mul_digs
4168  * to see how it works.
4169  *
4170  * This is used in the Barrett reduction since for one of the multiplications
4171  * only the higher digits were needed.  This essentially halves the work.
4172  *
4173  * Based on Algorithm 14.12 on pp.595 of HAC.
4174  */
fast_s_mp_mul_high_digs(mp_int * a,mp_int * b,mp_int * c,int digs)4175 int fast_s_mp_mul_high_digs (mp_int * a, mp_int * b, mp_int * c, int digs)
4176 {
4177   int     olduse, res, pa, ix, iz;
4178 #ifdef WOLFSSL_SMALL_STACK
4179   mp_digit* W;    /* uses dynamic memory and slower */
4180 #else
4181   mp_digit W[MP_WARRAY];
4182 #endif
4183   mp_word  _W;
4184 
4185   if (a->dp == NULL) { /* JRB, avoid reading uninitialized values */
4186       return MP_VAL;
4187   }
4188 
4189   /* grow the destination as required */
4190   pa = a->used + b->used;
4191   if (c->alloc < pa) {
4192     if ((res = mp_grow (c, pa)) != MP_OKAY) {
4193       return res;
4194     }
4195   }
4196 
4197   if (pa > (int)MP_WARRAY)
4198     return MP_RANGE;  /* TAO range check */
4199 
4200 #ifdef WOLFSSL_SMALL_STACK
4201   W = (mp_digit*)XMALLOC(sizeof(mp_digit) * MP_WARRAY, NULL, DYNAMIC_TYPE_BIGINT);
4202   if (W == NULL)
4203     return MP_MEM;
4204 #endif
4205 
4206   /* number of output digits to produce */
4207   pa = a->used + b->used;
4208   _W = 0;
4209   for (ix = digs; ix < pa; ix++) { /* JRB, have a->dp check at top of function*/
4210       int      tx, ty, iy;
4211       mp_digit *tmpx, *tmpy;
4212 
4213       /* get offsets into the two bignums */
4214       ty = MIN(b->used-1, ix);
4215       tx = ix - ty;
4216 
4217       /* setup temp aliases */
4218       tmpx = a->dp + tx;
4219       tmpy = b->dp + ty;
4220 
4221       /* this is the number of times the loop will iterate, essentially its
4222          while (tx++ < a->used && ty-- >= 0) { ... }
4223        */
4224       iy = MIN(a->used-tx, ty+1);
4225 
4226       /* execute loop */
4227       for (iz = 0; iz < iy; iz++) {
4228          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
4229       }
4230 
4231       /* store term */
4232       W[ix] = (mp_digit)(((mp_digit)_W) & MP_MASK);
4233 
4234       /* make next carry */
4235       _W = _W >> ((mp_word)DIGIT_BIT);
4236   }
4237 
4238   /* setup dest */
4239   olduse  = c->used;
4240   c->used = pa;
4241 
4242   {
4243     mp_digit *tmpc;
4244 
4245     tmpc = c->dp + digs;
4246     for (ix = digs; ix < pa; ix++) {   /* TAO, <= could potentially overwrite */
4247       /* now extract the previous digit [below the carry] */
4248       *tmpc++ = W[ix];
4249     }
4250 
4251     /* clear unused digits [that existed in the old copy of c] */
4252     for (; ix < olduse; ix++) {
4253       *tmpc++ = 0;
4254     }
4255   }
4256   mp_clamp (c);
4257 
4258 #ifdef WOLFSSL_SMALL_STACK
4259   XFREE(W, NULL, DYNAMIC_TYPE_BIGINT);
4260 #endif
4261 
4262   return MP_OKAY;
4263 }
4264 
4265 
4266 #ifndef MP_SET_CHUNK_BITS
4267     #define MP_SET_CHUNK_BITS 4
4268 #endif
mp_set_int(mp_int * a,unsigned long b)4269 int mp_set_int (mp_int * a, unsigned long b)
4270 {
4271   int x, res;
4272 
4273   /* use direct mp_set if b is less than mp_digit max */
4274   if (b < MP_DIGIT_MAX) {
4275     return mp_set (a, (mp_digit)b);
4276   }
4277 
4278   mp_zero (a);
4279 
4280   /* set chunk bits at a time */
4281   for (x = 0; x < (int)(sizeof(b) * 8) / MP_SET_CHUNK_BITS; x++) {
4282     /* shift the number up chunk bits */
4283     if ((res = mp_mul_2d (a, MP_SET_CHUNK_BITS, a)) != MP_OKAY) {
4284       return res;
4285     }
4286 
4287     /* OR in the top bits of the source */
4288     a->dp[0] |= (b >> ((sizeof(b) * 8) - MP_SET_CHUNK_BITS)) &
4289                                   ((1 << MP_SET_CHUNK_BITS) - 1);
4290 
4291     /* shift the source up to the next chunk bits */
4292     b <<= MP_SET_CHUNK_BITS;
4293 
4294     /* ensure that digits are not clamped off */
4295     a->used += 1;
4296   }
4297   mp_clamp (a);
4298   return MP_OKAY;
4299 }
4300 
4301 
4302 #if defined(WOLFSSL_KEY_GEN) || defined(HAVE_ECC) || !defined(NO_RSA) || \
4303     !defined(NO_DSA) | !defined(NO_DH)
4304 
4305 /* c = a * a (mod b) */
mp_sqrmod(mp_int * a,mp_int * b,mp_int * c)4306 int mp_sqrmod (mp_int * a, mp_int * b, mp_int * c)
4307 {
4308   int     res;
4309   mp_int  t;
4310 
4311   if ((res = mp_init (&t)) != MP_OKAY) {
4312     return res;
4313   }
4314 
4315   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
4316     mp_clear (&t);
4317     return res;
4318   }
4319   res = mp_mod (&t, b, c);
4320   mp_clear (&t);
4321   return res;
4322 }
4323 
4324 #endif
4325 
4326 
4327 #if defined(HAVE_ECC) || !defined(NO_PWDBASED) || defined(WOLFSSL_SNIFFER) || \
4328     defined(WOLFSSL_HAVE_WOLFSCEP) || defined(WOLFSSL_KEY_GEN) || \
4329     defined(OPENSSL_EXTRA) || defined(WC_RSA_BLINDING) || \
4330     (!defined(NO_RSA) && !defined(NO_RSA_BOUNDS_CHECK))
4331 
4332 /* single digit addition */
mp_add_d(mp_int * a,mp_digit b,mp_int * c)4333 int mp_add_d (mp_int* a, mp_digit b, mp_int* c)
4334 {
4335   int     res, ix, oldused;
4336   mp_digit *tmpa, *tmpc, mu;
4337 
4338   if (b > MP_DIGIT_MAX) return MP_VAL;
4339 
4340   /* grow c as required */
4341   if (c->alloc < a->used + 1) {
4342      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
4343         return res;
4344      }
4345   }
4346 
4347   /* if a is negative and |a| >= b, call c = |a| - b */
4348   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
4349      /* temporarily fix sign of a */
4350      a->sign = MP_ZPOS;
4351 
4352      /* c = |a| - b */
4353      res = mp_sub_d(a, b, c);
4354 
4355      /* fix sign  */
4356      a->sign = c->sign = MP_NEG;
4357 
4358      /* clamp */
4359      mp_clamp(c);
4360 
4361      return res;
4362   }
4363 
4364   /* old number of used digits in c */
4365   oldused = c->used;
4366 
4367   /* sign always positive */
4368   c->sign = MP_ZPOS;
4369 
4370   /* source alias */
4371   tmpa    = a->dp;
4372 
4373   /* destination alias */
4374   tmpc    = c->dp;
4375 
4376   /* if a is positive */
4377   if (a->sign == MP_ZPOS) {
4378      /* add digit, after this we're propagating
4379       * the carry.
4380       */
4381      *tmpc   = *tmpa++ + b;
4382      mu      = *tmpc >> DIGIT_BIT;
4383      *tmpc++ &= MP_MASK;
4384 
4385      /* now handle rest of the digits */
4386      for (ix = 1; ix < a->used; ix++) {
4387         *tmpc   = *tmpa++ + mu;
4388         mu      = *tmpc >> DIGIT_BIT;
4389         *tmpc++ &= MP_MASK;
4390      }
4391      /* set final carry */
4392      if (ix < c->alloc) {
4393         ix++;
4394         *tmpc++  = mu;
4395      }
4396 
4397      /* setup size */
4398      c->used = a->used + 1;
4399   } else {
4400      /* a was negative and |a| < b */
4401      c->used  = 1;
4402 
4403      /* the result is a single digit */
4404      if (a->used == 1) {
4405         *tmpc++  =  b - a->dp[0];
4406      } else {
4407         *tmpc++  =  b;
4408      }
4409 
4410      /* setup count so the clearing of oldused
4411       * can fall through correctly
4412       */
4413      ix       = 1;
4414   }
4415 
4416   /* now zero to oldused */
4417   while (ix++ < oldused) {
4418      *tmpc++ = 0;
4419   }
4420   mp_clamp(c);
4421 
4422   return MP_OKAY;
4423 }
4424 
4425 
4426 /* single digit subtraction */
mp_sub_d(mp_int * a,mp_digit b,mp_int * c)4427 int mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
4428 {
4429   mp_digit *tmpa, *tmpc, mu;
4430   int       res, ix, oldused;
4431 
4432   if (b > MP_MASK) return MP_VAL;
4433 
4434   /* grow c as required */
4435   if (c->alloc < a->used + 1) {
4436      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
4437         return res;
4438      }
4439   }
4440 
4441   /* if a is negative just do an unsigned
4442    * addition [with fudged signs]
4443    */
4444   if (a->sign == MP_NEG) {
4445      a->sign = MP_ZPOS;
4446      res     = mp_add_d(a, b, c);
4447      a->sign = c->sign = MP_NEG;
4448 
4449      /* clamp */
4450      mp_clamp(c);
4451 
4452      return res;
4453   }
4454 
4455   /* setup regs */
4456   oldused = c->used;
4457   tmpa    = a->dp;
4458   tmpc    = c->dp;
4459 
4460   /* if a <= b simply fix the single digit */
4461   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
4462      if (a->used == 1) {
4463         *tmpc++ = b - *tmpa;
4464      } else {
4465         *tmpc++ = b;
4466      }
4467      ix      = 1;
4468 
4469      /* negative/1digit */
4470      c->sign = MP_NEG;
4471      c->used = 1;
4472   } else {
4473      /* positive/size */
4474      c->sign = MP_ZPOS;
4475      c->used = a->used;
4476 
4477      /* subtract first digit */
4478      *tmpc    = *tmpa++ - b;
4479      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
4480      *tmpc++ &= MP_MASK;
4481 
4482      /* handle rest of the digits */
4483      for (ix = 1; ix < a->used; ix++) {
4484         *tmpc    = *tmpa++ - mu;
4485         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
4486         *tmpc++ &= MP_MASK;
4487      }
4488   }
4489 
4490   /* zero excess digits */
4491   while (ix++ < oldused) {
4492      *tmpc++ = 0;
4493   }
4494   mp_clamp(c);
4495   return MP_OKAY;
4496 }
4497 
4498 #endif /* defined(HAVE_ECC) || !defined(NO_PWDBASED) */
4499 
4500 
4501 #if defined(WOLFSSL_KEY_GEN) || defined(HAVE_COMP_KEY) || defined(HAVE_ECC) || \
4502     defined(DEBUG_WOLFSSL) || !defined(NO_RSA) || !defined(NO_DSA) || \
4503     !defined(NO_DH) || defined(WC_MP_TO_RADIX)
4504 
4505 static const int lnz[16] = {
4506    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
4507 };
4508 
4509 /* Counts the number of lsbs which are zero before the first zero bit */
mp_cnt_lsb(mp_int * a)4510 int mp_cnt_lsb(mp_int *a)
4511 {
4512     int x;
4513     mp_digit q = 0, qq;
4514 
4515     /* easy out */
4516     if (mp_iszero(a) == MP_YES) {
4517         return 0;
4518     }
4519 
4520     /* scan lower digits until non-zero */
4521     for (x = 0; x < a->used && a->dp[x] == 0; x++) {}
4522     if (a->dp)
4523         q = a->dp[x];
4524     x *= DIGIT_BIT;
4525 
4526     /* now scan this digit until a 1 is found */
4527     if ((q & 1) == 0) {
4528         do {
4529             qq  = q & 15;
4530             x  += lnz[qq];
4531             q >>= 4;
4532         } while (qq == 0);
4533     }
4534     return x;
4535 }
4536 
4537 
4538 
4539 
s_is_power_of_two(mp_digit b,int * p)4540 static int s_is_power_of_two(mp_digit b, int *p)
4541 {
4542    int x;
4543 
4544    /* fast return if no power of two */
4545    if ((b==0) || (b & (b-1))) {
4546       return 0;
4547    }
4548 
4549    for (x = 0; x < DIGIT_BIT; x++) {
4550       if (b == (((mp_digit)1)<<x)) {
4551          *p = x;
4552          return 1;
4553       }
4554    }
4555    return 0;
4556 }
4557 
4558 /* single digit division (based on routine from MPI) */
mp_div_d(mp_int * a,mp_digit b,mp_int * c,mp_digit * d)4559 static int mp_div_d (mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
4560 {
4561   mp_int  q;
4562   mp_word w;
4563   mp_digit t;
4564   int     res = MP_OKAY, ix;
4565 
4566   /* cannot divide by zero */
4567   if (b == 0) {
4568      return MP_VAL;
4569   }
4570 
4571   /* quick outs */
4572   if (b == 1 || mp_iszero(a) == MP_YES) {
4573      if (d != NULL) {
4574         *d = 0;
4575      }
4576      if (c != NULL) {
4577         return mp_copy(a, c);
4578      }
4579      return MP_OKAY;
4580   }
4581 
4582   /* power of two ? */
4583   if (s_is_power_of_two(b, &ix) == 1) {
4584      if (d != NULL) {
4585         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
4586      }
4587      if (c != NULL) {
4588         return mp_div_2d(a, ix, c, NULL);
4589      }
4590      return MP_OKAY;
4591   }
4592 
4593 #ifdef BN_MP_DIV_3_C
4594   /* three? */
4595   if (b == 3) {
4596      return mp_div_3(a, c, d);
4597   }
4598 #endif
4599 
4600   /* no easy answer [c'est la vie].  Just division */
4601   if (c != NULL) {
4602       if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
4603          return res;
4604       }
4605 
4606       q.used = a->used;
4607       q.sign = a->sign;
4608   }
4609   else {
4610       if ((res = mp_init(&q)) != MP_OKAY) {
4611          return res;
4612       }
4613   }
4614 
4615 
4616   w = 0;
4617   for (ix = a->used - 1; ix >= 0; ix--) {
4618      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
4619 
4620      if (w >= b) {
4621 #ifdef WOLFSSL_LINUXKM
4622         t = (mp_digit)w;
4623         /* Linux kernel macro for in-place 64 bit integer division. */
4624         do_div(t, b);
4625 #else
4626         t = (mp_digit)(w / b);
4627 #endif
4628         w -= ((mp_word)t) * ((mp_word)b);
4629       } else {
4630         t = 0;
4631       }
4632       if (c != NULL)
4633         q.dp[ix] = (mp_digit)t;
4634   }
4635 
4636   if (d != NULL) {
4637      *d = (mp_digit)w;
4638   }
4639 
4640   if (c != NULL) {
4641      mp_clamp(&q);
4642      mp_exch(&q, c);
4643   }
4644   mp_clear(&q);
4645 
4646   return res;
4647 }
4648 
4649 
mp_mod_d(mp_int * a,mp_digit b,mp_digit * c)4650 int mp_mod_d (mp_int * a, mp_digit b, mp_digit * c)
4651 {
4652   return mp_div_d(a, b, NULL, c);
4653 }
4654 
4655 #endif /* WOLFSSL_KEY_GEN || HAVE_COMP_KEY || HAVE_ECC || DEBUG_WOLFSSL */
4656 
4657 #if defined(WOLFSSL_KEY_GEN) || !defined(NO_DH) || !defined(NO_DSA) || !defined(NO_RSA)
4658 
4659 const FLASH_QUALIFIER mp_digit ltm_prime_tab[PRIME_SIZE] = {
4660   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
4661   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
4662   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
4663   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F,
4664 #ifndef MP_8BIT
4665   0x0083,
4666   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
4667   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
4668   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
4669   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
4670 
4671   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
4672   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
4673   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
4674   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
4675   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
4676   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
4677   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
4678   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
4679 
4680   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
4681   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
4682   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
4683   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
4684   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
4685   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
4686   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
4687   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
4688 
4689   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
4690   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
4691   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
4692   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
4693   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
4694   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
4695   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
4696   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
4697 #endif
4698 };
4699 
4700 
4701 /* Miller-Rabin test of "a" to the base of "b" as described in
4702  * HAC pp. 139 Algorithm 4.24
4703  *
4704  * Sets result to 0 if definitely composite or 1 if probably prime.
4705  * Randomly the chance of error is no more than 1/4 and often
4706  * very much lower.
4707  */
mp_prime_miller_rabin(mp_int * a,mp_int * b,int * result)4708 static int mp_prime_miller_rabin (mp_int * a, mp_int * b, int *result)
4709 {
4710   mp_int  n1, y, r;
4711   int     s, j, err;
4712 
4713   /* default */
4714   *result = MP_NO;
4715 
4716   /* ensure b > 1 */
4717   if (mp_cmp_d(b, 1) != MP_GT) {
4718      return MP_VAL;
4719   }
4720 
4721   /* get n1 = a - 1 */
4722   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
4723     return err;
4724   }
4725   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
4726     goto LBL_N1;
4727   }
4728 
4729   /* set 2**s * r = n1 */
4730   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
4731     goto LBL_N1;
4732   }
4733 
4734   /* count the number of least significant bits
4735    * which are zero
4736    */
4737   s = mp_cnt_lsb(&r);
4738 
4739   /* now divide n - 1 by 2**s */
4740   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
4741     goto LBL_R;
4742   }
4743 
4744   /* compute y = b**r mod a */
4745   if ((err = mp_init (&y)) != MP_OKAY) {
4746     goto LBL_R;
4747   }
4748 #if defined(WOLFSSL_HAVE_SP_RSA) || defined(WOLFSSL_HAVE_SP_DH)
4749 #ifndef WOLFSSL_SP_NO_2048
4750   if (mp_count_bits(a) == 1024 && mp_isodd(a))
4751       err = sp_ModExp_1024(b, &r, a, &y);
4752   else if (mp_count_bits(a) == 2048 && mp_isodd(a))
4753       err = sp_ModExp_2048(b, &r, a, &y);
4754   else
4755 #endif
4756 #ifndef WOLFSSL_SP_NO_3072
4757   if (mp_count_bits(a) == 1536 && mp_isodd(a))
4758       err = sp_ModExp_1536(b, &r, a, &y);
4759   else if (mp_count_bits(a) == 3072 && mp_isodd(a))
4760       err = sp_ModExp_3072(b, &r, a, &y);
4761   else
4762 #endif
4763 #ifdef WOLFSSL_SP_4096
4764   if (mp_count_bits(a) == 4096 && mp_isodd(a))
4765       err = sp_ModExp_4096(b, &r, a, &y);
4766   else
4767 #endif
4768 #endif
4769       err = mp_exptmod (b, &r, a, &y);
4770   if (err != MP_OKAY)
4771       goto LBL_Y;
4772 
4773   /* if y != 1 and y != n1 do */
4774   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
4775     j = 1;
4776     /* while j <= s-1 and y != n1 */
4777     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
4778       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
4779          goto LBL_Y;
4780       }
4781 
4782       /* if y == 1 then composite */
4783       if (mp_cmp_d (&y, 1) == MP_EQ) {
4784          goto LBL_Y;
4785       }
4786 
4787       ++j;
4788     }
4789 
4790     /* if y != n1 then composite */
4791     if (mp_cmp (&y, &n1) != MP_EQ) {
4792       goto LBL_Y;
4793     }
4794   }
4795 
4796   /* probably prime now */
4797   *result = MP_YES;
4798 LBL_Y:mp_clear (&y);
4799 LBL_R:mp_clear (&r);
4800 LBL_N1:mp_clear (&n1);
4801   return err;
4802 }
4803 
4804 
4805 /* determines if an integers is divisible by one
4806  * of the first PRIME_SIZE primes or not
4807  *
4808  * sets result to 0 if not, 1 if yes
4809  */
mp_prime_is_divisible(mp_int * a,int * result)4810 static int mp_prime_is_divisible (mp_int * a, int *result)
4811 {
4812   int     err, ix;
4813   mp_digit res;
4814 
4815   /* default to not */
4816   *result = MP_NO;
4817 
4818   for (ix = 0; ix < PRIME_SIZE; ix++) {
4819     /* what is a mod LBL_prime_tab[ix] */
4820     if ((err = mp_mod_d (a, ltm_prime_tab[ix], &res)) != MP_OKAY) {
4821       return err;
4822     }
4823 
4824     /* is the residue zero? */
4825     if (res == 0) {
4826       *result = MP_YES;
4827       return MP_OKAY;
4828     }
4829   }
4830 
4831   return MP_OKAY;
4832 }
4833 
4834 /*
4835  * Sets result to 1 if probably prime, 0 otherwise
4836  */
mp_prime_is_prime(mp_int * a,int t,int * result)4837 int mp_prime_is_prime (mp_int * a, int t, int *result)
4838 {
4839   mp_int  b;
4840   int     ix, err, res;
4841 
4842   /* default to no */
4843   *result = MP_NO;
4844 
4845   /* valid value of t? */
4846   if (t <= 0 || t > PRIME_SIZE) {
4847     return MP_VAL;
4848   }
4849 
4850   if (mp_isone(a)) {
4851       *result = MP_NO;
4852       return MP_OKAY;
4853   }
4854 
4855   /* is the input equal to one of the primes in the table? */
4856   for (ix = 0; ix < PRIME_SIZE; ix++) {
4857       if (mp_cmp_d(a, ltm_prime_tab[ix]) == MP_EQ) {
4858          *result = MP_YES;
4859          return MP_OKAY;
4860       }
4861   }
4862 
4863   /* first perform trial division */
4864   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
4865     return err;
4866   }
4867 
4868   /* return if it was trivially divisible */
4869   if (res == MP_YES) {
4870     return MP_OKAY;
4871   }
4872 
4873   /* now perform the miller-rabin rounds */
4874   if ((err = mp_init (&b)) != MP_OKAY) {
4875     return err;
4876   }
4877 
4878   for (ix = 0; ix < t; ix++) {
4879     /* set the prime */
4880     if ((err = mp_set (&b, ltm_prime_tab[ix])) != MP_OKAY) {
4881         goto LBL_B;
4882     }
4883 
4884     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
4885       goto LBL_B;
4886     }
4887 
4888     if (res == MP_NO) {
4889       goto LBL_B;
4890     }
4891   }
4892 
4893   /* passed the test */
4894   *result = MP_YES;
4895 LBL_B:mp_clear (&b);
4896   return err;
4897 }
4898 
4899 
4900 /*
4901  * Sets result to 1 if probably prime, 0 otherwise
4902  */
mp_prime_is_prime_ex(mp_int * a,int t,int * result,WC_RNG * rng)4903 int mp_prime_is_prime_ex (mp_int * a, int t, int *result, WC_RNG *rng)
4904 {
4905   mp_int  b, c;
4906   int     ix, err, res;
4907   byte*   base = NULL;
4908   word32  baseSz = 0;
4909 
4910   /* default to no */
4911   *result = MP_NO;
4912 
4913   /* valid value of t? */
4914   if (t <= 0 || t > PRIME_SIZE) {
4915     return MP_VAL;
4916   }
4917 
4918   if (mp_isone(a)) {
4919     *result = MP_NO;
4920     return MP_OKAY;
4921   }
4922 
4923   /* is the input equal to one of the primes in the table? */
4924   for (ix = 0; ix < PRIME_SIZE; ix++) {
4925       if (mp_cmp_d(a, ltm_prime_tab[ix]) == MP_EQ) {
4926          *result = MP_YES;
4927          return MP_OKAY;
4928       }
4929   }
4930 
4931   /* first perform trial division */
4932   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
4933     return err;
4934   }
4935 
4936   /* return if it was trivially divisible */
4937   if (res == MP_YES) {
4938     return MP_OKAY;
4939   }
4940 
4941   /* now perform the miller-rabin rounds */
4942   if ((err = mp_init (&b)) != MP_OKAY) {
4943     return err;
4944   }
4945   if ((err = mp_init (&c)) != MP_OKAY) {
4946       mp_clear(&b);
4947     return err;
4948   }
4949 
4950   baseSz = mp_count_bits(a);
4951   baseSz = (baseSz / 8) + ((baseSz % 8) ? 1 : 0);
4952 
4953   base = (byte*)XMALLOC(baseSz, NULL, DYNAMIC_TYPE_TMP_BUFFER);
4954   if (base == NULL) {
4955       err = MP_MEM;
4956       goto LBL_B;
4957   }
4958 
4959   if ((err = mp_sub_d(a, 2, &c)) != MP_OKAY) {
4960       goto LBL_B;
4961   }
4962 
4963  /* now do a miller rabin with up to t random numbers, this should
4964   * give a (1/4)^t chance of a false prime. */
4965   for (ix = 0; ix < t; ix++) {
4966     /* Set a test candidate. */
4967     if ((err = wc_RNG_GenerateBlock(rng, base, baseSz)) != 0) {
4968         goto LBL_B;
4969     }
4970 
4971     if ((err = mp_read_unsigned_bin(&b, base, baseSz)) != MP_OKAY) {
4972         goto LBL_B;
4973     }
4974 
4975     if (mp_cmp_d(&b, 2) != MP_GT || mp_cmp(&b, &c) != MP_LT) {
4976         ix--;
4977         continue;
4978     }
4979 
4980     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
4981       goto LBL_B;
4982     }
4983 
4984     if (res == MP_NO) {
4985       goto LBL_B;
4986     }
4987   }
4988 
4989   /* passed the test */
4990   *result = MP_YES;
4991 LBL_B:mp_clear (&b);
4992       mp_clear (&c);
4993       XFREE(base, NULL, DYNAMIC_TYPE_TMP_BUFFER);
4994   return err;
4995 }
4996 
4997 #endif /* WOLFSSL_KEY_GEN NO_DH NO_DSA NO_RSA */
4998 
4999 #ifdef WOLFSSL_KEY_GEN
5000 
5001 static const int USE_BBS = 1;
5002 
mp_rand_prime(mp_int * N,int len,WC_RNG * rng,void * heap)5003 int mp_rand_prime(mp_int* N, int len, WC_RNG* rng, void* heap)
5004 {
5005     int   err, res, type;
5006     byte* buf;
5007 
5008     if (N == NULL || rng == NULL)
5009         return MP_VAL;
5010 
5011     /* get type */
5012     if (len < 0) {
5013         type = USE_BBS;
5014         len = -len;
5015     } else {
5016         type = 0;
5017     }
5018 
5019     /* allow sizes between 2 and 512 bytes for a prime size */
5020     if (len < 2 || len > 512) {
5021         return MP_VAL;
5022     }
5023 
5024     /* allocate buffer to work with */
5025     buf = (byte*)XMALLOC(len, heap, DYNAMIC_TYPE_RSA);
5026     if (buf == NULL) {
5027         return MP_MEM;
5028     }
5029     XMEMSET(buf, 0, len);
5030 
5031     do {
5032 #ifdef SHOW_GEN
5033         printf(".");
5034         fflush(stdout);
5035 #endif
5036         /* generate value */
5037         err = wc_RNG_GenerateBlock(rng, buf, len);
5038         if (err != 0) {
5039             XFREE(buf, heap, DYNAMIC_TYPE_RSA);
5040             return err;
5041         }
5042 
5043         /* munge bits */
5044         buf[0]     |= 0x80 | 0x40;
5045         buf[len-1] |= 0x01 | ((type & USE_BBS) ? 0x02 : 0x00);
5046 
5047         /* load value */
5048         if ((err = mp_read_unsigned_bin(N, buf, len)) != MP_OKAY) {
5049             XFREE(buf, heap, DYNAMIC_TYPE_RSA);
5050             return err;
5051         }
5052 
5053         /* test */
5054         /* Running Miller-Rabin up to 3 times gives us a 2^{-80} chance
5055          * of a 1024-bit candidate being a false positive, when it is our
5056          * prime candidate. (Note 4.49 of Handbook of Applied Cryptography.)
5057          * Using 8 because we've always used 8. */
5058         if ((err = mp_prime_is_prime_ex(N, 8, &res, rng)) != MP_OKAY) {
5059             XFREE(buf, heap, DYNAMIC_TYPE_RSA);
5060             return err;
5061         }
5062     } while (res == MP_NO);
5063 
5064     XMEMSET(buf, 0, len);
5065     XFREE(buf, heap, DYNAMIC_TYPE_RSA);
5066 
5067     return MP_OKAY;
5068 }
5069 
5070 
5071 /* computes least common multiple as |a*b|/(a, b) */
mp_lcm(mp_int * a,mp_int * b,mp_int * c)5072 int mp_lcm (mp_int * a, mp_int * b, mp_int * c)
5073 {
5074   int     res;
5075   mp_int  t1, t2;
5076 
5077   /* LCM of 0 and any number is undefined as 0 is not in the set of values
5078    * being used. */
5079   if (mp_iszero (a) == MP_YES || mp_iszero (b) == MP_YES) {
5080     return MP_VAL;
5081   }
5082 
5083   if ((res = mp_init_multi (&t1, &t2, NULL, NULL, NULL, NULL)) != MP_OKAY) {
5084     return res;
5085   }
5086 
5087   /* t1 = get the GCD of the two inputs */
5088   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
5089     goto LBL_T;
5090   }
5091 
5092   /* divide the smallest by the GCD */
5093   if (mp_cmp_mag(a, b) == MP_LT) {
5094      /* store quotient in t2 such that t2 * b is the LCM */
5095      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
5096         goto LBL_T;
5097      }
5098      res = mp_mul(b, &t2, c);
5099   } else {
5100      /* store quotient in t2 such that t2 * a is the LCM */
5101      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
5102         goto LBL_T;
5103      }
5104      res = mp_mul(a, &t2, c);
5105   }
5106 
5107   /* fix the sign to positive */
5108   c->sign = MP_ZPOS;
5109 
5110 LBL_T:
5111   mp_clear(&t1);
5112   mp_clear(&t2);
5113   return res;
5114 }
5115 
5116 
5117 
5118 /* Greatest Common Divisor using the binary method */
mp_gcd(mp_int * a,mp_int * b,mp_int * c)5119 int mp_gcd (mp_int * a, mp_int * b, mp_int * c)
5120 {
5121     mp_int  u, v;
5122     int     k, u_lsb, v_lsb, res;
5123 
5124     /* either zero than gcd is the largest */
5125     if (mp_iszero (a) == MP_YES) {
5126         /* GCD of 0 and 0 is undefined as all integers divide 0. */
5127         if (mp_iszero (b) == MP_YES) {
5128            return MP_VAL;
5129         }
5130         return mp_abs (b, c);
5131     }
5132     if (mp_iszero (b) == MP_YES) {
5133         return mp_abs (a, c);
5134     }
5135 
5136     /* get copies of a and b we can modify */
5137     if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
5138         return res;
5139     }
5140 
5141     if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
5142         goto LBL_U;
5143     }
5144 
5145     /* must be positive for the remainder of the algorithm */
5146     u.sign = v.sign = MP_ZPOS;
5147 
5148     /* B1.  Find the common power of two for u and v */
5149     u_lsb = mp_cnt_lsb(&u);
5150     v_lsb = mp_cnt_lsb(&v);
5151     k     = MIN(u_lsb, v_lsb);
5152 
5153     if (k > 0) {
5154         /* divide the power of two out */
5155         if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
5156             goto LBL_V;
5157         }
5158 
5159         if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
5160             goto LBL_V;
5161         }
5162     }
5163 
5164     /* divide any remaining factors of two out */
5165     if (u_lsb != k) {
5166         if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
5167             goto LBL_V;
5168         }
5169     }
5170 
5171     if (v_lsb != k) {
5172         if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
5173             goto LBL_V;
5174         }
5175     }
5176 
5177     while (mp_iszero(&v) == MP_NO) {
5178         /* make sure v is the largest */
5179         if (mp_cmp_mag(&u, &v) == MP_GT) {
5180             /* swap u and v to make sure v is >= u */
5181             mp_exch(&u, &v);
5182         }
5183 
5184         /* subtract smallest from largest */
5185         if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
5186             goto LBL_V;
5187         }
5188 
5189         /* Divide out all factors of two */
5190         if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
5191             goto LBL_V;
5192         }
5193     }
5194 
5195     /* multiply by 2**k which we divided out at the beginning */
5196     if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
5197         goto LBL_V;
5198     }
5199     c->sign = MP_ZPOS;
5200     res = MP_OKAY;
5201 LBL_V:mp_clear (&v);
5202 LBL_U:mp_clear (&u);
5203     return res;
5204 }
5205 
5206 #endif /* WOLFSSL_KEY_GEN */
5207 
5208 
5209 #if !defined(NO_DSA) || defined(HAVE_ECC) || defined(WOLFSSL_KEY_GEN) || \
5210     defined(HAVE_COMP_KEY) || defined(WOLFSSL_DEBUG_MATH) || \
5211     defined(DEBUG_WOLFSSL) || defined(OPENSSL_EXTRA) || defined(WC_MP_TO_RADIX)
5212 
5213 /* chars used in radix conversions */
5214 const char *mp_s_rmap = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
5215                         "abcdefghijklmnopqrstuvwxyz+/";
5216 #endif
5217 
5218 #if !defined(NO_DSA) || defined(HAVE_ECC)
5219 /* read a string [ASCII] in a given radix */
mp_read_radix(mp_int * a,const char * str,int radix)5220 int mp_read_radix (mp_int * a, const char *str, int radix)
5221 {
5222   int     y, res, neg;
5223   char    ch;
5224 
5225   /* zero the digit bignum */
5226   mp_zero(a);
5227 
5228   /* make sure the radix is ok */
5229   if (radix < MP_RADIX_BIN || radix > MP_RADIX_MAX) {
5230     return MP_VAL;
5231   }
5232 
5233   /* if the leading digit is a
5234    * minus set the sign to negative.
5235    */
5236   if (*str == '-') {
5237     ++str;
5238     neg = MP_NEG;
5239   } else {
5240     neg = MP_ZPOS;
5241   }
5242 
5243   /* set the integer to the default of zero */
5244   mp_zero (a);
5245 
5246   /* process each digit of the string */
5247   while (*str != '\0') {
5248     /* if the radix <= 36 the conversion is case insensitive
5249      * this allows numbers like 1AB and 1ab to represent the same  value
5250      * [e.g. in hex]
5251      */
5252     ch = (radix <= 36) ? (char)XTOUPPER((unsigned char)*str) : *str;
5253     for (y = 0; y < 64; y++) {
5254       if (ch == mp_s_rmap[y]) {
5255          break;
5256       }
5257     }
5258 
5259     /* if the char was found in the map
5260      * and is less than the given radix add it
5261      * to the number, otherwise exit the loop.
5262      */
5263     if (y < radix) {
5264       if ((res = mp_mul_d (a, (mp_digit) radix, a)) != MP_OKAY) {
5265          mp_zero(a);
5266          return res;
5267       }
5268       if ((res = mp_add_d (a, (mp_digit) y, a)) != MP_OKAY) {
5269          mp_zero(a);
5270          return res;
5271       }
5272     } else {
5273       break;
5274     }
5275     ++str;
5276   }
5277 
5278   /* if digit in isn't null term, then invalid character was found */
5279   if (*str != '\0') {
5280      mp_zero (a);
5281      return MP_VAL;
5282   }
5283 
5284   /* set the sign only if a != 0 */
5285   if (mp_iszero(a) != MP_YES) {
5286      a->sign = neg;
5287   }
5288   return MP_OKAY;
5289 }
5290 #endif /* !defined(NO_DSA) || defined(HAVE_ECC) */
5291 
5292 #ifdef WC_MP_TO_RADIX
5293 
5294 /* returns size of ASCII representation */
mp_radix_size(mp_int * a,int radix,int * size)5295 int mp_radix_size (mp_int *a, int radix, int *size)
5296 {
5297     int     res, digs;
5298     mp_int  t;
5299     mp_digit d;
5300 
5301     *size = 0;
5302 
5303     /* special case for binary */
5304     if (radix == MP_RADIX_BIN) {
5305         *size = mp_count_bits(a);
5306         if (*size == 0)
5307           *size = 1;
5308         *size += (a->sign == MP_NEG ? 1 : 0) + 1; /* "-" sign + null term */
5309         return MP_OKAY;
5310     }
5311 
5312     /* make sure the radix is in range */
5313     if (radix < MP_RADIX_BIN || radix > MP_RADIX_MAX) {
5314         return MP_VAL;
5315     }
5316 
5317     if (mp_iszero(a) == MP_YES) {
5318 #ifndef WC_DISABLE_RADIX_ZERO_PAD
5319         if (radix == 16)
5320             *size = 3;
5321         else
5322 #endif
5323             *size = 2;
5324         return MP_OKAY;
5325     }
5326 
5327     /* digs is the digit count */
5328     digs = 0;
5329 
5330     /* init a copy of the input */
5331     if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
5332         return res;
5333     }
5334 
5335     /* force temp to positive */
5336     t.sign = MP_ZPOS;
5337 
5338     /* fetch out all of the digits */
5339     while (mp_iszero (&t) == MP_NO) {
5340         if ((res = mp_div_d (&t, (mp_digit) radix, &t, &d)) != MP_OKAY) {
5341             mp_clear (&t);
5342             return res;
5343         }
5344         ++digs;
5345     }
5346     mp_clear (&t);
5347 
5348 #ifndef WC_DISABLE_RADIX_ZERO_PAD
5349     /* For hexadecimal output, add zero padding when number of digits is odd */
5350     if ((digs & 1) && (radix == 16)) {
5351         ++digs;
5352     }
5353 #endif
5354 
5355     /* if it's negative add one for the sign */
5356     if (a->sign == MP_NEG) {
5357         ++digs;
5358     }
5359 
5360     /* return digs + 1, the 1 is for the NULL byte that would be required. */
5361     *size = digs + 1;
5362     return MP_OKAY;
5363 }
5364 
5365 /* stores a bignum as a ASCII string in a given radix (2..64) */
mp_toradix(mp_int * a,char * str,int radix)5366 int mp_toradix (mp_int *a, char *str, int radix)
5367 {
5368     int     res, digs;
5369     mp_int  t;
5370     mp_digit d;
5371     char   *_s = str;
5372 
5373     /* check range of the radix */
5374     if (radix < MP_RADIX_BIN || radix > MP_RADIX_MAX) {
5375         return MP_VAL;
5376     }
5377 
5378     /* quick out if its zero */
5379     if (mp_iszero(a) == MP_YES) {
5380 #ifndef WC_DISABLE_RADIX_ZERO_PAD
5381         if (radix == 16) {
5382             *str++ = '0';
5383         }
5384 #endif
5385         *str++ = '0';
5386         *str = '\0';
5387         return MP_OKAY;
5388     }
5389 
5390     if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
5391         return res;
5392     }
5393 
5394     /* if it is negative output a - */
5395     if (t.sign == MP_NEG) {
5396         ++_s;
5397         *str++ = '-';
5398         t.sign = MP_ZPOS;
5399     }
5400 
5401     digs = 0;
5402     while (mp_iszero (&t) == MP_NO) {
5403         if ((res = mp_div_d (&t, (mp_digit) radix, &t, &d)) != MP_OKAY) {
5404             mp_clear (&t);
5405             return res;
5406         }
5407         *str++ = mp_s_rmap[d];
5408         ++digs;
5409     }
5410 #ifndef WC_DISABLE_RADIX_ZERO_PAD
5411     /* For hexadecimal output, add zero padding when number of digits is odd */
5412     if ((digs & 1) && (radix == 16)) {
5413         *str++ = mp_s_rmap[0];
5414         ++digs;
5415     }
5416 #endif
5417     /* reverse the digits of the string.  In this case _s points
5418      * to the first digit [excluding the sign] of the number]
5419      */
5420     bn_reverse ((unsigned char *)_s, digs);
5421 
5422     /* append a NULL so the string is properly terminated */
5423     *str = '\0';
5424 
5425     mp_clear (&t);
5426     return MP_OKAY;
5427 }
5428 
5429 #ifdef WOLFSSL_DEBUG_MATH
mp_dump(const char * desc,mp_int * a,byte verbose)5430 void mp_dump(const char* desc, mp_int* a, byte verbose)
5431 {
5432   char *buffer;
5433   int size = a->alloc;
5434 
5435   buffer = (char*)XMALLOC(size * sizeof(mp_digit) * 2, NULL, DYNAMIC_TYPE_TMP_BUFFER);
5436   if (buffer == NULL) {
5437     return;
5438   }
5439 
5440   printf("%s: ptr=%p, used=%d, sign=%d, size=%d, mpd=%d\n",
5441     desc, a, a->used, a->sign, size, (int)sizeof(mp_digit));
5442 
5443   mp_tohex(a, buffer);
5444   printf("  %s\n  ", buffer);
5445 
5446   if (verbose) {
5447     int i;
5448     for(i=0; i<a->alloc * (int)sizeof(mp_digit); i++) {
5449       printf("%02x ", *(((byte*)a->dp) + i));
5450     }
5451     printf("\n");
5452   }
5453 
5454   XFREE(buffer, NULL, DYNAMIC_TYPE_TMP_BUFFER);
5455 }
5456 #endif /* WOLFSSL_DEBUG_MATH */
5457 
5458 #endif /* WC_MP_TO_RADIX */
5459 
5460 #endif /* WOLFSSL_SP_MATH */
5461 
5462 #endif /* USE_FAST_MATH */
5463 
5464 #endif /* NO_BIG_INT */
5465