xref: /reactos/dll/win32/rsaenh/mpi.c (revision 5100859e)
1 /*
2  * dlls/rsaenh/mpi.c
3  * Multi Precision Integer functions
4  *
5  * Copyright 2004 Michael Jung
6  * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
7  *
8  * This library is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with this library; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
21  */
22 
23 /*
24  * This file contains code from the LibTomCrypt cryptographic
25  * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26  * is in the public domain. The code in this file is tailored to
27  * special requirements. Take a look at http://libtomcrypt.org for the
28  * original version.
29  */
30 
31 #include <stdarg.h>
32 
33 #include <windef.h>
34 #include <winbase.h>
35 #include "tomcrypt.h"
36 
37 /* Known optimal configurations
38  CPU                    /Compiler     /MUL CUTOFF/SQR CUTOFF
39 -------------------------------------------------------------
40  Intel P4 Northwood     /GCC v3.4.1   /        88/       128/LTM 0.32 ;-)
41 */
42 static const int KARATSUBA_MUL_CUTOFF = 88,  /* Min. number of digits before Karatsuba multiplication is used. */
43                  KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
44 
45 
46 /* trim unused digits */
47 static void mp_clamp(mp_int *a);
48 
49 /* compare |a| to |b| */
50 static int mp_cmp_mag(const mp_int *a, const mp_int *b);
51 
52 /* Counts the number of lsbs which are zero before the first zero bit */
53 static int mp_cnt_lsb(const mp_int *a);
54 
55 /* computes a = B**n mod b without division or multiplication useful for
56  * normalizing numbers in a Montgomery system.
57  */
58 static int mp_montgomery_calc_normalization(mp_int *a, const mp_int *b);
59 
60 /* computes x/R == x (mod N) via Montgomery Reduction */
61 static int mp_montgomery_reduce(mp_int *a, const mp_int *m, mp_digit mp);
62 
63 /* setups the montgomery reduction */
64 static int mp_montgomery_setup(const mp_int *a, mp_digit *mp);
65 
66 /* Barrett Reduction, computes a (mod b) with a precomputed value c
67  *
68  * Assumes that 0 < a <= b*b, note if 0 > a > -(b*b) then you can merely
69  * compute the reduction as -1 * mp_reduce(mp_abs(a)) [pseudo code].
70  */
71 static int mp_reduce(mp_int *a, const mp_int *b, const mp_int *c);
72 
73 /* reduces a modulo b where b is of the form 2**p - k [0 <= a] */
74 static int mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d);
75 
76 /* determines k value for 2k reduction */
77 static int mp_reduce_2k_setup(const mp_int *a, mp_digit *d);
78 
79 /* used to setup the Barrett reduction for a given modulus b */
80 static int mp_reduce_setup(mp_int *a, const mp_int *b);
81 
82 /* set to a digit */
83 static void mp_set(mp_int *a, mp_digit b);
84 
85 /* b = a*a  */
86 static int mp_sqr(const mp_int *a, mp_int *b);
87 
88 /* c = a * a (mod b) */
89 static int mp_sqrmod(const mp_int *a, mp_int *b, mp_int *c);
90 
91 
92 static void bn_reverse(unsigned char *s, int len);
93 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
94 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
95 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
96 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
97 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
98 static int s_mp_sqr(const mp_int *a, mp_int *b);
99 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
100 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
101 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
102 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
103 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
104 
105 /* grow as required */
106 static int mp_grow (mp_int * a, int size)
107 {
108   int     i;
109   mp_digit *tmp;
110 
111   /* if the alloc size is smaller alloc more ram */
112   if (a->alloc < size) {
113     /* ensure there are always at least MP_PREC digits extra on top */
114     size += (MP_PREC * 2) - (size % MP_PREC);
115 
116     /* reallocate the array a->dp
117      *
118      * We store the return in a temporary variable
119      * in case the operation failed we don't want
120      * to overwrite the dp member of a.
121      */
122     tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * size);
123     if (tmp == NULL) {
124       /* reallocation failed but "a" is still valid [can be freed] */
125       return MP_MEM;
126     }
127 
128     /* reallocation succeeded so set a->dp */
129     a->dp = tmp;
130 
131     /* zero excess digits */
132     i        = a->alloc;
133     a->alloc = size;
134     for (; i < a->alloc; i++) {
135       a->dp[i] = 0;
136     }
137   }
138   return MP_OKAY;
139 }
140 
141 /* b = a/2 */
142 static int mp_div_2(const mp_int * a, mp_int * b)
143 {
144   int     x, res, oldused;
145 
146   /* copy */
147   if (b->alloc < a->used) {
148     if ((res = mp_grow (b, a->used)) != MP_OKAY) {
149       return res;
150     }
151   }
152 
153   oldused = b->used;
154   b->used = a->used;
155   {
156     register mp_digit r, rr, *tmpa, *tmpb;
157 
158     /* source alias */
159     tmpa = a->dp + b->used - 1;
160 
161     /* dest alias */
162     tmpb = b->dp + b->used - 1;
163 
164     /* carry */
165     r = 0;
166     for (x = b->used - 1; x >= 0; x--) {
167       /* get the carry for the next iteration */
168       rr = *tmpa & 1;
169 
170       /* shift the current digit, add in carry and store */
171       *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
172 
173       /* forward carry to next iteration */
174       r = rr;
175     }
176 
177     /* zero excess digits */
178     tmpb = b->dp + b->used;
179     for (x = b->used; x < oldused; x++) {
180       *tmpb++ = 0;
181     }
182   }
183   b->sign = a->sign;
184   mp_clamp (b);
185   return MP_OKAY;
186 }
187 
188 /* swap the elements of two integers, for cases where you can't simply swap the
189  * mp_int pointers around
190  */
191 static void
192 mp_exch (mp_int * a, mp_int * b)
193 {
194   mp_int  t;
195 
196   t  = *a;
197   *a = *b;
198   *b = t;
199 }
200 
201 /* init a new mp_int */
202 static int mp_init (mp_int * a)
203 {
204   int i;
205 
206   /* allocate memory required and clear it */
207   a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * MP_PREC);
208   if (a->dp == NULL) {
209     return MP_MEM;
210   }
211 
212   /* set the digits to zero */
213   for (i = 0; i < MP_PREC; i++) {
214       a->dp[i] = 0;
215   }
216 
217   /* set the used to zero, allocated digits to the default precision
218    * and sign to positive */
219   a->used  = 0;
220   a->alloc = MP_PREC;
221   a->sign  = MP_ZPOS;
222 
223   return MP_OKAY;
224 }
225 
226 /* init an mp_init for a given size */
227 static int mp_init_size (mp_int * a, int size)
228 {
229   int x;
230 
231   /* pad size so there are always extra digits */
232   size += (MP_PREC * 2) - (size % MP_PREC);
233 
234   /* alloc mem */
235   a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * size);
236   if (a->dp == NULL) {
237     return MP_MEM;
238   }
239 
240   /* set the members */
241   a->used  = 0;
242   a->alloc = size;
243   a->sign  = MP_ZPOS;
244 
245   /* zero the digits */
246   for (x = 0; x < size; x++) {
247       a->dp[x] = 0;
248   }
249 
250   return MP_OKAY;
251 }
252 
253 /* clear one (frees)  */
254 static void
255 mp_clear (mp_int * a)
256 {
257   int i;
258 
259   /* only do anything if a hasn't been freed previously */
260   if (a->dp != NULL) {
261     /* first zero the digits */
262     for (i = 0; i < a->used; i++) {
263         a->dp[i] = 0;
264     }
265 
266     /* free ram */
267     HeapFree(GetProcessHeap(), 0, a->dp);
268 
269     /* reset members to make debugging easier */
270     a->dp    = NULL;
271     a->alloc = a->used = 0;
272     a->sign  = MP_ZPOS;
273   }
274 }
275 
276 /* set to zero */
277 static void
278 mp_zero (mp_int * a)
279 {
280   a->sign = MP_ZPOS;
281   a->used = 0;
282   memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
283 }
284 
285 /* b = |a|
286  *
287  * Simple function copies the input and fixes the sign to positive
288  */
289 static int
290 mp_abs (const mp_int * a, mp_int * b)
291 {
292   int     res;
293 
294   /* copy a to b */
295   if (a != b) {
296      if ((res = mp_copy (a, b)) != MP_OKAY) {
297        return res;
298      }
299   }
300 
301   /* force the sign of b to positive */
302   b->sign = MP_ZPOS;
303 
304   return MP_OKAY;
305 }
306 
307 /* computes the modular inverse via binary extended euclidean algorithm,
308  * that is c = 1/a mod b
309  *
310  * Based on slow invmod except this is optimized for the case where b is
311  * odd as per HAC Note 14.64 on pp. 610
312  */
313 static int
314 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
315 {
316   mp_int  x, y, u, v, B, D;
317   int     res, neg;
318 
319   /* 2. [modified] b must be odd   */
320   if (mp_iseven (b) == 1) {
321     return MP_VAL;
322   }
323 
324   /* init all our temps */
325   if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
326      return res;
327   }
328 
329   /* x == modulus, y == value to invert */
330   if ((res = mp_copy (b, &x)) != MP_OKAY) {
331     goto __ERR;
332   }
333 
334   /* we need y = |a| */
335   if ((res = mp_abs (a, &y)) != MP_OKAY) {
336     goto __ERR;
337   }
338 
339   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
340   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
341     goto __ERR;
342   }
343   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
344     goto __ERR;
345   }
346   mp_set (&D, 1);
347 
348 top:
349   /* 4.  while u is even do */
350   while (mp_iseven (&u) == 1) {
351     /* 4.1 u = u/2 */
352     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
353       goto __ERR;
354     }
355     /* 4.2 if B is odd then */
356     if (mp_isodd (&B) == 1) {
357       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
358         goto __ERR;
359       }
360     }
361     /* B = B/2 */
362     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
363       goto __ERR;
364     }
365   }
366 
367   /* 5.  while v is even do */
368   while (mp_iseven (&v) == 1) {
369     /* 5.1 v = v/2 */
370     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
371       goto __ERR;
372     }
373     /* 5.2 if D is odd then */
374     if (mp_isodd (&D) == 1) {
375       /* D = (D-x)/2 */
376       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
377         goto __ERR;
378       }
379     }
380     /* D = D/2 */
381     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
382       goto __ERR;
383     }
384   }
385 
386   /* 6.  if u >= v then */
387   if (mp_cmp (&u, &v) != MP_LT) {
388     /* u = u - v, B = B - D */
389     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
390       goto __ERR;
391     }
392 
393     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
394       goto __ERR;
395     }
396   } else {
397     /* v - v - u, D = D - B */
398     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
399       goto __ERR;
400     }
401 
402     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
403       goto __ERR;
404     }
405   }
406 
407   /* if not zero goto step 4 */
408   if (mp_iszero (&u) == 0) {
409     goto top;
410   }
411 
412   /* now a = C, b = D, gcd == g*v */
413 
414   /* if v != 1 then there is no inverse */
415   if (mp_cmp_d (&v, 1) != MP_EQ) {
416     res = MP_VAL;
417     goto __ERR;
418   }
419 
420   /* b is now the inverse */
421   neg = a->sign;
422   while (D.sign == MP_NEG) {
423     if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
424       goto __ERR;
425     }
426   }
427   mp_exch (&D, c);
428   c->sign = neg;
429   res = MP_OKAY;
430 
431 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
432   return res;
433 }
434 
435 /* computes xR**-1 == x (mod N) via Montgomery Reduction
436  *
437  * This is an optimized implementation of montgomery_reduce
438  * which uses the comba method to quickly calculate the columns of the
439  * reduction.
440  *
441  * Based on Algorithm 14.32 on pp.601 of HAC.
442 */
443 static int
444 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
445 {
446   int     ix, res, olduse;
447   mp_word W[MP_WARRAY];
448 
449   /* get old used count */
450   olduse = x->used;
451 
452   /* grow a as required */
453   if (x->alloc < n->used + 1) {
454     if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
455       return res;
456     }
457   }
458 
459   /* first we have to get the digits of the input into
460    * an array of double precision words W[...]
461    */
462   {
463     register mp_word *_W;
464     register mp_digit *tmpx;
465 
466     /* alias for the W[] array */
467     _W   = W;
468 
469     /* alias for the digits of  x*/
470     tmpx = x->dp;
471 
472     /* copy the digits of a into W[0..a->used-1] */
473     for (ix = 0; ix < x->used; ix++) {
474       *_W++ = *tmpx++;
475     }
476 
477     /* zero the high words of W[a->used..m->used*2] */
478     for (; ix < n->used * 2 + 1; ix++) {
479       *_W++ = 0;
480     }
481   }
482 
483   /* now we proceed to zero successive digits
484    * from the least significant upwards
485    */
486   for (ix = 0; ix < n->used; ix++) {
487     /* mu = ai * m' mod b
488      *
489      * We avoid a double precision multiplication (which isn't required)
490      * by casting the value down to a mp_digit.  Note this requires
491      * that W[ix-1] have  the carry cleared (see after the inner loop)
492      */
493     register mp_digit mu;
494     mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
495 
496     /* a = a + mu * m * b**i
497      *
498      * This is computed in place and on the fly.  The multiplication
499      * by b**i is handled by offsetting which columns the results
500      * are added to.
501      *
502      * Note the comba method normally doesn't handle carries in the
503      * inner loop In this case we fix the carry from the previous
504      * column since the Montgomery reduction requires digits of the
505      * result (so far) [see above] to work.  This is
506      * handled by fixing up one carry after the inner loop.  The
507      * carry fixups are done in order so after these loops the
508      * first m->used words of W[] have the carries fixed
509      */
510     {
511       register int iy;
512       register mp_digit *tmpn;
513       register mp_word *_W;
514 
515       /* alias for the digits of the modulus */
516       tmpn = n->dp;
517 
518       /* Alias for the columns set by an offset of ix */
519       _W = W + ix;
520 
521       /* inner loop */
522       for (iy = 0; iy < n->used; iy++) {
523           *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
524       }
525     }
526 
527     /* now fix carry for next digit, W[ix+1] */
528     W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
529   }
530 
531   /* now we have to propagate the carries and
532    * shift the words downward [all those least
533    * significant digits we zeroed].
534    */
535   {
536     register mp_digit *tmpx;
537     register mp_word *_W, *_W1;
538 
539     /* nox fix rest of carries */
540 
541     /* alias for current word */
542     _W1 = W + ix;
543 
544     /* alias for next word, where the carry goes */
545     _W = W + ++ix;
546 
547     for (; ix <= n->used * 2 + 1; ix++) {
548       *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
549     }
550 
551     /* copy out, A = A/b**n
552      *
553      * The result is A/b**n but instead of converting from an
554      * array of mp_word to mp_digit than calling mp_rshd
555      * we just copy them in the right order
556      */
557 
558     /* alias for destination word */
559     tmpx = x->dp;
560 
561     /* alias for shifted double precision result */
562     _W = W + n->used;
563 
564     for (ix = 0; ix < n->used + 1; ix++) {
565       *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
566     }
567 
568     /* zero oldused digits, if the input a was larger than
569      * m->used+1 we'll have to clear the digits
570      */
571     for (; ix < olduse; ix++) {
572       *tmpx++ = 0;
573     }
574   }
575 
576   /* set the max used and clamp */
577   x->used = n->used + 1;
578   mp_clamp (x);
579 
580   /* if A >= m then A = A - m */
581   if (mp_cmp_mag (x, n) != MP_LT) {
582     return s_mp_sub (x, n, x);
583   }
584   return MP_OKAY;
585 }
586 
587 /* Fast (comba) multiplier
588  *
589  * This is the fast column-array [comba] multiplier.  It is
590  * designed to compute the columns of the product first
591  * then handle the carries afterwards.  This has the effect
592  * of making the nested loops that compute the columns very
593  * simple and schedulable on super-scalar processors.
594  *
595  * This has been modified to produce a variable number of
596  * digits of output so if say only a half-product is required
597  * you don't have to compute the upper half (a feature
598  * required for fast Barrett reduction).
599  *
600  * Based on Algorithm 14.12 on pp.595 of HAC.
601  *
602  */
603 static int
604 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
605 {
606   int     olduse, res, pa, ix, iz;
607   mp_digit W[MP_WARRAY];
608   register mp_word  _W;
609 
610   /* grow the destination as required */
611   if (c->alloc < digs) {
612     if ((res = mp_grow (c, digs)) != MP_OKAY) {
613       return res;
614     }
615   }
616 
617   /* number of output digits to produce */
618   pa = MIN(digs, a->used + b->used);
619 
620   /* clear the carry */
621   _W = 0;
622   for (ix = 0; ix <= pa; ix++) {
623       int      tx, ty;
624       int      iy;
625       mp_digit *tmpx, *tmpy;
626 
627       /* get offsets into the two bignums */
628       ty = MIN(b->used-1, ix);
629       tx = ix - ty;
630 
631       /* setup temp aliases */
632       tmpx = a->dp + tx;
633       tmpy = b->dp + ty;
634 
635       /* This is the number of times the loop will iterate, essentially it's
636          while (tx++ < a->used && ty-- >= 0) { ... }
637        */
638       iy = MIN(a->used-tx, ty+1);
639 
640       /* execute loop */
641       for (iz = 0; iz < iy; ++iz) {
642          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
643       }
644 
645       /* store term */
646       W[ix] = ((mp_digit)_W) & MP_MASK;
647 
648       /* make next carry */
649       _W = _W >> ((mp_word)DIGIT_BIT);
650   }
651 
652   /* setup dest */
653   olduse  = c->used;
654   c->used = digs;
655 
656   {
657     register mp_digit *tmpc;
658     tmpc = c->dp;
659     for (ix = 0; ix < digs; ix++) {
660       /* now extract the previous digit [below the carry] */
661       *tmpc++ = W[ix];
662     }
663 
664     /* clear unused digits [that existed in the old copy of c] */
665     for (; ix < olduse; ix++) {
666       *tmpc++ = 0;
667     }
668   }
669   mp_clamp (c);
670   return MP_OKAY;
671 }
672 
673 /* this is a modified version of fast_s_mul_digs that only produces
674  * output digits *above* digs.  See the comments for fast_s_mul_digs
675  * to see how it works.
676  *
677  * This is used in the Barrett reduction since for one of the multiplications
678  * only the higher digits were needed.  This essentially halves the work.
679  *
680  * Based on Algorithm 14.12 on pp.595 of HAC.
681  */
682 static int
683 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
684 {
685   int     olduse, res, pa, ix, iz;
686   mp_digit W[MP_WARRAY];
687   mp_word  _W;
688 
689   /* grow the destination as required */
690   pa = a->used + b->used;
691   if (c->alloc < pa) {
692     if ((res = mp_grow (c, pa)) != MP_OKAY) {
693       return res;
694     }
695   }
696 
697   /* number of output digits to produce */
698   pa = a->used + b->used;
699   _W = 0;
700   for (ix = digs; ix <= pa; ix++) {
701       int      tx, ty, iy;
702       mp_digit *tmpx, *tmpy;
703 
704       /* get offsets into the two bignums */
705       ty = MIN(b->used-1, ix);
706       tx = ix - ty;
707 
708       /* setup temp aliases */
709       tmpx = a->dp + tx;
710       tmpy = b->dp + ty;
711 
712       /* This is the number of times the loop will iterate, essentially it's
713          while (tx++ < a->used && ty-- >= 0) { ... }
714        */
715       iy = MIN(a->used-tx, ty+1);
716 
717       /* execute loop */
718       for (iz = 0; iz < iy; iz++) {
719          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
720       }
721 
722       /* store term */
723       W[ix] = ((mp_digit)_W) & MP_MASK;
724 
725       /* make next carry */
726       _W = _W >> ((mp_word)DIGIT_BIT);
727   }
728 
729   /* setup dest */
730   olduse  = c->used;
731   c->used = pa;
732 
733   {
734     register mp_digit *tmpc;
735 
736     tmpc = c->dp + digs;
737     for (ix = digs; ix <= pa; ix++) {
738       /* now extract the previous digit [below the carry] */
739       *tmpc++ = W[ix];
740     }
741 
742     /* clear unused digits [that existed in the old copy of c] */
743     for (; ix < olduse; ix++) {
744       *tmpc++ = 0;
745     }
746   }
747   mp_clamp (c);
748   return MP_OKAY;
749 }
750 
751 /* fast squaring
752  *
753  * This is the comba method where the columns of the product
754  * are computed first then the carries are computed.  This
755  * has the effect of making a very simple inner loop that
756  * is executed the most
757  *
758  * W2 represents the outer products and W the inner.
759  *
760  * A further optimizations is made because the inner
761  * products are of the form "A * B * 2".  The *2 part does
762  * not need to be computed until the end which is good
763  * because 64-bit shifts are slow!
764  *
765  * Based on Algorithm 14.16 on pp.597 of HAC.
766  *
767  */
768 /* the jist of squaring...
769 
770 you do like mult except the offset of the tmpx [one that starts closer to zero]
771 can't equal the offset of tmpy.  So basically you set up iy like before then you min it with
772 (ty-tx) so that it never happens.  You double all those you add in the inner loop
773 
774 After that loop you do the squares and add them in.
775 
776 Remove W2 and don't memset W
777 
778 */
779 
780 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
781 {
782   int       olduse, res, pa, ix, iz;
783   mp_digit   W[MP_WARRAY], *tmpx;
784   mp_word   W1;
785 
786   /* grow the destination as required */
787   pa = a->used + a->used;
788   if (b->alloc < pa) {
789     if ((res = mp_grow (b, pa)) != MP_OKAY) {
790       return res;
791     }
792   }
793 
794   /* number of output digits to produce */
795   W1 = 0;
796   for (ix = 0; ix <= pa; ix++) {
797       int      tx, ty, iy;
798       mp_word  _W;
799       mp_digit *tmpy;
800 
801       /* clear counter */
802       _W = 0;
803 
804       /* get offsets into the two bignums */
805       ty = MIN(a->used-1, ix);
806       tx = ix - ty;
807 
808       /* setup temp aliases */
809       tmpx = a->dp + tx;
810       tmpy = a->dp + ty;
811 
812       /* This is the number of times the loop will iterate, essentially it's
813          while (tx++ < a->used && ty-- >= 0) { ... }
814        */
815       iy = MIN(a->used-tx, ty+1);
816 
817       /* now for squaring tx can never equal ty
818        * we halve the distance since they approach at a rate of 2x
819        * and we have to round because odd cases need to be executed
820        */
821       iy = MIN(iy, (ty-tx+1)>>1);
822 
823       /* execute loop */
824       for (iz = 0; iz < iy; iz++) {
825          _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
826       }
827 
828       /* double the inner product and add carry */
829       _W = _W + _W + W1;
830 
831       /* even columns have the square term in them */
832       if ((ix&1) == 0) {
833          _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
834       }
835 
836       /* store it */
837       W[ix] = _W;
838 
839       /* make next carry */
840       W1 = _W >> ((mp_word)DIGIT_BIT);
841   }
842 
843   /* setup dest */
844   olduse  = b->used;
845   b->used = a->used+a->used;
846 
847   {
848     mp_digit *tmpb;
849     tmpb = b->dp;
850     for (ix = 0; ix < pa; ix++) {
851       *tmpb++ = W[ix] & MP_MASK;
852     }
853 
854     /* clear unused digits [that existed in the old copy of c] */
855     for (; ix < olduse; ix++) {
856       *tmpb++ = 0;
857     }
858   }
859   mp_clamp (b);
860   return MP_OKAY;
861 }
862 
863 /* computes a = 2**b
864  *
865  * Simple algorithm which zeroes the int, grows it then just sets one bit
866  * as required.
867  */
868 static int
869 mp_2expt (mp_int * a, int b)
870 {
871   int     res;
872 
873   /* zero a as per default */
874   mp_zero (a);
875 
876   /* grow a to accommodate the single bit */
877   if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
878     return res;
879   }
880 
881   /* set the used count of where the bit will go */
882   a->used = b / DIGIT_BIT + 1;
883 
884   /* put the single bit in its place */
885   a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
886 
887   return MP_OKAY;
888 }
889 
890 /* high level addition (handles signs) */
891 int mp_add (mp_int * a, mp_int * b, mp_int * c)
892 {
893   int     sa, sb, res;
894 
895   /* get sign of both inputs */
896   sa = a->sign;
897   sb = b->sign;
898 
899   /* handle two cases, not four */
900   if (sa == sb) {
901     /* both positive or both negative */
902     /* add their magnitudes, copy the sign */
903     c->sign = sa;
904     res = s_mp_add (a, b, c);
905   } else {
906     /* one positive, the other negative */
907     /* subtract the one with the greater magnitude from */
908     /* the one of the lesser magnitude.  The result gets */
909     /* the sign of the one with the greater magnitude. */
910     if (mp_cmp_mag (a, b) == MP_LT) {
911       c->sign = sb;
912       res = s_mp_sub (b, a, c);
913     } else {
914       c->sign = sa;
915       res = s_mp_sub (a, b, c);
916     }
917   }
918   return res;
919 }
920 
921 
922 /* single digit addition */
923 static int
924 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
925 {
926   int     res, ix, oldused;
927   mp_digit *tmpa, *tmpc, mu;
928 
929   /* grow c as required */
930   if (c->alloc < a->used + 1) {
931      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
932         return res;
933      }
934   }
935 
936   /* if a is negative and |a| >= b, call c = |a| - b */
937   if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
938      /* temporarily fix sign of a */
939      a->sign = MP_ZPOS;
940 
941      /* c = |a| - b */
942      res = mp_sub_d(a, b, c);
943 
944      /* fix sign  */
945      a->sign = c->sign = MP_NEG;
946 
947      return res;
948   }
949 
950   /* old number of used digits in c */
951   oldused = c->used;
952 
953   /* sign always positive */
954   c->sign = MP_ZPOS;
955 
956   /* source alias */
957   tmpa    = a->dp;
958 
959   /* destination alias */
960   tmpc    = c->dp;
961 
962   /* if a is positive */
963   if (a->sign == MP_ZPOS) {
964      /* add digit, after this we're propagating
965       * the carry.
966       */
967      *tmpc   = *tmpa++ + b;
968      mu      = *tmpc >> DIGIT_BIT;
969      *tmpc++ &= MP_MASK;
970 
971      /* now handle rest of the digits */
972      for (ix = 1; ix < a->used; ix++) {
973         *tmpc   = *tmpa++ + mu;
974         mu      = *tmpc >> DIGIT_BIT;
975         *tmpc++ &= MP_MASK;
976      }
977      /* set final carry */
978      ix++;
979      *tmpc++  = mu;
980 
981      /* setup size */
982      c->used = a->used + 1;
983   } else {
984      /* a was negative and |a| < b */
985      c->used  = 1;
986 
987      /* the result is a single digit */
988      if (a->used == 1) {
989         *tmpc++  =  b - a->dp[0];
990      } else {
991         *tmpc++  =  b;
992      }
993 
994      /* setup count so the clearing of oldused
995       * can fall through correctly
996       */
997      ix       = 1;
998   }
999 
1000   /* now zero to oldused */
1001   while (ix++ < oldused) {
1002      *tmpc++ = 0;
1003   }
1004   mp_clamp(c);
1005 
1006   return MP_OKAY;
1007 }
1008 
1009 /* trim unused digits
1010  *
1011  * This is used to ensure that leading zero digits are
1012  * trimed and the leading "used" digit will be non-zero
1013  * Typically very fast.  Also fixes the sign if there
1014  * are no more leading digits
1015  */
1016 void
1017 mp_clamp (mp_int * a)
1018 {
1019   /* decrease used while the most significant digit is
1020    * zero.
1021    */
1022   while (a->used > 0 && a->dp[a->used - 1] == 0) {
1023     --(a->used);
1024   }
1025 
1026   /* reset the sign flag if used == 0 */
1027   if (a->used == 0) {
1028     a->sign = MP_ZPOS;
1029   }
1030 }
1031 
1032 void mp_clear_multi(mp_int *mp, ...)
1033 {
1034     mp_int* next_mp = mp;
1035     va_list args;
1036     va_start(args, mp);
1037     while (next_mp != NULL) {
1038         mp_clear(next_mp);
1039         next_mp = va_arg(args, mp_int*);
1040     }
1041     va_end(args);
1042 }
1043 
1044 /* compare two ints (signed)*/
1045 int
1046 mp_cmp (const mp_int * a, const mp_int * b)
1047 {
1048   /* compare based on sign */
1049   if (a->sign != b->sign) {
1050      if (a->sign == MP_NEG) {
1051         return MP_LT;
1052      } else {
1053         return MP_GT;
1054      }
1055   }
1056 
1057   /* compare digits */
1058   if (a->sign == MP_NEG) {
1059      /* if negative compare opposite direction */
1060      return mp_cmp_mag(b, a);
1061   } else {
1062      return mp_cmp_mag(a, b);
1063   }
1064 }
1065 
1066 /* compare a digit */
1067 int mp_cmp_d(const mp_int * a, mp_digit b)
1068 {
1069   /* compare based on sign */
1070   if (a->sign == MP_NEG) {
1071     return MP_LT;
1072   }
1073 
1074   /* compare based on magnitude */
1075   if (a->used > 1) {
1076     return MP_GT;
1077   }
1078 
1079   /* compare the only digit of a to b */
1080   if (a->dp[0] > b) {
1081     return MP_GT;
1082   } else if (a->dp[0] < b) {
1083     return MP_LT;
1084   } else {
1085     return MP_EQ;
1086   }
1087 }
1088 
1089 /* compare maginitude of two ints (unsigned) */
1090 int mp_cmp_mag (const mp_int * a, const mp_int * b)
1091 {
1092   int     n;
1093   mp_digit *tmpa, *tmpb;
1094 
1095   /* compare based on # of non-zero digits */
1096   if (a->used > b->used) {
1097     return MP_GT;
1098   }
1099 
1100   if (a->used < b->used) {
1101     return MP_LT;
1102   }
1103 
1104   /* alias for a */
1105   tmpa = a->dp + (a->used - 1);
1106 
1107   /* alias for b */
1108   tmpb = b->dp + (a->used - 1);
1109 
1110   /* compare based on digits  */
1111   for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1112     if (*tmpa > *tmpb) {
1113       return MP_GT;
1114     }
1115 
1116     if (*tmpa < *tmpb) {
1117       return MP_LT;
1118     }
1119   }
1120   return MP_EQ;
1121 }
1122 
1123 static const int lnz[16] = {
1124    4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
1125 };
1126 
1127 /* Counts the number of lsbs which are zero before the first zero bit */
1128 int mp_cnt_lsb(const mp_int *a)
1129 {
1130    int x;
1131    mp_digit q, qq;
1132 
1133    /* easy out */
1134    if (mp_iszero(a) == 1) {
1135       return 0;
1136    }
1137 
1138    /* scan lower digits until non-zero */
1139    for (x = 0; x < a->used && a->dp[x] == 0; x++);
1140    q = a->dp[x];
1141    x *= DIGIT_BIT;
1142 
1143    /* now scan this digit until a 1 is found */
1144    if ((q & 1) == 0) {
1145       do {
1146          qq  = q & 15;
1147          x  += lnz[qq];
1148          q >>= 4;
1149       } while (qq == 0);
1150    }
1151    return x;
1152 }
1153 
1154 /* copy, b = a */
1155 int
1156 mp_copy (const mp_int * a, mp_int * b)
1157 {
1158   int     res, n;
1159 
1160   /* if dst == src do nothing */
1161   if (a == b) {
1162     return MP_OKAY;
1163   }
1164 
1165   /* grow dest */
1166   if (b->alloc < a->used) {
1167      if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1168         return res;
1169      }
1170   }
1171 
1172   /* zero b and copy the parameters over */
1173   {
1174     register mp_digit *tmpa, *tmpb;
1175 
1176     /* pointer aliases */
1177 
1178     /* source */
1179     tmpa = a->dp;
1180 
1181     /* destination */
1182     tmpb = b->dp;
1183 
1184     /* copy all the digits */
1185     for (n = 0; n < a->used; n++) {
1186       *tmpb++ = *tmpa++;
1187     }
1188 
1189     /* clear high digits */
1190     for (; n < b->used; n++) {
1191       *tmpb++ = 0;
1192     }
1193   }
1194 
1195   /* copy used count and sign */
1196   b->used = a->used;
1197   b->sign = a->sign;
1198   return MP_OKAY;
1199 }
1200 
1201 /* returns the number of bits in an int */
1202 int
1203 mp_count_bits (const mp_int * a)
1204 {
1205   int     r;
1206   mp_digit q;
1207 
1208   /* shortcut */
1209   if (a->used == 0) {
1210     return 0;
1211   }
1212 
1213   /* get number of digits and add that */
1214   r = (a->used - 1) * DIGIT_BIT;
1215 
1216   /* take the last digit and count the bits in it */
1217   q = a->dp[a->used - 1];
1218   while (q > 0) {
1219     ++r;
1220     q >>= ((mp_digit) 1);
1221   }
1222   return r;
1223 }
1224 
1225 /* calc a value mod 2**b */
1226 static int
1227 mp_mod_2d (const mp_int * a, int b, mp_int * c)
1228 {
1229   int     x, res;
1230 
1231   /* if b is <= 0 then zero the int */
1232   if (b <= 0) {
1233     mp_zero (c);
1234     return MP_OKAY;
1235   }
1236 
1237   /* if the modulus is larger than the value than return */
1238   if (b > a->used * DIGIT_BIT) {
1239     res = mp_copy (a, c);
1240     return res;
1241   }
1242 
1243   /* copy */
1244   if ((res = mp_copy (a, c)) != MP_OKAY) {
1245     return res;
1246   }
1247 
1248   /* zero digits above the last digit of the modulus */
1249   for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1250     c->dp[x] = 0;
1251   }
1252   /* clear the digit that is not completely outside/inside the modulus */
1253   c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
1254   mp_clamp (c);
1255   return MP_OKAY;
1256 }
1257 
1258 /* shift right a certain amount of digits */
1259 static void mp_rshd (mp_int * a, int b)
1260 {
1261   int     x;
1262 
1263   /* if b <= 0 then ignore it */
1264   if (b <= 0) {
1265     return;
1266   }
1267 
1268   /* if b > used then simply zero it and return */
1269   if (a->used <= b) {
1270     mp_zero (a);
1271     return;
1272   }
1273 
1274   {
1275     register mp_digit *bottom, *top;
1276 
1277     /* shift the digits down */
1278 
1279     /* bottom */
1280     bottom = a->dp;
1281 
1282     /* top [offset into digits] */
1283     top = a->dp + b;
1284 
1285     /* this is implemented as a sliding window where
1286      * the window is b-digits long and digits from
1287      * the top of the window are copied to the bottom
1288      *
1289      * e.g.
1290 
1291      b-2 | b-1 | b0 | b1 | b2 | ... | bb |   ---->
1292                  /\                   |      ---->
1293                   \-------------------/      ---->
1294      */
1295     for (x = 0; x < (a->used - b); x++) {
1296       *bottom++ = *top++;
1297     }
1298 
1299     /* zero the top digits */
1300     for (; x < a->used; x++) {
1301       *bottom++ = 0;
1302     }
1303   }
1304 
1305   /* remove excess digits */
1306   a->used -= b;
1307 }
1308 
1309 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1310 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1311 {
1312   mp_digit D, r, rr;
1313   int     x, res;
1314   mp_int  t;
1315 
1316 
1317   /* if the shift count is <= 0 then we do no work */
1318   if (b <= 0) {
1319     res = mp_copy (a, c);
1320     if (d != NULL) {
1321       mp_zero (d);
1322     }
1323     return res;
1324   }
1325 
1326   if ((res = mp_init (&t)) != MP_OKAY) {
1327     return res;
1328   }
1329 
1330   /* get the remainder */
1331   if (d != NULL) {
1332     if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1333       mp_clear (&t);
1334       return res;
1335     }
1336   }
1337 
1338   /* copy */
1339   if ((res = mp_copy (a, c)) != MP_OKAY) {
1340     mp_clear (&t);
1341     return res;
1342   }
1343 
1344   /* shift by as many digits in the bit count */
1345   if (b >= DIGIT_BIT) {
1346     mp_rshd (c, b / DIGIT_BIT);
1347   }
1348 
1349   /* shift any bit count < DIGIT_BIT */
1350   D = (mp_digit) (b % DIGIT_BIT);
1351   if (D != 0) {
1352     register mp_digit *tmpc, mask, shift;
1353 
1354     /* mask */
1355     mask = (((mp_digit)1) << D) - 1;
1356 
1357     /* shift for lsb */
1358     shift = DIGIT_BIT - D;
1359 
1360     /* alias */
1361     tmpc = c->dp + (c->used - 1);
1362 
1363     /* carry */
1364     r = 0;
1365     for (x = c->used - 1; x >= 0; x--) {
1366       /* get the lower  bits of this word in a temp */
1367       rr = *tmpc & mask;
1368 
1369       /* shift the current word and mix in the carry bits from the previous word */
1370       *tmpc = (*tmpc >> D) | (r << shift);
1371       --tmpc;
1372 
1373       /* set the carry to the carry bits of the current word found above */
1374       r = rr;
1375     }
1376   }
1377   mp_clamp (c);
1378   if (d != NULL) {
1379     mp_exch (&t, d);
1380   }
1381   mp_clear (&t);
1382   return MP_OKAY;
1383 }
1384 
1385 /* shift left a certain amount of digits */
1386 static int mp_lshd (mp_int * a, int b)
1387 {
1388   int     x, res;
1389 
1390   /* if it's less than zero return */
1391   if (b <= 0) {
1392     return MP_OKAY;
1393   }
1394 
1395   /* grow to fit the new digits */
1396   if (a->alloc < a->used + b) {
1397      if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1398        return res;
1399      }
1400   }
1401 
1402   {
1403     register mp_digit *top, *bottom;
1404 
1405     /* increment the used by the shift amount then copy upwards */
1406     a->used += b;
1407 
1408     /* top */
1409     top = a->dp + a->used - 1;
1410 
1411     /* base */
1412     bottom = a->dp + a->used - 1 - b;
1413 
1414     /* much like mp_rshd this is implemented using a sliding window
1415      * except the window goes the other way around.  Copying from
1416      * the bottom to the top.  see bn_mp_rshd.c for more info.
1417      */
1418     for (x = a->used - 1; x >= b; x--) {
1419       *top-- = *bottom--;
1420     }
1421 
1422     /* zero the lower digits */
1423     top = a->dp;
1424     for (x = 0; x < b; x++) {
1425       *top++ = 0;
1426     }
1427   }
1428   return MP_OKAY;
1429 }
1430 
1431 /* shift left by a certain bit count */
1432 static int mp_mul_2d (const mp_int * a, int b, mp_int * c)
1433 {
1434   mp_digit d;
1435   int      res;
1436 
1437   /* copy */
1438   if (a != c) {
1439      if ((res = mp_copy (a, c)) != MP_OKAY) {
1440        return res;
1441      }
1442   }
1443 
1444   if (c->alloc < c->used + b/DIGIT_BIT + 1) {
1445      if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1446        return res;
1447      }
1448   }
1449 
1450   /* shift by as many digits in the bit count */
1451   if (b >= DIGIT_BIT) {
1452     if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1453       return res;
1454     }
1455   }
1456 
1457   /* shift any bit count < DIGIT_BIT */
1458   d = (mp_digit) (b % DIGIT_BIT);
1459   if (d != 0) {
1460     register mp_digit *tmpc, shift, mask, r, rr;
1461     register int x;
1462 
1463     /* bitmask for carries */
1464     mask = (((mp_digit)1) << d) - 1;
1465 
1466     /* shift for msbs */
1467     shift = DIGIT_BIT - d;
1468 
1469     /* alias */
1470     tmpc = c->dp;
1471 
1472     /* carry */
1473     r    = 0;
1474     for (x = 0; x < c->used; x++) {
1475       /* get the higher bits of the current word */
1476       rr = (*tmpc >> shift) & mask;
1477 
1478       /* shift the current word and OR in the carry */
1479       *tmpc = ((*tmpc << d) | r) & MP_MASK;
1480       ++tmpc;
1481 
1482       /* set the carry to the carry bits of the current word */
1483       r = rr;
1484     }
1485 
1486     /* set final carry */
1487     if (r != 0) {
1488        c->dp[(c->used)++] = r;
1489     }
1490   }
1491   mp_clamp (c);
1492   return MP_OKAY;
1493 }
1494 
1495 /* multiply by a digit */
1496 static int
1497 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
1498 {
1499   mp_digit u, *tmpa, *tmpc;
1500   mp_word  r;
1501   int      ix, res, olduse;
1502 
1503   /* make sure c is big enough to hold a*b */
1504   if (c->alloc < a->used + 1) {
1505     if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
1506       return res;
1507     }
1508   }
1509 
1510   /* get the original destinations used count */
1511   olduse = c->used;
1512 
1513   /* set the sign */
1514   c->sign = a->sign;
1515 
1516   /* alias for a->dp [source] */
1517   tmpa = a->dp;
1518 
1519   /* alias for c->dp [dest] */
1520   tmpc = c->dp;
1521 
1522   /* zero carry */
1523   u = 0;
1524 
1525   /* compute columns */
1526   for (ix = 0; ix < a->used; ix++) {
1527     /* compute product and carry sum for this term */
1528     r       = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
1529 
1530     /* mask off higher bits to get a single digit */
1531     *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
1532 
1533     /* send carry into next iteration */
1534     u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
1535   }
1536 
1537   /* store final carry [if any] */
1538   *tmpc++ = u;
1539 
1540   /* now zero digits above the top */
1541   while (ix++ < olduse) {
1542      *tmpc++ = 0;
1543   }
1544 
1545   /* set used count */
1546   c->used = a->used + 1;
1547   mp_clamp(c);
1548 
1549   return MP_OKAY;
1550 }
1551 
1552 /* integer signed division.
1553  * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1554  * HAC pp.598 Algorithm 14.20
1555  *
1556  * Note that the description in HAC is horribly
1557  * incomplete.  For example, it doesn't consider
1558  * the case where digits are removed from 'x' in
1559  * the inner loop.  It also doesn't consider the
1560  * case that y has fewer than three digits, etc..
1561  *
1562  * The overall algorithm is as described as
1563  * 14.20 from HAC but fixed to treat these cases.
1564 */
1565 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1566 {
1567   mp_int  q, x, y, t1, t2;
1568   int     res, n, t, i, norm, neg;
1569 
1570   /* is divisor zero ? */
1571   if (mp_iszero (b) == 1) {
1572     return MP_VAL;
1573   }
1574 
1575   /* if a < b then q=0, r = a */
1576   if (mp_cmp_mag (a, b) == MP_LT) {
1577     if (d != NULL) {
1578       res = mp_copy (a, d);
1579     } else {
1580       res = MP_OKAY;
1581     }
1582     if (c != NULL) {
1583       mp_zero (c);
1584     }
1585     return res;
1586   }
1587 
1588   if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1589     return res;
1590   }
1591   q.used = a->used + 2;
1592 
1593   if ((res = mp_init (&t1)) != MP_OKAY) {
1594     goto __Q;
1595   }
1596 
1597   if ((res = mp_init (&t2)) != MP_OKAY) {
1598     goto __T1;
1599   }
1600 
1601   if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1602     goto __T2;
1603   }
1604 
1605   if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1606     goto __X;
1607   }
1608 
1609   /* fix the sign */
1610   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1611   x.sign = y.sign = MP_ZPOS;
1612 
1613   /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1614   norm = mp_count_bits(&y) % DIGIT_BIT;
1615   if (norm < DIGIT_BIT-1) {
1616      norm = (DIGIT_BIT-1) - norm;
1617      if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1618        goto __Y;
1619      }
1620      if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1621        goto __Y;
1622      }
1623   } else {
1624      norm = 0;
1625   }
1626 
1627   /* note hac does 0 based, so if used==5 then it's 0,1,2,3,4, e.g. use 4 */
1628   n = x.used - 1;
1629   t = y.used - 1;
1630 
1631   /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1632   if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1633     goto __Y;
1634   }
1635 
1636   while (mp_cmp (&x, &y) != MP_LT) {
1637     ++(q.dp[n - t]);
1638     if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1639       goto __Y;
1640     }
1641   }
1642 
1643   /* reset y by shifting it back down */
1644   mp_rshd (&y, n - t);
1645 
1646   /* step 3. for i from n down to (t + 1) */
1647   for (i = n; i >= (t + 1); i--) {
1648     if (i > x.used) {
1649       continue;
1650     }
1651 
1652     /* step 3.1 if xi == yt then set q{i-t-1} to b-1,
1653      * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1654     if (x.dp[i] == y.dp[t]) {
1655       q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1656     } else {
1657       mp_word tmp;
1658       tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1659       tmp |= ((mp_word) x.dp[i - 1]);
1660       tmp /= ((mp_word) y.dp[t]);
1661       if (tmp > (mp_word) MP_MASK)
1662         tmp = MP_MASK;
1663       q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1664     }
1665 
1666     /* while (q{i-t-1} * (yt * b + y{t-1})) >
1667              xi * b**2 + xi-1 * b + xi-2
1668 
1669        do q{i-t-1} -= 1;
1670     */
1671     q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1672     do {
1673       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1674 
1675       /* find left hand */
1676       mp_zero (&t1);
1677       t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1678       t1.dp[1] = y.dp[t];
1679       t1.used = 2;
1680       if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1681         goto __Y;
1682       }
1683 
1684       /* find right hand */
1685       t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1686       t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1687       t2.dp[2] = x.dp[i];
1688       t2.used = 3;
1689     } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1690 
1691     /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1692     if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1693       goto __Y;
1694     }
1695 
1696     if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1697       goto __Y;
1698     }
1699 
1700     if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1701       goto __Y;
1702     }
1703 
1704     /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1705     if (x.sign == MP_NEG) {
1706       if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1707         goto __Y;
1708       }
1709       if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1710         goto __Y;
1711       }
1712       if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1713         goto __Y;
1714       }
1715 
1716       q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1717     }
1718   }
1719 
1720   /* now q is the quotient and x is the remainder
1721    * [which we have to normalize]
1722    */
1723 
1724   /* get sign before writing to c */
1725   x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1726 
1727   if (c != NULL) {
1728     mp_clamp (&q);
1729     mp_exch (&q, c);
1730     c->sign = neg;
1731   }
1732 
1733   if (d != NULL) {
1734     mp_div_2d (&x, norm, &x, NULL);
1735     mp_exch (&x, d);
1736   }
1737 
1738   res = MP_OKAY;
1739 
1740 __Y:mp_clear (&y);
1741 __X:mp_clear (&x);
1742 __T2:mp_clear (&t2);
1743 __T1:mp_clear (&t1);
1744 __Q:mp_clear (&q);
1745   return res;
1746 }
1747 
1748 static BOOL s_is_power_of_two(mp_digit b, int *p)
1749 {
1750    int x;
1751 
1752    for (x = 1; x < DIGIT_BIT; x++) {
1753       if (b == (((mp_digit)1)<<x)) {
1754          *p = x;
1755          return TRUE;
1756       }
1757    }
1758    return FALSE;
1759 }
1760 
1761 /* single digit division (based on routine from MPI) */
1762 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1763 {
1764   mp_int  q;
1765   mp_word w;
1766   mp_digit t;
1767   int     res, ix;
1768 
1769   /* cannot divide by zero */
1770   if (b == 0) {
1771      return MP_VAL;
1772   }
1773 
1774   /* quick outs */
1775   if (b == 1 || mp_iszero(a) == 1) {
1776      if (d != NULL) {
1777         *d = 0;
1778      }
1779      if (c != NULL) {
1780         return mp_copy(a, c);
1781      }
1782      return MP_OKAY;
1783   }
1784 
1785   /* power of two ? */
1786   if (s_is_power_of_two(b, &ix)) {
1787      if (d != NULL) {
1788         *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1789      }
1790      if (c != NULL) {
1791         return mp_div_2d(a, ix, c, NULL);
1792      }
1793      return MP_OKAY;
1794   }
1795 
1796   /* no easy answer [c'est la vie].  Just division */
1797   if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1798      return res;
1799   }
1800 
1801   q.used = a->used;
1802   q.sign = a->sign;
1803   w = 0;
1804   for (ix = a->used - 1; ix >= 0; ix--) {
1805      w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1806 
1807      if (w >= b) {
1808         t = (mp_digit)(w / b);
1809         w -= ((mp_word)t) * ((mp_word)b);
1810       } else {
1811         t = 0;
1812       }
1813       q.dp[ix] = t;
1814   }
1815 
1816   if (d != NULL) {
1817      *d = (mp_digit)w;
1818   }
1819 
1820   if (c != NULL) {
1821      mp_clamp(&q);
1822      mp_exch(&q, c);
1823   }
1824   mp_clear(&q);
1825 
1826   return res;
1827 }
1828 
1829 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1830  *
1831  * Based on algorithm from the paper
1832  *
1833  * "Generating Efficient Primes for Discrete Log Cryptosystems"
1834  *                 Chae Hoon Lim, Pil Loong Lee,
1835  *          POSTECH Information Research Laboratories
1836  *
1837  * The modulus must be of a special format [see manual]
1838  *
1839  * Has been modified to use algorithm 7.10 from the LTM book instead
1840  *
1841  * Input x must be in the range 0 <= x <= (n-1)**2
1842  */
1843 static int
1844 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1845 {
1846   int      err, i, m;
1847   mp_word  r;
1848   mp_digit mu, *tmpx1, *tmpx2;
1849 
1850   /* m = digits in modulus */
1851   m = n->used;
1852 
1853   /* ensure that "x" has at least 2m digits */
1854   if (x->alloc < m + m) {
1855     if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1856       return err;
1857     }
1858   }
1859 
1860 /* top of loop, this is where the code resumes if
1861  * another reduction pass is required.
1862  */
1863 top:
1864   /* aliases for digits */
1865   /* alias for lower half of x */
1866   tmpx1 = x->dp;
1867 
1868   /* alias for upper half of x, or x/B**m */
1869   tmpx2 = x->dp + m;
1870 
1871   /* set carry to zero */
1872   mu = 0;
1873 
1874   /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1875   for (i = 0; i < m; i++) {
1876       r         = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1877       *tmpx1++  = (mp_digit)(r & MP_MASK);
1878       mu        = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1879   }
1880 
1881   /* set final carry */
1882   *tmpx1++ = mu;
1883 
1884   /* zero words above m */
1885   for (i = m + 1; i < x->used; i++) {
1886       *tmpx1++ = 0;
1887   }
1888 
1889   /* clamp, sub and return */
1890   mp_clamp (x);
1891 
1892   /* if x >= n then subtract and reduce again
1893    * Each successive "recursion" makes the input smaller and smaller.
1894    */
1895   if (mp_cmp_mag (x, n) != MP_LT) {
1896     s_mp_sub(x, n, x);
1897     goto top;
1898   }
1899   return MP_OKAY;
1900 }
1901 
1902 /* sets the value of "d" required for mp_dr_reduce */
1903 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1904 {
1905    /* the casts are required if DIGIT_BIT is one less than
1906     * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1907     */
1908    *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) -
1909         ((mp_word)a->dp[0]));
1910 }
1911 
1912 /* this is a shell function that calls either the normal or Montgomery
1913  * exptmod functions.  Originally the call to the montgomery code was
1914  * embedded in the normal function but that wasted a lot of stack space
1915  * for nothing (since 99% of the time the Montgomery code would be called)
1916  */
1917 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1918 {
1919   int dr;
1920 
1921   /* modulus P must be positive */
1922   if (P->sign == MP_NEG) {
1923      return MP_VAL;
1924   }
1925 
1926   /* if exponent X is negative we have to recurse */
1927   if (X->sign == MP_NEG) {
1928      mp_int tmpG, tmpX;
1929      int err;
1930 
1931      /* first compute 1/G mod P */
1932      if ((err = mp_init(&tmpG)) != MP_OKAY) {
1933         return err;
1934      }
1935      if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1936         mp_clear(&tmpG);
1937         return err;
1938      }
1939 
1940      /* now get |X| */
1941      if ((err = mp_init(&tmpX)) != MP_OKAY) {
1942         mp_clear(&tmpG);
1943         return err;
1944      }
1945      if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1946         mp_clear_multi(&tmpG, &tmpX, NULL);
1947         return err;
1948      }
1949 
1950      /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1951      err = mp_exptmod(&tmpG, &tmpX, P, Y);
1952      mp_clear_multi(&tmpG, &tmpX, NULL);
1953      return err;
1954   }
1955 
1956   dr = 0;
1957 
1958   /* if the modulus is odd use the fast method */
1959   if (mp_isodd (P) == 1) {
1960     return mp_exptmod_fast (G, X, P, Y, dr);
1961   } else {
1962     /* otherwise use the generic Barrett reduction technique */
1963     return s_mp_exptmod (G, X, P, Y);
1964   }
1965 }
1966 
1967 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1968  *
1969  * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1970  * The value of k changes based on the size of the exponent.
1971  *
1972  * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1973  */
1974 
1975 int
1976 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1977 {
1978   mp_int  M[256], res;
1979   mp_digit buf, mp;
1980   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1981 
1982   /* use a pointer to the reduction algorithm.  This allows us to use
1983    * one of many reduction algorithms without modding the guts of
1984    * the code with if statements everywhere.
1985    */
1986   int     (*redux)(mp_int*,const mp_int*,mp_digit);
1987 
1988   /* find window size */
1989   x = mp_count_bits (X);
1990   if (x <= 7) {
1991     winsize = 2;
1992   } else if (x <= 36) {
1993     winsize = 3;
1994   } else if (x <= 140) {
1995     winsize = 4;
1996   } else if (x <= 450) {
1997     winsize = 5;
1998   } else if (x <= 1303) {
1999     winsize = 6;
2000   } else if (x <= 3529) {
2001     winsize = 7;
2002   } else {
2003     winsize = 8;
2004   }
2005 
2006   /* init M array */
2007   /* init first cell */
2008   if ((err = mp_init(&M[1])) != MP_OKAY) {
2009      return err;
2010   }
2011 
2012   /* now init the second half of the array */
2013   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2014     if ((err = mp_init(&M[x])) != MP_OKAY) {
2015       for (y = 1<<(winsize-1); y < x; y++) {
2016         mp_clear (&M[y]);
2017       }
2018       mp_clear(&M[1]);
2019       return err;
2020     }
2021   }
2022 
2023   /* determine and setup reduction code */
2024   if (redmode == 0) {
2025      /* now setup montgomery  */
2026      if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
2027         goto __M;
2028      }
2029 
2030      /* automatically pick the comba one if available (saves quite a few calls/ifs) */
2031      if (((P->used * 2 + 1) < MP_WARRAY) &&
2032           P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2033         redux = fast_mp_montgomery_reduce;
2034      } else {
2035         /* use slower baseline Montgomery method */
2036         redux = mp_montgomery_reduce;
2037      }
2038   } else if (redmode == 1) {
2039      /* setup DR reduction for moduli of the form B**k - b */
2040      mp_dr_setup(P, &mp);
2041      redux = mp_dr_reduce;
2042   } else {
2043      /* setup DR reduction for moduli of the form 2**k - b */
2044      if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
2045         goto __M;
2046      }
2047      redux = mp_reduce_2k;
2048   }
2049 
2050   /* setup result */
2051   if ((err = mp_init (&res)) != MP_OKAY) {
2052     goto __M;
2053   }
2054 
2055   /* create M table
2056    *
2057 
2058    *
2059    * The first half of the table is not computed though accept for M[0] and M[1]
2060    */
2061 
2062   if (redmode == 0) {
2063      /* now we need R mod m */
2064      if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
2065        goto __RES;
2066      }
2067 
2068      /* now set M[1] to G * R mod m */
2069      if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
2070        goto __RES;
2071      }
2072   } else {
2073      mp_set(&res, 1);
2074      if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
2075         goto __RES;
2076      }
2077   }
2078 
2079   /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
2080   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
2081     goto __RES;
2082   }
2083 
2084   for (x = 0; x < (winsize - 1); x++) {
2085     if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
2086       goto __RES;
2087     }
2088     if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
2089       goto __RES;
2090     }
2091   }
2092 
2093   /* create upper table */
2094   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2095     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2096       goto __RES;
2097     }
2098     if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
2099       goto __RES;
2100     }
2101   }
2102 
2103   /* set initial mode and bit cnt */
2104   mode   = 0;
2105   bitcnt = 1;
2106   buf    = 0;
2107   digidx = X->used - 1;
2108   bitcpy = 0;
2109   bitbuf = 0;
2110 
2111   for (;;) {
2112     /* grab next digit as required */
2113     if (--bitcnt == 0) {
2114       /* if digidx == -1 we are out of digits so break */
2115       if (digidx == -1) {
2116         break;
2117       }
2118       /* read next digit and reset bitcnt */
2119       buf    = X->dp[digidx--];
2120       bitcnt = DIGIT_BIT;
2121     }
2122 
2123     /* grab the next msb from the exponent */
2124     y     = (buf >> (DIGIT_BIT - 1)) & 1;
2125     buf <<= (mp_digit)1;
2126 
2127     /* if the bit is zero and mode == 0 then we ignore it
2128      * These represent the leading zero bits before the first 1 bit
2129      * in the exponent.  Technically this opt is not required but it
2130      * does lower the # of trivial squaring/reductions used
2131      */
2132     if (mode == 0 && y == 0) {
2133       continue;
2134     }
2135 
2136     /* if the bit is zero and mode == 1 then we square */
2137     if (mode == 1 && y == 0) {
2138       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2139         goto __RES;
2140       }
2141       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2142         goto __RES;
2143       }
2144       continue;
2145     }
2146 
2147     /* else we add it to the window */
2148     bitbuf |= (y << (winsize - ++bitcpy));
2149     mode    = 2;
2150 
2151     if (bitcpy == winsize) {
2152       /* ok window is filled so square as required and multiply  */
2153       /* square first */
2154       for (x = 0; x < winsize; x++) {
2155         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2156           goto __RES;
2157         }
2158         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2159           goto __RES;
2160         }
2161       }
2162 
2163       /* then multiply */
2164       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2165         goto __RES;
2166       }
2167       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2168         goto __RES;
2169       }
2170 
2171       /* empty window and reset */
2172       bitcpy = 0;
2173       bitbuf = 0;
2174       mode   = 1;
2175     }
2176   }
2177 
2178   /* if bits remain then square/multiply */
2179   if (mode == 2 && bitcpy > 0) {
2180     /* square then multiply if the bit is set */
2181     for (x = 0; x < bitcpy; x++) {
2182       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2183         goto __RES;
2184       }
2185       if ((err = redux (&res, P, mp)) != MP_OKAY) {
2186         goto __RES;
2187       }
2188 
2189       /* get next bit of the window */
2190       bitbuf <<= 1;
2191       if ((bitbuf & (1 << winsize)) != 0) {
2192         /* then multiply */
2193         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2194           goto __RES;
2195         }
2196         if ((err = redux (&res, P, mp)) != MP_OKAY) {
2197           goto __RES;
2198         }
2199       }
2200     }
2201   }
2202 
2203   if (redmode == 0) {
2204      /* fixup result if Montgomery reduction is used
2205       * recall that any value in a Montgomery system is
2206       * actually multiplied by R mod n.  So we have
2207       * to reduce one more time to cancel out the factor
2208       * of R.
2209       */
2210      if ((err = redux(&res, P, mp)) != MP_OKAY) {
2211        goto __RES;
2212      }
2213   }
2214 
2215   /* swap res with Y */
2216   mp_exch (&res, Y);
2217   err = MP_OKAY;
2218 __RES:mp_clear (&res);
2219 __M:
2220   mp_clear(&M[1]);
2221   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2222     mp_clear (&M[x]);
2223   }
2224   return err;
2225 }
2226 
2227 /* Greatest Common Divisor using the binary method */
2228 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
2229 {
2230   mp_int  u, v;
2231   int     k, u_lsb, v_lsb, res;
2232 
2233   /* either zero than gcd is the largest */
2234   if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
2235     return mp_abs (b, c);
2236   }
2237   if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
2238     return mp_abs (a, c);
2239   }
2240 
2241   /* optimized.  At this point if a == 0 then
2242    * b must equal zero too
2243    */
2244   if (mp_iszero (a) == 1) {
2245     mp_zero(c);
2246     return MP_OKAY;
2247   }
2248 
2249   /* get copies of a and b we can modify */
2250   if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
2251     return res;
2252   }
2253 
2254   if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
2255     goto __U;
2256   }
2257 
2258   /* must be positive for the remainder of the algorithm */
2259   u.sign = v.sign = MP_ZPOS;
2260 
2261   /* B1.  Find the common power of two for u and v */
2262   u_lsb = mp_cnt_lsb(&u);
2263   v_lsb = mp_cnt_lsb(&v);
2264   k     = MIN(u_lsb, v_lsb);
2265 
2266   if (k > 0) {
2267      /* divide the power of two out */
2268      if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
2269         goto __V;
2270      }
2271 
2272      if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
2273         goto __V;
2274      }
2275   }
2276 
2277   /* divide any remaining factors of two out */
2278   if (u_lsb != k) {
2279      if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
2280         goto __V;
2281      }
2282   }
2283 
2284   if (v_lsb != k) {
2285      if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
2286         goto __V;
2287      }
2288   }
2289 
2290   while (mp_iszero(&v) == 0) {
2291      /* make sure v is the largest */
2292      if (mp_cmp_mag(&u, &v) == MP_GT) {
2293         /* swap u and v to make sure v is >= u */
2294         mp_exch(&u, &v);
2295      }
2296 
2297      /* subtract smallest from largest */
2298      if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
2299         goto __V;
2300      }
2301 
2302      /* Divide out all factors of two */
2303      if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
2304         goto __V;
2305      }
2306   }
2307 
2308   /* multiply by 2**k which we divided out at the beginning */
2309   if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
2310      goto __V;
2311   }
2312   c->sign = MP_ZPOS;
2313   res = MP_OKAY;
2314 __V:mp_clear (&u);
2315 __U:mp_clear (&v);
2316   return res;
2317 }
2318 
2319 /* get the lower 32-bits of an mp_int */
2320 unsigned long mp_get_int(const mp_int * a)
2321 {
2322   int i;
2323   unsigned long res;
2324 
2325   if (a->used == 0) {
2326      return 0;
2327   }
2328 
2329   /* get number of digits of the lsb we have to read */
2330   i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
2331 
2332   /* get most significant digit of result */
2333   res = DIGIT(a,i);
2334 
2335   while (--i >= 0) {
2336     res = (res << DIGIT_BIT) | DIGIT(a,i);
2337   }
2338 
2339   /* force result to 32-bits always so it is consistent on non 32-bit platforms */
2340   return res & 0xFFFFFFFFUL;
2341 }
2342 
2343 /* creates "a" then copies b into it */
2344 int mp_init_copy (mp_int * a, const mp_int * b)
2345 {
2346   int     res;
2347 
2348   if ((res = mp_init (a)) != MP_OKAY) {
2349     return res;
2350   }
2351   return mp_copy (b, a);
2352 }
2353 
2354 int mp_init_multi(mp_int *mp, ...)
2355 {
2356     mp_err res = MP_OKAY;      /* Assume ok until proven otherwise */
2357     int n = 0;                 /* Number of ok inits */
2358     mp_int* cur_arg = mp;
2359     va_list args;
2360 
2361     va_start(args, mp);        /* init args to next argument from caller */
2362     while (cur_arg != NULL) {
2363         if (mp_init(cur_arg) != MP_OKAY) {
2364             /* Oops - error! Back-track and mp_clear what we already
2365                succeeded in init-ing, then return error.
2366             */
2367             va_list clean_args;
2368 
2369             /* end the current list */
2370             va_end(args);
2371 
2372             /* now start cleaning up */
2373             cur_arg = mp;
2374             va_start(clean_args, mp);
2375             while (n--) {
2376                 mp_clear(cur_arg);
2377                 cur_arg = va_arg(clean_args, mp_int*);
2378             }
2379             va_end(clean_args);
2380             res = MP_MEM;
2381             break;
2382         }
2383         n++;
2384         cur_arg = va_arg(args, mp_int*);
2385     }
2386     va_end(args);
2387     return res;                /* Assumed ok, if error flagged above. */
2388 }
2389 
2390 /* hac 14.61, pp608 */
2391 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2392 {
2393   /* b cannot be negative */
2394   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2395     return MP_VAL;
2396   }
2397 
2398   /* if the modulus is odd we can use a faster routine instead */
2399   if (mp_isodd (b) == 1) {
2400     return fast_mp_invmod (a, b, c);
2401   }
2402 
2403   return mp_invmod_slow(a, b, c);
2404 }
2405 
2406 /* hac 14.61, pp608 */
2407 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2408 {
2409   mp_int  x, y, u, v, A, B, C, D;
2410   int     res;
2411 
2412   /* b cannot be negative */
2413   if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2414     return MP_VAL;
2415   }
2416 
2417   /* init temps */
2418   if ((res = mp_init_multi(&x, &y, &u, &v,
2419                            &A, &B, &C, &D, NULL)) != MP_OKAY) {
2420      return res;
2421   }
2422 
2423   /* x = a, y = b */
2424   if ((res = mp_copy (a, &x)) != MP_OKAY) {
2425     goto __ERR;
2426   }
2427   if ((res = mp_copy (b, &y)) != MP_OKAY) {
2428     goto __ERR;
2429   }
2430 
2431   /* 2. [modified] if x,y are both even then return an error! */
2432   if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2433     res = MP_VAL;
2434     goto __ERR;
2435   }
2436 
2437   /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2438   if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2439     goto __ERR;
2440   }
2441   if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2442     goto __ERR;
2443   }
2444   mp_set (&A, 1);
2445   mp_set (&D, 1);
2446 
2447 top:
2448   /* 4.  while u is even do */
2449   while (mp_iseven (&u) == 1) {
2450     /* 4.1 u = u/2 */
2451     if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2452       goto __ERR;
2453     }
2454     /* 4.2 if A or B is odd then */
2455     if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2456       /* A = (A+y)/2, B = (B-x)/2 */
2457       if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2458          goto __ERR;
2459       }
2460       if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2461          goto __ERR;
2462       }
2463     }
2464     /* A = A/2, B = B/2 */
2465     if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2466       goto __ERR;
2467     }
2468     if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2469       goto __ERR;
2470     }
2471   }
2472 
2473   /* 5.  while v is even do */
2474   while (mp_iseven (&v) == 1) {
2475     /* 5.1 v = v/2 */
2476     if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2477       goto __ERR;
2478     }
2479     /* 5.2 if C or D is odd then */
2480     if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2481       /* C = (C+y)/2, D = (D-x)/2 */
2482       if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2483          goto __ERR;
2484       }
2485       if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2486          goto __ERR;
2487       }
2488     }
2489     /* C = C/2, D = D/2 */
2490     if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2491       goto __ERR;
2492     }
2493     if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2494       goto __ERR;
2495     }
2496   }
2497 
2498   /* 6.  if u >= v then */
2499   if (mp_cmp (&u, &v) != MP_LT) {
2500     /* u = u - v, A = A - C, B = B - D */
2501     if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2502       goto __ERR;
2503     }
2504 
2505     if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2506       goto __ERR;
2507     }
2508 
2509     if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2510       goto __ERR;
2511     }
2512   } else {
2513     /* v - v - u, C = C - A, D = D - B */
2514     if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2515       goto __ERR;
2516     }
2517 
2518     if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2519       goto __ERR;
2520     }
2521 
2522     if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2523       goto __ERR;
2524     }
2525   }
2526 
2527   /* if not zero goto step 4 */
2528   if (mp_iszero (&u) == 0)
2529     goto top;
2530 
2531   /* now a = C, b = D, gcd == g*v */
2532 
2533   /* if v != 1 then there is no inverse */
2534   if (mp_cmp_d (&v, 1) != MP_EQ) {
2535     res = MP_VAL;
2536     goto __ERR;
2537   }
2538 
2539   /* if it's too low */
2540   while (mp_cmp_d(&C, 0) == MP_LT) {
2541       if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2542          goto __ERR;
2543       }
2544   }
2545 
2546   /* too big */
2547   while (mp_cmp_mag(&C, b) != MP_LT) {
2548       if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2549          goto __ERR;
2550       }
2551   }
2552 
2553   /* C is now the inverse */
2554   mp_exch (&C, c);
2555   res = MP_OKAY;
2556 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2557   return res;
2558 }
2559 
2560 /* c = |a| * |b| using Karatsuba Multiplication using
2561  * three half size multiplications
2562  *
2563  * Let B represent the radix [e.g. 2**DIGIT_BIT] and
2564  * let n represent half of the number of digits in
2565  * the min(a,b)
2566  *
2567  * a = a1 * B**n + a0
2568  * b = b1 * B**n + b0
2569  *
2570  * Then, a * b =>
2571    a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2572  *
2573  * Note that a1b1 and a0b0 are used twice and only need to be
2574  * computed once.  So in total three half size (half # of
2575  * digit) multiplications are performed, a0b0, a1b1 and
2576  * (a1-b1)(a0-b0)
2577  *
2578  * Note that a multiplication of half the digits requires
2579  * 1/4th the number of single precision multiplications so in
2580  * total after one call 25% of the single precision multiplications
2581  * are saved.  Note also that the call to mp_mul can end up back
2582  * in this function if the a0, a1, b0, or b1 are above the threshold.
2583  * This is known as divide-and-conquer and leads to the famous
2584  * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2585  * the standard O(N**2) that the baseline/comba methods use.
2586  * Generally though the overhead of this method doesn't pay off
2587  * until a certain size (N ~ 80) is reached.
2588  */
2589 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2590 {
2591   mp_int  x0, x1, y0, y1, t1, x0y0, x1y1;
2592   int     B, err;
2593 
2594   /* default the return code to an error */
2595   err = MP_MEM;
2596 
2597   /* min # of digits */
2598   B = MIN (a->used, b->used);
2599 
2600   /* now divide in two */
2601   B = B >> 1;
2602 
2603   /* init copy all the temps */
2604   if (mp_init_size (&x0, B) != MP_OKAY)
2605     goto ERR;
2606   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2607     goto X0;
2608   if (mp_init_size (&y0, B) != MP_OKAY)
2609     goto X1;
2610   if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2611     goto Y0;
2612 
2613   /* init temps */
2614   if (mp_init_size (&t1, B * 2) != MP_OKAY)
2615     goto Y1;
2616   if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2617     goto T1;
2618   if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2619     goto X0Y0;
2620 
2621   /* now shift the digits */
2622   x0.used = y0.used = B;
2623   x1.used = a->used - B;
2624   y1.used = b->used - B;
2625 
2626   {
2627     register int x;
2628     register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2629 
2630     /* we copy the digits directly instead of using higher level functions
2631      * since we also need to shift the digits
2632      */
2633     tmpa = a->dp;
2634     tmpb = b->dp;
2635 
2636     tmpx = x0.dp;
2637     tmpy = y0.dp;
2638     for (x = 0; x < B; x++) {
2639       *tmpx++ = *tmpa++;
2640       *tmpy++ = *tmpb++;
2641     }
2642 
2643     tmpx = x1.dp;
2644     for (x = B; x < a->used; x++) {
2645       *tmpx++ = *tmpa++;
2646     }
2647 
2648     tmpy = y1.dp;
2649     for (x = B; x < b->used; x++) {
2650       *tmpy++ = *tmpb++;
2651     }
2652   }
2653 
2654   /* only need to clamp the lower words since by definition the
2655    * upper words x1/y1 must have a known number of digits
2656    */
2657   mp_clamp (&x0);
2658   mp_clamp (&y0);
2659 
2660   /* now calc the products x0y0 and x1y1 */
2661   /* after this x0 is no longer required, free temp [x0==t2]! */
2662   if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)
2663     goto X1Y1;          /* x0y0 = x0*y0 */
2664   if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2665     goto X1Y1;          /* x1y1 = x1*y1 */
2666 
2667   /* now calc x1-x0 and y1-y0 */
2668   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2669     goto X1Y1;          /* t1 = x1 - x0 */
2670   if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2671     goto X1Y1;          /* t2 = y1 - y0 */
2672   if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2673     goto X1Y1;          /* t1 = (x1 - x0) * (y1 - y0) */
2674 
2675   /* add x0y0 */
2676   if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2677     goto X1Y1;          /* t2 = x0y0 + x1y1 */
2678   if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2679     goto X1Y1;          /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2680 
2681   /* shift by B */
2682   if (mp_lshd (&t1, B) != MP_OKAY)
2683     goto X1Y1;          /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2684   if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2685     goto X1Y1;          /* x1y1 = x1y1 << 2*B */
2686 
2687   if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2688     goto X1Y1;          /* t1 = x0y0 + t1 */
2689   if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2690     goto X1Y1;          /* t1 = x0y0 + t1 + x1y1 */
2691 
2692   /* Algorithm succeeded set the return code to MP_OKAY */
2693   err = MP_OKAY;
2694 
2695 X1Y1:mp_clear (&x1y1);
2696 X0Y0:mp_clear (&x0y0);
2697 T1:mp_clear (&t1);
2698 Y1:mp_clear (&y1);
2699 Y0:mp_clear (&y0);
2700 X1:mp_clear (&x1);
2701 X0:mp_clear (&x0);
2702 ERR:
2703   return err;
2704 }
2705 
2706 /* Karatsuba squaring, computes b = a*a using three
2707  * half size squarings
2708  *
2709  * See comments of karatsuba_mul for details.  It
2710  * is essentially the same algorithm but merely
2711  * tuned to perform recursive squarings.
2712  */
2713 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2714 {
2715   mp_int  x0, x1, t1, t2, x0x0, x1x1;
2716   int     B, err;
2717 
2718   err = MP_MEM;
2719 
2720   /* min # of digits */
2721   B = a->used;
2722 
2723   /* now divide in two */
2724   B = B >> 1;
2725 
2726   /* init copy all the temps */
2727   if (mp_init_size (&x0, B) != MP_OKAY)
2728     goto ERR;
2729   if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2730     goto X0;
2731 
2732   /* init temps */
2733   if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2734     goto X1;
2735   if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2736     goto T1;
2737   if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2738     goto T2;
2739   if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2740     goto X0X0;
2741 
2742   {
2743     register int x;
2744     register mp_digit *dst, *src;
2745 
2746     src = a->dp;
2747 
2748     /* now shift the digits */
2749     dst = x0.dp;
2750     for (x = 0; x < B; x++) {
2751       *dst++ = *src++;
2752     }
2753 
2754     dst = x1.dp;
2755     for (x = B; x < a->used; x++) {
2756       *dst++ = *src++;
2757     }
2758   }
2759 
2760   x0.used = B;
2761   x1.used = a->used - B;
2762 
2763   mp_clamp (&x0);
2764 
2765   /* now calc the products x0*x0 and x1*x1 */
2766   if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2767     goto X1X1;           /* x0x0 = x0*x0 */
2768   if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2769     goto X1X1;           /* x1x1 = x1*x1 */
2770 
2771   /* now calc (x1-x0)**2 */
2772   if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2773     goto X1X1;           /* t1 = x1 - x0 */
2774   if (mp_sqr (&t1, &t1) != MP_OKAY)
2775     goto X1X1;           /* t1 = (x1 - x0) * (x1 - x0) */
2776 
2777   /* add x0y0 */
2778   if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2779     goto X1X1;           /* t2 = x0x0 + x1x1 */
2780   if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2781     goto X1X1;           /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2782 
2783   /* shift by B */
2784   if (mp_lshd (&t1, B) != MP_OKAY)
2785     goto X1X1;           /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2786   if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2787     goto X1X1;           /* x1x1 = x1x1 << 2*B */
2788 
2789   if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2790     goto X1X1;           /* t1 = x0x0 + t1 */
2791   if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2792     goto X1X1;           /* t1 = x0x0 + t1 + x1x1 */
2793 
2794   err = MP_OKAY;
2795 
2796 X1X1:mp_clear (&x1x1);
2797 X0X0:mp_clear (&x0x0);
2798 T2:mp_clear (&t2);
2799 T1:mp_clear (&t1);
2800 X1:mp_clear (&x1);
2801 X0:mp_clear (&x0);
2802 ERR:
2803   return err;
2804 }
2805 
2806 /* computes least common multiple as |a*b|/(a, b) */
2807 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2808 {
2809   int     res;
2810   mp_int  t1, t2;
2811 
2812 
2813   if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2814     return res;
2815   }
2816 
2817   /* t1 = get the GCD of the two inputs */
2818   if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2819     goto __T;
2820   }
2821 
2822   /* divide the smallest by the GCD */
2823   if (mp_cmp_mag(a, b) == MP_LT) {
2824      /* store quotient in t2 so that t2 * b is the LCM */
2825      if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2826         goto __T;
2827      }
2828      res = mp_mul(b, &t2, c);
2829   } else {
2830      /* store quotient in t2 so that t2 * a is the LCM */
2831      if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2832         goto __T;
2833      }
2834      res = mp_mul(a, &t2, c);
2835   }
2836 
2837   /* fix the sign to positive */
2838   c->sign = MP_ZPOS;
2839 
2840 __T:
2841   mp_clear_multi (&t1, &t2, NULL);
2842   return res;
2843 }
2844 
2845 /* c = a mod b, 0 <= c < b */
2846 int
2847 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2848 {
2849   mp_int  t;
2850   int     res;
2851 
2852   if ((res = mp_init (&t)) != MP_OKAY) {
2853     return res;
2854   }
2855 
2856   if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2857     mp_clear (&t);
2858     return res;
2859   }
2860 
2861   if (t.sign != b->sign) {
2862     res = mp_add (b, &t, c);
2863   } else {
2864     res = MP_OKAY;
2865     mp_exch (&t, c);
2866   }
2867 
2868   mp_clear (&t);
2869   return res;
2870 }
2871 
2872 static int
2873 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2874 {
2875   return mp_div_d(a, b, NULL, c);
2876 }
2877 
2878 /* b = a*2 */
2879 static int mp_mul_2(const mp_int * a, mp_int * b)
2880 {
2881   int     x, res, oldused;
2882 
2883   /* grow to accommodate result */
2884   if (b->alloc < a->used + 1) {
2885     if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2886       return res;
2887     }
2888   }
2889 
2890   oldused = b->used;
2891   b->used = a->used;
2892 
2893   {
2894     register mp_digit r, rr, *tmpa, *tmpb;
2895 
2896     /* alias for source */
2897     tmpa = a->dp;
2898 
2899     /* alias for dest */
2900     tmpb = b->dp;
2901 
2902     /* carry */
2903     r = 0;
2904     for (x = 0; x < a->used; x++) {
2905 
2906       /* get what will be the *next* carry bit from the
2907        * MSB of the current digit
2908        */
2909       rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2910 
2911       /* now shift up this digit, add in the carry [from the previous] */
2912       *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2913 
2914       /* copy the carry that would be from the source
2915        * digit into the next iteration
2916        */
2917       r = rr;
2918     }
2919 
2920     /* new leading digit? */
2921     if (r != 0) {
2922       /* add a MSB which is always 1 at this point */
2923       *tmpb = 1;
2924       ++(b->used);
2925     }
2926 
2927     /* now zero any excess digits on the destination
2928      * that we didn't write to
2929      */
2930     tmpb = b->dp + b->used;
2931     for (x = b->used; x < oldused; x++) {
2932       *tmpb++ = 0;
2933     }
2934   }
2935   b->sign = a->sign;
2936   return MP_OKAY;
2937 }
2938 
2939 /*
2940  * shifts with subtractions when the result is greater than b.
2941  *
2942  * The method is slightly modified to shift B unconditionally up to just under
2943  * the leading bit of b.  This saves a lot of multiple precision shifting.
2944  */
2945 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2946 {
2947   int     x, bits, res;
2948 
2949   /* how many bits of last digit does b use */
2950   bits = mp_count_bits (b) % DIGIT_BIT;
2951 
2952 
2953   if (b->used > 1) {
2954      if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2955         return res;
2956      }
2957   } else {
2958      mp_set(a, 1);
2959      bits = 1;
2960   }
2961 
2962 
2963   /* now compute C = A * B mod b */
2964   for (x = bits - 1; x < DIGIT_BIT; x++) {
2965     if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2966       return res;
2967     }
2968     if (mp_cmp_mag (a, b) != MP_LT) {
2969       if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2970         return res;
2971       }
2972     }
2973   }
2974 
2975   return MP_OKAY;
2976 }
2977 
2978 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2979 int
2980 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2981 {
2982   int     ix, res, digs;
2983   mp_digit mu;
2984 
2985   /* can the fast reduction [comba] method be used?
2986    *
2987    * Note that unlike in mul you're safely allowed *less*
2988    * than the available columns [255 per default] since carries
2989    * are fixed up in the inner loop.
2990    */
2991   digs = n->used * 2 + 1;
2992   if ((digs < MP_WARRAY) &&
2993       n->used <
2994       (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2995     return fast_mp_montgomery_reduce (x, n, rho);
2996   }
2997 
2998   /* grow the input as required */
2999   if (x->alloc < digs) {
3000     if ((res = mp_grow (x, digs)) != MP_OKAY) {
3001       return res;
3002     }
3003   }
3004   x->used = digs;
3005 
3006   for (ix = 0; ix < n->used; ix++) {
3007     /* mu = ai * rho mod b
3008      *
3009      * The value of rho must be precalculated via
3010      * montgomery_setup() such that
3011      * it equals -1/n0 mod b this allows the
3012      * following inner loop to reduce the
3013      * input one digit at a time
3014      */
3015     mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
3016 
3017     /* a = a + mu * m * b**i */
3018     {
3019       register int iy;
3020       register mp_digit *tmpn, *tmpx, u;
3021       register mp_word r;
3022 
3023       /* alias for digits of the modulus */
3024       tmpn = n->dp;
3025 
3026       /* alias for the digits of x [the input] */
3027       tmpx = x->dp + ix;
3028 
3029       /* set the carry to zero */
3030       u = 0;
3031 
3032       /* Multiply and add in place */
3033       for (iy = 0; iy < n->used; iy++) {
3034         /* compute product and sum */
3035         r       = ((mp_word)mu) * ((mp_word)*tmpn++) +
3036                   ((mp_word) u) + ((mp_word) * tmpx);
3037 
3038         /* get carry */
3039         u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3040 
3041         /* fix digit */
3042         *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
3043       }
3044       /* At this point the ix'th digit of x should be zero */
3045 
3046 
3047       /* propagate carries upwards as required*/
3048       while (u) {
3049         *tmpx   += u;
3050         u        = *tmpx >> DIGIT_BIT;
3051         *tmpx++ &= MP_MASK;
3052       }
3053     }
3054   }
3055 
3056   /* at this point the n.used'th least
3057    * significant digits of x are all zero
3058    * which means we can shift x to the
3059    * right by n.used digits and the
3060    * residue is unchanged.
3061    */
3062 
3063   /* x = x/b**n.used */
3064   mp_clamp(x);
3065   mp_rshd (x, n->used);
3066 
3067   /* if x >= n then x = x - n */
3068   if (mp_cmp_mag (x, n) != MP_LT) {
3069     return s_mp_sub (x, n, x);
3070   }
3071 
3072   return MP_OKAY;
3073 }
3074 
3075 /* setups the montgomery reduction stuff */
3076 int
3077 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
3078 {
3079   mp_digit x, b;
3080 
3081 /* fast inversion mod 2**k
3082  *
3083  * Based on the fact that
3084  *
3085  * XA = 1 (mod 2**n)  =>  (X(2-XA)) A = 1 (mod 2**2n)
3086  *                    =>  2*X*A - X*X*A*A = 1
3087  *                    =>  2*(1) - (1)     = 1
3088  */
3089   b = n->dp[0];
3090 
3091   if ((b & 1) == 0) {
3092     return MP_VAL;
3093   }
3094 
3095   x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3096   x *= 2 - b * x;               /* here x*a==1 mod 2**8 */
3097   x *= 2 - b * x;               /* here x*a==1 mod 2**16 */
3098   x *= 2 - b * x;               /* here x*a==1 mod 2**32 */
3099 
3100   /* rho = -1/m mod b */
3101   *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3102 
3103   return MP_OKAY;
3104 }
3105 
3106 /* high level multiplication (handles sign) */
3107 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
3108 {
3109   int     res, neg;
3110   neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3111 
3112   /* use Karatsuba? */
3113   if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
3114     res = mp_karatsuba_mul (a, b, c);
3115   } else
3116   {
3117     /* can we use the fast multiplier?
3118      *
3119      * The fast multiplier can be used if the output will
3120      * have less than MP_WARRAY digits and the number of
3121      * digits won't affect carry propagation
3122      */
3123     int     digs = a->used + b->used + 1;
3124 
3125     if ((digs < MP_WARRAY) &&
3126         MIN(a->used, b->used) <=
3127         (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3128       res = fast_s_mp_mul_digs (a, b, c, digs);
3129     } else
3130       res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3131   }
3132   c->sign = (c->used > 0) ? neg : MP_ZPOS;
3133   return res;
3134 }
3135 
3136 /* d = a * b (mod c) */
3137 int
3138 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3139 {
3140   int     res;
3141   mp_int  t;
3142 
3143   if ((res = mp_init (&t)) != MP_OKAY) {
3144     return res;
3145   }
3146 
3147   if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3148     mp_clear (&t);
3149     return res;
3150   }
3151   res = mp_mod (&t, c, d);
3152   mp_clear (&t);
3153   return res;
3154 }
3155 
3156 /* table of first PRIME_SIZE primes */
3157 static const mp_digit __prime_tab[] = {
3158   0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3159   0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3160   0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3161   0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3162   0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3163   0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3164   0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3165   0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3166 
3167   0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3168   0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3169   0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3170   0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3171   0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3172   0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3173   0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3174   0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3175 
3176   0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3177   0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3178   0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3179   0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3180   0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3181   0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3182   0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3183   0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3184 
3185   0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3186   0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3187   0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3188   0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3189   0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3190   0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3191   0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3192   0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3193 };
3194 
3195 /* determines if an integers is divisible by one
3196  * of the first PRIME_SIZE primes or not
3197  *
3198  * sets result to 0 if not, 1 if yes
3199  */
3200 static int mp_prime_is_divisible (const mp_int * a, int *result)
3201 {
3202   int     err, ix;
3203   mp_digit res;
3204 
3205   /* default to not */
3206   *result = MP_NO;
3207 
3208   for (ix = 0; ix < PRIME_SIZE; ix++) {
3209     /* what is a mod __prime_tab[ix] */
3210     if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3211       return err;
3212     }
3213 
3214     /* is the residue zero? */
3215     if (res == 0) {
3216       *result = MP_YES;
3217       return MP_OKAY;
3218     }
3219   }
3220 
3221   return MP_OKAY;
3222 }
3223 
3224 /* Miller-Rabin test of "a" to the base of "b" as described in
3225  * HAC pp. 139 Algorithm 4.24
3226  *
3227  * Sets result to 0 if definitely composite or 1 if probably prime.
3228  * Randomly the chance of error is no more than 1/4 and often
3229  * very much lower.
3230  */
3231 static int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3232 {
3233   mp_int  n1, y, r;
3234   int     s, j, err;
3235 
3236   /* default */
3237   *result = MP_NO;
3238 
3239   /* ensure b > 1 */
3240   if (mp_cmp_d(b, 1) != MP_GT) {
3241      return MP_VAL;
3242   }
3243 
3244   /* get n1 = a - 1 */
3245   if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3246     return err;
3247   }
3248   if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3249     goto __N1;
3250   }
3251 
3252   /* set 2**s * r = n1 */
3253   if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3254     goto __N1;
3255   }
3256 
3257   /* count the number of least significant bits
3258    * which are zero
3259    */
3260   s = mp_cnt_lsb(&r);
3261 
3262   /* now divide n - 1 by 2**s */
3263   if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3264     goto __R;
3265   }
3266 
3267   /* compute y = b**r mod a */
3268   if ((err = mp_init (&y)) != MP_OKAY) {
3269     goto __R;
3270   }
3271   if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3272     goto __Y;
3273   }
3274 
3275   /* if y != 1 and y != n1 do */
3276   if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3277     j = 1;
3278     /* while j <= s-1 and y != n1 */
3279     while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3280       if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3281          goto __Y;
3282       }
3283 
3284       /* if y == 1 then composite */
3285       if (mp_cmp_d (&y, 1) == MP_EQ) {
3286          goto __Y;
3287       }
3288 
3289       ++j;
3290     }
3291 
3292     /* if y != n1 then composite */
3293     if (mp_cmp (&y, &n1) != MP_EQ) {
3294       goto __Y;
3295     }
3296   }
3297 
3298   /* probably prime now */
3299   *result = MP_YES;
3300 __Y:mp_clear (&y);
3301 __R:mp_clear (&r);
3302 __N1:mp_clear (&n1);
3303   return err;
3304 }
3305 
3306 /* performs a variable number of rounds of Miller-Rabin
3307  *
3308  * Probability of error after t rounds is no more than
3309 
3310  *
3311  * Sets result to 1 if probably prime, 0 otherwise
3312  */
3313 static int mp_prime_is_prime (mp_int * a, int t, int *result)
3314 {
3315   mp_int  b;
3316   int     ix, err, res;
3317 
3318   /* default to no */
3319   *result = MP_NO;
3320 
3321   /* valid value of t? */
3322   if (t <= 0 || t > PRIME_SIZE) {
3323     return MP_VAL;
3324   }
3325 
3326   /* is the input equal to one of the primes in the table? */
3327   for (ix = 0; ix < PRIME_SIZE; ix++) {
3328       if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3329          *result = 1;
3330          return MP_OKAY;
3331       }
3332   }
3333 
3334   /* first perform trial division */
3335   if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3336     return err;
3337   }
3338 
3339   /* return if it was trivially divisible */
3340   if (res == MP_YES) {
3341     return MP_OKAY;
3342   }
3343 
3344   /* now perform the miller-rabin rounds */
3345   if ((err = mp_init (&b)) != MP_OKAY) {
3346     return err;
3347   }
3348 
3349   for (ix = 0; ix < t; ix++) {
3350     /* set the prime */
3351     mp_set (&b, __prime_tab[ix]);
3352 
3353     if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3354       goto __B;
3355     }
3356 
3357     if (res == MP_NO) {
3358       goto __B;
3359     }
3360   }
3361 
3362   /* passed the test */
3363   *result = MP_YES;
3364 __B:mp_clear (&b);
3365   return err;
3366 }
3367 
3368 static const struct {
3369    int k, t;
3370 } sizes[] = {
3371 {   128,    28 },
3372 {   256,    16 },
3373 {   384,    10 },
3374 {   512,     7 },
3375 {   640,     6 },
3376 {   768,     5 },
3377 {   896,     4 },
3378 {  1024,     4 }
3379 };
3380 
3381 /* returns # of RM trials required for a given bit size */
3382 int mp_prime_rabin_miller_trials(int size)
3383 {
3384    int x;
3385 
3386    for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3387        if (sizes[x].k == size) {
3388           return sizes[x].t;
3389        } else if (sizes[x].k > size) {
3390           return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3391        }
3392    }
3393    return sizes[x-1].t + 1;
3394 }
3395 
3396 /* makes a truly random prime of a given size (bits),
3397  *
3398  * Flags are as follows:
3399  *
3400  *   LTM_PRIME_BBS      - make prime congruent to 3 mod 4
3401  *   LTM_PRIME_SAFE     - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3402  *   LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3403  *   LTM_PRIME_2MSB_ON  - make the 2nd highest bit one
3404  *
3405  * You have to supply a callback which fills in a buffer with random bytes.  "dat" is a parameter you can
3406  * have passed to the callback (e.g. a state or something).  This function doesn't use "dat" itself
3407  * so it can be NULL
3408  *
3409  */
3410 
3411 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3412 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3413 {
3414    unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3415    int res, err, bsize, maskOR_msb_offset;
3416 
3417    /* sanity check the input */
3418    if (size <= 1 || t <= 0) {
3419       return MP_VAL;
3420    }
3421 
3422    /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3423    if (flags & LTM_PRIME_SAFE) {
3424       flags |= LTM_PRIME_BBS;
3425    }
3426 
3427    /* calc the byte size */
3428    bsize = (size>>3)+((size&7)?1:0);
3429 
3430    /* we need a buffer of bsize bytes */
3431    tmp = HeapAlloc(GetProcessHeap(), 0, bsize);
3432    if (tmp == NULL) {
3433       return MP_MEM;
3434    }
3435 
3436    /* calc the maskAND value for the MSbyte*/
3437    maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7)));
3438 
3439    /* calc the maskOR_msb */
3440    maskOR_msb        = 0;
3441    maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3442    if (flags & LTM_PRIME_2MSB_ON) {
3443       maskOR_msb     |= 1 << ((size - 2) & 7);
3444    } else if (flags & LTM_PRIME_2MSB_OFF) {
3445       maskAND        &= ~(1 << ((size - 2) & 7));
3446    }
3447 
3448    /* get the maskOR_lsb */
3449    maskOR_lsb         = 0;
3450    if (flags & LTM_PRIME_BBS) {
3451       maskOR_lsb     |= 3;
3452    }
3453 
3454    do {
3455       /* read the bytes */
3456       if (cb(tmp, bsize, dat) != bsize) {
3457          err = MP_VAL;
3458          goto error;
3459       }
3460 
3461       /* work over the MSbyte */
3462       tmp[0]    &= maskAND;
3463       tmp[0]    |= 1 << ((size - 1) & 7);
3464 
3465       /* mix in the maskORs */
3466       tmp[maskOR_msb_offset]   |= maskOR_msb;
3467       tmp[bsize-1]             |= maskOR_lsb;
3468 
3469       /* read it in */
3470       if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY)     { goto error; }
3471 
3472       /* is it prime? */
3473       if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)           { goto error; }
3474       if (res == MP_NO) {
3475          continue;
3476       }
3477 
3478       if (flags & LTM_PRIME_SAFE) {
3479          /* see if (a-1)/2 is prime */
3480          if ((err = mp_sub_d(a, 1, a)) != MP_OKAY)                    { goto error; }
3481          if ((err = mp_div_2(a, a)) != MP_OKAY)                       { goto error; }
3482 
3483          /* is it prime? */
3484          if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY)        { goto error; }
3485       }
3486    } while (res == MP_NO);
3487 
3488    if (flags & LTM_PRIME_SAFE) {
3489       /* restore a to the original value */
3490       if ((err = mp_mul_2(a, a)) != MP_OKAY)                          { goto error; }
3491       if ((err = mp_add_d(a, 1, a)) != MP_OKAY)                       { goto error; }
3492    }
3493 
3494    err = MP_OKAY;
3495 error:
3496    HeapFree(GetProcessHeap(), 0, tmp);
3497    return err;
3498 }
3499 
3500 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3501 int
3502 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3503 {
3504   int     res;
3505 
3506   /* make sure there are at least two digits */
3507   if (a->alloc < 2) {
3508      if ((res = mp_grow(a, 2)) != MP_OKAY) {
3509         return res;
3510      }
3511   }
3512 
3513   /* zero the int */
3514   mp_zero (a);
3515 
3516   /* read the bytes in */
3517   while (c-- > 0) {
3518     if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3519       return res;
3520     }
3521 
3522     a->dp[0] |= *b++;
3523     a->used += 1;
3524   }
3525   mp_clamp (a);
3526   return MP_OKAY;
3527 }
3528 
3529 /* reduces x mod m, assumes 0 < x < m**2, mu is
3530  * precomputed via mp_reduce_setup.
3531  * From HAC pp.604 Algorithm 14.42
3532  */
3533 int
3534 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3535 {
3536   mp_int  q;
3537   int     res, um = m->used;
3538 
3539   /* q = x */
3540   if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3541     return res;
3542   }
3543 
3544   /* q1 = x / b**(k-1)  */
3545   mp_rshd (&q, um - 1);
3546 
3547   /* according to HAC this optimization is ok */
3548   if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3549     if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3550       goto CLEANUP;
3551     }
3552   } else {
3553     if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3554       goto CLEANUP;
3555     }
3556   }
3557 
3558   /* q3 = q2 / b**(k+1) */
3559   mp_rshd (&q, um + 1);
3560 
3561   /* x = x mod b**(k+1), quick (no division) */
3562   if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3563     goto CLEANUP;
3564   }
3565 
3566   /* q = q * m mod b**(k+1), quick (no division) */
3567   if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3568     goto CLEANUP;
3569   }
3570 
3571   /* x = x - q */
3572   if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3573     goto CLEANUP;
3574   }
3575 
3576   /* If x < 0, add b**(k+1) to it */
3577   if (mp_cmp_d (x, 0) == MP_LT) {
3578     mp_set (&q, 1);
3579     if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3580       goto CLEANUP;
3581     if ((res = mp_add (x, &q, x)) != MP_OKAY)
3582       goto CLEANUP;
3583   }
3584 
3585   /* Back off if it's too big */
3586   while (mp_cmp (x, m) != MP_LT) {
3587     if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3588       goto CLEANUP;
3589     }
3590   }
3591 
3592 CLEANUP:
3593   mp_clear (&q);
3594 
3595   return res;
3596 }
3597 
3598 /* reduces a modulo n where n is of the form 2**p - d */
3599 int
3600 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3601 {
3602    mp_int q;
3603    int    p, res;
3604 
3605    if ((res = mp_init(&q)) != MP_OKAY) {
3606       return res;
3607    }
3608 
3609    p = mp_count_bits(n);
3610 top:
3611    /* q = a/2**p, a = a mod 2**p */
3612    if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3613       goto ERR;
3614    }
3615 
3616    if (d != 1) {
3617       /* q = q * d */
3618       if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) {
3619          goto ERR;
3620       }
3621    }
3622 
3623    /* a = a + q */
3624    if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3625       goto ERR;
3626    }
3627 
3628    if (mp_cmp_mag(a, n) != MP_LT) {
3629       s_mp_sub(a, n, a);
3630       goto top;
3631    }
3632 
3633 ERR:
3634    mp_clear(&q);
3635    return res;
3636 }
3637 
3638 /* determines the setup value */
3639 static int
3640 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3641 {
3642    int res, p;
3643    mp_int tmp;
3644 
3645    if ((res = mp_init(&tmp)) != MP_OKAY) {
3646       return res;
3647    }
3648 
3649    p = mp_count_bits(a);
3650    if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3651       mp_clear(&tmp);
3652       return res;
3653    }
3654 
3655    if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3656       mp_clear(&tmp);
3657       return res;
3658    }
3659 
3660    *d = tmp.dp[0];
3661    mp_clear(&tmp);
3662    return MP_OKAY;
3663 }
3664 
3665 /* pre-calculate the value required for Barrett reduction
3666  * For a given modulus "b" it calculates the value required in "a"
3667  */
3668 int mp_reduce_setup (mp_int * a, const mp_int * b)
3669 {
3670   int     res;
3671 
3672   if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3673     return res;
3674   }
3675   return mp_div (a, b, a, NULL);
3676 }
3677 
3678 /* set to a digit */
3679 void mp_set (mp_int * a, mp_digit b)
3680 {
3681   mp_zero (a);
3682   a->dp[0] = b & MP_MASK;
3683   a->used  = (a->dp[0] != 0) ? 1 : 0;
3684 }
3685 
3686 /* set a 32-bit const */
3687 int mp_set_int (mp_int * a, unsigned long b)
3688 {
3689   int     x, res;
3690 
3691   mp_zero (a);
3692 
3693   /* set four bits at a time */
3694   for (x = 0; x < 8; x++) {
3695     /* shift the number up four bits */
3696     if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3697       return res;
3698     }
3699 
3700     /* OR in the top four bits of the source */
3701     a->dp[0] |= (b >> 28) & 15;
3702 
3703     /* shift the source up to the next four bits */
3704     b <<= 4;
3705 
3706     /* ensure that digits are not clamped off */
3707     a->used += 1;
3708   }
3709   mp_clamp (a);
3710   return MP_OKAY;
3711 }
3712 
3713 /* shrink a bignum */
3714 int mp_shrink (mp_int * a)
3715 {
3716   mp_digit *tmp;
3717   if (a->alloc != a->used && a->used > 0) {
3718     if ((tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3719       return MP_MEM;
3720     }
3721     a->dp    = tmp;
3722     a->alloc = a->used;
3723   }
3724   return MP_OKAY;
3725 }
3726 
3727 /* computes b = a*a */
3728 int
3729 mp_sqr (const mp_int * a, mp_int * b)
3730 {
3731   int     res;
3732 
3733 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3734     res = mp_karatsuba_sqr (a, b);
3735   } else
3736   {
3737     /* can we use the fast comba multiplier? */
3738     if ((a->used * 2 + 1) < MP_WARRAY &&
3739          a->used <
3740          (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3741       res = fast_s_mp_sqr (a, b);
3742     } else
3743       res = s_mp_sqr (a, b);
3744   }
3745   b->sign = MP_ZPOS;
3746   return res;
3747 }
3748 
3749 /* c = a * a (mod b) */
3750 int
3751 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3752 {
3753   int     res;
3754   mp_int  t;
3755 
3756   if ((res = mp_init (&t)) != MP_OKAY) {
3757     return res;
3758   }
3759 
3760   if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3761     mp_clear (&t);
3762     return res;
3763   }
3764   res = mp_mod (&t, b, c);
3765   mp_clear (&t);
3766   return res;
3767 }
3768 
3769 /* high level subtraction (handles signs) */
3770 int
3771 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3772 {
3773   int     sa, sb, res;
3774 
3775   sa = a->sign;
3776   sb = b->sign;
3777 
3778   if (sa != sb) {
3779     /* subtract a negative from a positive, OR */
3780     /* subtract a positive from a negative. */
3781     /* In either case, ADD their magnitudes, */
3782     /* and use the sign of the first number. */
3783     c->sign = sa;
3784     res = s_mp_add (a, b, c);
3785   } else {
3786     /* subtract a positive from a positive, OR */
3787     /* subtract a negative from a negative. */
3788     /* First, take the difference between their */
3789     /* magnitudes, then... */
3790     if (mp_cmp_mag (a, b) != MP_LT) {
3791       /* Copy the sign from the first */
3792       c->sign = sa;
3793       /* The first has a larger or equal magnitude */
3794       res = s_mp_sub (a, b, c);
3795     } else {
3796       /* The result has the *opposite* sign from */
3797       /* the first number. */
3798       c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3799       /* The second has a larger magnitude */
3800       res = s_mp_sub (b, a, c);
3801     }
3802   }
3803   return res;
3804 }
3805 
3806 /* single digit subtraction */
3807 int
3808 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3809 {
3810   mp_digit *tmpa, *tmpc, mu;
3811   int       res, ix, oldused;
3812 
3813   /* grow c as required */
3814   if (c->alloc < a->used + 1) {
3815      if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3816         return res;
3817      }
3818   }
3819 
3820   /* if a is negative just do an unsigned
3821    * addition [with fudged signs]
3822    */
3823   if (a->sign == MP_NEG) {
3824      a->sign = MP_ZPOS;
3825      res     = mp_add_d(a, b, c);
3826      a->sign = c->sign = MP_NEG;
3827      return res;
3828   }
3829 
3830   /* setup regs */
3831   oldused = c->used;
3832   tmpa    = a->dp;
3833   tmpc    = c->dp;
3834 
3835   /* if a <= b simply fix the single digit */
3836   if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3837      if (a->used == 1) {
3838         *tmpc++ = b - *tmpa;
3839      } else {
3840         *tmpc++ = b;
3841      }
3842      ix      = 1;
3843 
3844      /* negative/1digit */
3845      c->sign = MP_NEG;
3846      c->used = 1;
3847   } else {
3848      /* positive/size */
3849      c->sign = MP_ZPOS;
3850      c->used = a->used;
3851 
3852      /* subtract first digit */
3853      *tmpc    = *tmpa++ - b;
3854      mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3855      *tmpc++ &= MP_MASK;
3856 
3857      /* handle rest of the digits */
3858      for (ix = 1; ix < a->used; ix++) {
3859         *tmpc    = *tmpa++ - mu;
3860         mu       = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3861         *tmpc++ &= MP_MASK;
3862      }
3863   }
3864 
3865   /* zero excess digits */
3866   while (ix++ < oldused) {
3867      *tmpc++ = 0;
3868   }
3869   mp_clamp(c);
3870   return MP_OKAY;
3871 }
3872 
3873 /* store in unsigned [big endian] format */
3874 int
3875 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3876 {
3877   int     x, res;
3878   mp_int  t;
3879 
3880   if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3881     return res;
3882   }
3883 
3884   x = 0;
3885   while (mp_iszero (&t) == 0) {
3886     b[x++] = (unsigned char) (t.dp[0] & 255);
3887     if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3888       mp_clear (&t);
3889       return res;
3890     }
3891   }
3892   bn_reverse (b, x);
3893   mp_clear (&t);
3894   return MP_OKAY;
3895 }
3896 
3897 /* get the size for an unsigned equivalent */
3898 int
3899 mp_unsigned_bin_size (const mp_int * a)
3900 {
3901   int     size = mp_count_bits (a);
3902   return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3903 }
3904 
3905 /* reverse an array, used for radix code */
3906 static void
3907 bn_reverse (unsigned char *s, int len)
3908 {
3909   int     ix, iy;
3910   unsigned char t;
3911 
3912   ix = 0;
3913   iy = len - 1;
3914   while (ix < iy) {
3915     t     = s[ix];
3916     s[ix] = s[iy];
3917     s[iy] = t;
3918     ++ix;
3919     --iy;
3920   }
3921 }
3922 
3923 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3924 static int
3925 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3926 {
3927   mp_int *x;
3928   int     olduse, res, min, max;
3929 
3930   /* find sizes, we let |a| <= |b| which means we have to sort
3931    * them.  "x" will point to the input with the most digits
3932    */
3933   if (a->used > b->used) {
3934     min = b->used;
3935     max = a->used;
3936     x = a;
3937   } else {
3938     min = a->used;
3939     max = b->used;
3940     x = b;
3941   }
3942 
3943   /* init result */
3944   if (c->alloc < max + 1) {
3945     if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3946       return res;
3947     }
3948   }
3949 
3950   /* get old used digit count and set new one */
3951   olduse = c->used;
3952   c->used = max + 1;
3953 
3954   {
3955     register mp_digit u, *tmpa, *tmpb, *tmpc;
3956     register int i;
3957 
3958     /* alias for digit pointers */
3959 
3960     /* first input */
3961     tmpa = a->dp;
3962 
3963     /* second input */
3964     tmpb = b->dp;
3965 
3966     /* destination */
3967     tmpc = c->dp;
3968 
3969     /* zero the carry */
3970     u = 0;
3971     for (i = 0; i < min; i++) {
3972       /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3973       *tmpc = *tmpa++ + *tmpb++ + u;
3974 
3975       /* U = carry bit of T[i] */
3976       u = *tmpc >> ((mp_digit)DIGIT_BIT);
3977 
3978       /* take away carry bit from T[i] */
3979       *tmpc++ &= MP_MASK;
3980     }
3981 
3982     /* now copy higher words if any, that is in A+B
3983      * if A or B has more digits add those in
3984      */
3985     if (min != max) {
3986       for (; i < max; i++) {
3987         /* T[i] = X[i] + U */
3988         *tmpc = x->dp[i] + u;
3989 
3990         /* U = carry bit of T[i] */
3991         u = *tmpc >> ((mp_digit)DIGIT_BIT);
3992 
3993         /* take away carry bit from T[i] */
3994         *tmpc++ &= MP_MASK;
3995       }
3996     }
3997 
3998     /* add carry */
3999     *tmpc++ = u;
4000 
4001     /* clear digits above oldused */
4002     for (i = c->used; i < olduse; i++) {
4003       *tmpc++ = 0;
4004     }
4005   }
4006 
4007   mp_clamp (c);
4008   return MP_OKAY;
4009 }
4010 
4011 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
4012 {
4013   mp_int  M[256], res, mu;
4014   mp_digit buf;
4015   int     err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
4016 
4017   /* find window size */
4018   x = mp_count_bits (X);
4019   if (x <= 7) {
4020     winsize = 2;
4021   } else if (x <= 36) {
4022     winsize = 3;
4023   } else if (x <= 140) {
4024     winsize = 4;
4025   } else if (x <= 450) {
4026     winsize = 5;
4027   } else if (x <= 1303) {
4028     winsize = 6;
4029   } else if (x <= 3529) {
4030     winsize = 7;
4031   } else {
4032     winsize = 8;
4033   }
4034 
4035   /* init M array */
4036   /* init first cell */
4037   if ((err = mp_init(&M[1])) != MP_OKAY) {
4038      return err;
4039   }
4040 
4041   /* now init the second half of the array */
4042   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4043     if ((err = mp_init(&M[x])) != MP_OKAY) {
4044       for (y = 1<<(winsize-1); y < x; y++) {
4045         mp_clear (&M[y]);
4046       }
4047       mp_clear(&M[1]);
4048       return err;
4049     }
4050   }
4051 
4052   /* create mu, used for Barrett reduction */
4053   if ((err = mp_init (&mu)) != MP_OKAY) {
4054     goto __M;
4055   }
4056   if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4057     goto __MU;
4058   }
4059 
4060   /* create M table
4061    *
4062    * The M table contains powers of the base,
4063    * e.g. M[x] = G**x mod P
4064    *
4065    * The first half of the table is not
4066    * computed though accept for M[0] and M[1]
4067    */
4068   if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4069     goto __MU;
4070   }
4071 
4072   /* compute the value at M[1<<(winsize-1)] by squaring
4073    * M[1] (winsize-1) times
4074    */
4075   if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4076     goto __MU;
4077   }
4078 
4079   for (x = 0; x < (winsize - 1); x++) {
4080     if ((err = mp_sqr (&M[1 << (winsize - 1)],
4081                        &M[1 << (winsize - 1)])) != MP_OKAY) {
4082       goto __MU;
4083     }
4084     if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4085       goto __MU;
4086     }
4087   }
4088 
4089   /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4090    * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4091    */
4092   for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4093     if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4094       goto __MU;
4095     }
4096     if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4097       goto __MU;
4098     }
4099   }
4100 
4101   /* setup result */
4102   if ((err = mp_init (&res)) != MP_OKAY) {
4103     goto __MU;
4104   }
4105   mp_set (&res, 1);
4106 
4107   /* set initial mode and bit cnt */
4108   mode   = 0;
4109   bitcnt = 1;
4110   buf    = 0;
4111   digidx = X->used - 1;
4112   bitcpy = 0;
4113   bitbuf = 0;
4114 
4115   for (;;) {
4116     /* grab next digit as required */
4117     if (--bitcnt == 0) {
4118       /* if digidx == -1 we are out of digits */
4119       if (digidx == -1) {
4120         break;
4121       }
4122       /* read next digit and reset the bitcnt */
4123       buf    = X->dp[digidx--];
4124       bitcnt = DIGIT_BIT;
4125     }
4126 
4127     /* grab the next msb from the exponent */
4128     y     = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4129     buf <<= (mp_digit)1;
4130 
4131     /* if the bit is zero and mode == 0 then we ignore it
4132      * These represent the leading zero bits before the first 1 bit
4133      * in the exponent.  Technically this opt is not required but it
4134      * does lower the # of trivial squaring/reductions used
4135      */
4136     if (mode == 0 && y == 0) {
4137       continue;
4138     }
4139 
4140     /* if the bit is zero and mode == 1 then we square */
4141     if (mode == 1 && y == 0) {
4142       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4143         goto __RES;
4144       }
4145       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4146         goto __RES;
4147       }
4148       continue;
4149     }
4150 
4151     /* else we add it to the window */
4152     bitbuf |= (y << (winsize - ++bitcpy));
4153     mode    = 2;
4154 
4155     if (bitcpy == winsize) {
4156       /* ok window is filled so square as required and multiply  */
4157       /* square first */
4158       for (x = 0; x < winsize; x++) {
4159         if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4160           goto __RES;
4161         }
4162         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4163           goto __RES;
4164         }
4165       }
4166 
4167       /* then multiply */
4168       if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4169         goto __RES;
4170       }
4171       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4172         goto __RES;
4173       }
4174 
4175       /* empty window and reset */
4176       bitcpy = 0;
4177       bitbuf = 0;
4178       mode   = 1;
4179     }
4180   }
4181 
4182   /* if bits remain then square/multiply */
4183   if (mode == 2 && bitcpy > 0) {
4184     /* square then multiply if the bit is set */
4185     for (x = 0; x < bitcpy; x++) {
4186       if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4187         goto __RES;
4188       }
4189       if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4190         goto __RES;
4191       }
4192 
4193       bitbuf <<= 1;
4194       if ((bitbuf & (1 << winsize)) != 0) {
4195         /* then multiply */
4196         if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4197           goto __RES;
4198         }
4199         if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4200           goto __RES;
4201         }
4202       }
4203     }
4204   }
4205 
4206   mp_exch (&res, Y);
4207   err = MP_OKAY;
4208 __RES:mp_clear (&res);
4209 __MU:mp_clear (&mu);
4210 __M:
4211   mp_clear(&M[1]);
4212   for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4213     mp_clear (&M[x]);
4214   }
4215   return err;
4216 }
4217 
4218 /* multiplies |a| * |b| and only computes up to digs digits of result
4219  * HAC pp. 595, Algorithm 14.12  Modified so you can control how
4220  * many digits of output are created.
4221  */
4222 static int
4223 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4224 {
4225   mp_int  t;
4226   int     res, pa, pb, ix, iy;
4227   mp_digit u;
4228   mp_word r;
4229   mp_digit tmpx, *tmpt, *tmpy;
4230 
4231   /* can we use the fast multiplier? */
4232   if (((digs) < MP_WARRAY) &&
4233       MIN (a->used, b->used) <
4234           (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4235     return fast_s_mp_mul_digs (a, b, c, digs);
4236   }
4237 
4238   if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4239     return res;
4240   }
4241   t.used = digs;
4242 
4243   /* compute the digits of the product directly */
4244   pa = a->used;
4245   for (ix = 0; ix < pa; ix++) {
4246     /* set the carry to zero */
4247     u = 0;
4248 
4249     /* limit ourselves to making digs digits of output */
4250     pb = MIN (b->used, digs - ix);
4251 
4252     /* setup some aliases */
4253     /* copy of the digit from a used within the nested loop */
4254     tmpx = a->dp[ix];
4255 
4256     /* an alias for the destination shifted ix places */
4257     tmpt = t.dp + ix;
4258 
4259     /* an alias for the digits of b */
4260     tmpy = b->dp;
4261 
4262     /* compute the columns of the output and propagate the carry */
4263     for (iy = 0; iy < pb; iy++) {
4264       /* compute the column as a mp_word */
4265       r       = ((mp_word)*tmpt) +
4266                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4267                 ((mp_word) u);
4268 
4269       /* the new column is the lower part of the result */
4270       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4271 
4272       /* get the carry word from the result */
4273       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4274     }
4275     /* set carry if it is placed below digs */
4276     if (ix + iy < digs) {
4277       *tmpt = u;
4278     }
4279   }
4280 
4281   mp_clamp (&t);
4282   mp_exch (&t, c);
4283 
4284   mp_clear (&t);
4285   return MP_OKAY;
4286 }
4287 
4288 /* multiplies |a| * |b| and does not compute the lower digs digits
4289  * [meant to get the higher part of the product]
4290  */
4291 static int
4292 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4293 {
4294   mp_int  t;
4295   int     res, pa, pb, ix, iy;
4296   mp_digit u;
4297   mp_word r;
4298   mp_digit tmpx, *tmpt, *tmpy;
4299 
4300   /* can we use the fast multiplier? */
4301   if (((a->used + b->used + 1) < MP_WARRAY)
4302       && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4303     return fast_s_mp_mul_high_digs (a, b, c, digs);
4304   }
4305 
4306   if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4307     return res;
4308   }
4309   t.used = a->used + b->used + 1;
4310 
4311   pa = a->used;
4312   pb = b->used;
4313   for (ix = 0; ix < pa; ix++) {
4314     /* clear the carry */
4315     u = 0;
4316 
4317     /* left hand side of A[ix] * B[iy] */
4318     tmpx = a->dp[ix];
4319 
4320     /* alias to the address of where the digits will be stored */
4321     tmpt = &(t.dp[digs]);
4322 
4323     /* alias for where to read the right hand side from */
4324     tmpy = b->dp + (digs - ix);
4325 
4326     for (iy = digs - ix; iy < pb; iy++) {
4327       /* calculate the double precision result */
4328       r       = ((mp_word)*tmpt) +
4329                 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4330                 ((mp_word) u);
4331 
4332       /* get the lower part */
4333       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4334 
4335       /* carry the carry */
4336       u       = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4337     }
4338     *tmpt = u;
4339   }
4340   mp_clamp (&t);
4341   mp_exch (&t, c);
4342   mp_clear (&t);
4343   return MP_OKAY;
4344 }
4345 
4346 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4347 static int
4348 s_mp_sqr (const mp_int * a, mp_int * b)
4349 {
4350   mp_int  t;
4351   int     res, ix, iy, pa;
4352   mp_word r;
4353   mp_digit u, tmpx, *tmpt;
4354 
4355   pa = a->used;
4356   if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4357     return res;
4358   }
4359 
4360   /* default used is maximum possible size */
4361   t.used = 2*pa + 1;
4362 
4363   for (ix = 0; ix < pa; ix++) {
4364     /* first calculate the digit at 2*ix */
4365     /* calculate double precision result */
4366     r = ((mp_word) t.dp[2*ix]) +
4367         ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4368 
4369     /* store lower part in result */
4370     t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4371 
4372     /* get the carry */
4373     u           = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4374 
4375     /* left hand side of A[ix] * A[iy] */
4376     tmpx        = a->dp[ix];
4377 
4378     /* alias for where to store the results */
4379     tmpt        = t.dp + (2*ix + 1);
4380 
4381     for (iy = ix + 1; iy < pa; iy++) {
4382       /* first calculate the product */
4383       r       = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4384 
4385       /* now calculate the double precision result, note we use
4386        * addition instead of *2 since it's easier to optimize
4387        */
4388       r       = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4389 
4390       /* store lower part */
4391       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4392 
4393       /* get carry */
4394       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4395     }
4396     /* propagate upwards */
4397     while (u != 0) {
4398       r       = ((mp_word) *tmpt) + ((mp_word) u);
4399       *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4400       u       = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4401     }
4402   }
4403 
4404   mp_clamp (&t);
4405   mp_exch (&t, b);
4406   mp_clear (&t);
4407   return MP_OKAY;
4408 }
4409 
4410 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4411 int
4412 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4413 {
4414   int     olduse, res, min, max;
4415 
4416   /* find sizes */
4417   min = b->used;
4418   max = a->used;
4419 
4420   /* init result */
4421   if (c->alloc < max) {
4422     if ((res = mp_grow (c, max)) != MP_OKAY) {
4423       return res;
4424     }
4425   }
4426   olduse = c->used;
4427   c->used = max;
4428 
4429   {
4430     register mp_digit u, *tmpa, *tmpb, *tmpc;
4431     register int i;
4432 
4433     /* alias for digit pointers */
4434     tmpa = a->dp;
4435     tmpb = b->dp;
4436     tmpc = c->dp;
4437 
4438     /* set carry to zero */
4439     u = 0;
4440     for (i = 0; i < min; i++) {
4441       /* T[i] = A[i] - B[i] - U */
4442       *tmpc = *tmpa++ - *tmpb++ - u;
4443 
4444       /* U = carry bit of T[i]
4445        * Note this saves performing an AND operation since
4446        * if a carry does occur it will propagate all the way to the
4447        * MSB.  As a result a single shift is enough to get the carry
4448        */
4449       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4450 
4451       /* Clear carry from T[i] */
4452       *tmpc++ &= MP_MASK;
4453     }
4454 
4455     /* now copy higher words if any, e.g. if A has more digits than B  */
4456     for (; i < max; i++) {
4457       /* T[i] = A[i] - U */
4458       *tmpc = *tmpa++ - u;
4459 
4460       /* U = carry bit of T[i] */
4461       u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4462 
4463       /* Clear carry from T[i] */
4464       *tmpc++ &= MP_MASK;
4465     }
4466 
4467     /* clear digits above used (since we may not have grown result above) */
4468     for (i = c->used; i < olduse; i++) {
4469       *tmpc++ = 0;
4470     }
4471   }
4472 
4473   mp_clamp (c);
4474   return MP_OKAY;
4475 }
4476