1 /* imath version 1.3 */
2 /*
3 Name: imath.c
4 Purpose: Arbitrary precision integer arithmetic routines.
5 Author: M. J. Fromberger <http://spinning-yarns.org/michael/sw/>
6 Info: Id: imath.c 21 2006-04-02 18:58:36Z sting
7
8 Copyright (C) 2002 Michael J. Fromberger, All Rights Reserved.
9
10 Permission is hereby granted, free of charge, to any person
11 obtaining a copy of this software and associated documentation files
12 (the "Software"), to deal in the Software without restriction,
13 including without limitation the rights to use, copy, modify, merge,
14 publish, distribute, sublicense, and/or sell copies of the Software,
15 and to permit persons to whom the Software is furnished to do so,
16 subject to the following conditions:
17
18 The above copyright notice and this permission notice shall be
19 included in all copies or substantial portions of the Software.
20
21 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
22 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
23 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
24 NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
25 BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
26 ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
27 CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 SOFTWARE.
29 */
30 /* contrib/pgcrypto/imath.c */
31
32 #include "postgres.h"
33 #include "px.h"
34 #include "imath.h"
35
36 #undef assert
37 #define assert(TEST) Assert(TEST)
38 #define TRACEABLE_CLAMP 0
39 #define TRACEABLE_FREE 0
40
41 /* {{{ Constants */
42
43 const mp_result MP_OK = 0; /* no error, all is well */
44 const mp_result MP_FALSE = 0; /* boolean false */
45 const mp_result MP_TRUE = -1; /* boolean true */
46 const mp_result MP_MEMORY = -2; /* out of memory */
47 const mp_result MP_RANGE = -3; /* argument out of range */
48 const mp_result MP_UNDEF = -4; /* result undefined */
49 const mp_result MP_TRUNC = -5; /* output truncated */
50 const mp_result MP_BADARG = -6; /* invalid null argument */
51
52 const mp_sign MP_NEG = 1; /* value is strictly negative */
53 const mp_sign MP_ZPOS = 0; /* value is non-negative */
54
55 static const char *s_unknown_err = "unknown result code";
56 static const char *s_error_msg[] = {
57 "error code 0",
58 "boolean true",
59 "out of memory",
60 "argument out of range",
61 "result undefined",
62 "output truncated",
63 "invalid null argument",
64 NULL
65 };
66
67 /* }}} */
68
69 /* Optional library flags */
70 #define MP_CAP_DIGITS 1 /* flag bit to capitalize letter digits */
71
72 /* Argument checking macros
73 Use CHECK() where a return value is required; NRCHECK() elsewhere */
74 #define CHECK(TEST) assert(TEST)
75 #define NRCHECK(TEST) assert(TEST)
76
77 /* {{{ Logarithm table for computing output sizes */
78
79 /* The ith entry of this table gives the value of log_i(2).
80
81 An integer value n requires ceil(log_i(n)) digits to be represented
82 in base i. Since it is easy to compute lg(n), by counting bits, we
83 can compute log_i(n) = lg(n) * log_i(2).
84 */
85 static const double s_log2[] = {
86 0.000000000, 0.000000000, 1.000000000, 0.630929754, /* 0 1 2 3 */
87 0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4 5 6 7 */
88 0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8 9 10 11 */
89 0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */
90 0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */
91 0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */
92 0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */
93 0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */
94 0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */
95 0.193426404, 0.191958720, 0.190551412, 0.189200360, /* 36 37 38 39 */
96 0.187901825, 0.186652411, 0.185449023, 0.184288833, /* 40 41 42 43 */
97 0.183169251, 0.182087900, 0.181042597, 0.180031327, /* 44 45 46 47 */
98 0.179052232, 0.178103594, 0.177183820, 0.176291434, /* 48 49 50 51 */
99 0.175425064, 0.174583430, 0.173765343, 0.172969690, /* 52 53 54 55 */
100 0.172195434, 0.171441601, 0.170707280, 0.169991616, /* 56 57 58 59 */
101 0.169293808, 0.168613099, 0.167948779, 0.167300179, /* 60 61 62 63 */
102 0.166666667
103 };
104
105 /* }}} */
106 /* {{{ Various macros */
107
108 /* Return the number of digits needed to represent a static value */
109 #define MP_VALUE_DIGITS(V) \
110 ((sizeof(V)+(sizeof(mp_digit)-1))/sizeof(mp_digit))
111
112 /* Round precision P to nearest word boundary */
113 #define ROUND_PREC(P) ((mp_size)(2*(((P)+1)/2)))
114
115 /* Set array P of S digits to zero */
116 #define ZERO(P, S) \
117 do{mp_size i__=(S)*sizeof(mp_digit);mp_digit *p__=(P);memset(p__,0,i__);}while(0)
118
119 /* Copy S digits from array P to array Q */
120 #define COPY(P, Q, S) \
121 do{mp_size i__=(S)*sizeof(mp_digit);mp_digit *p__=(P),*q__=(Q);\
122 memcpy(q__,p__,i__);}while(0)
123
124 /* Reverse N elements of type T in array A */
125 #define REV(T, A, N) \
126 do{T *u_=(A),*v_=u_+(N)-1;while(u_<v_){T xch=*u_;*u_++=*v_;*v_--=xch;}}while(0)
127
128 #if TRACEABLE_CLAMP
129 #define CLAMP(Z) s_clamp(Z)
130 #else
131 #define CLAMP(Z) \
132 do{mp_int z_=(Z);mp_size uz_=MP_USED(z_);mp_digit *dz_=MP_DIGITS(z_)+uz_-1;\
133 while(uz_ > 1 && (*dz_-- == 0)) --uz_;MP_USED(z_)=uz_;}while(0)
134 #endif
135
136 #undef MIN
137 #undef MAX
138 #define MIN(A, B) ((B)<(A)?(B):(A))
139 #define MAX(A, B) ((B)>(A)?(B):(A))
140 #define SWAP(T, A, B) do{T t_=(A);A=(B);B=t_;}while(0)
141
142 #define TEMP(K) (temp + (K))
143 #define SETUP(E, C) \
144 do{if((res = (E)) != MP_OK) goto CLEANUP; ++(C);}while(0)
145
146 #define CMPZ(Z) \
147 (((Z)->used==1&&(Z)->digits[0]==0)?0:((Z)->sign==MP_NEG)?-1:1)
148
149 #define UMUL(X, Y, Z) \
150 do{mp_size ua_=MP_USED(X),ub_=MP_USED(Y);mp_size o_=ua_+ub_;\
151 ZERO(MP_DIGITS(Z),o_);\
152 (void) s_kmul(MP_DIGITS(X),MP_DIGITS(Y),MP_DIGITS(Z),ua_,ub_);\
153 MP_USED(Z)=o_;CLAMP(Z);}while(0)
154
155 #define USQR(X, Z) \
156 do{mp_size ua_=MP_USED(X),o_=ua_+ua_;ZERO(MP_DIGITS(Z),o_);\
157 (void) s_ksqr(MP_DIGITS(X),MP_DIGITS(Z),ua_);MP_USED(Z)=o_;CLAMP(Z);}while(0)
158
159 #define UPPER_HALF(W) ((mp_word)((W) >> MP_DIGIT_BIT))
160 #define LOWER_HALF(W) ((mp_digit)(W))
161 #define HIGH_BIT_SET(W) ((W) >> (MP_WORD_BIT - 1))
162 #define ADD_WILL_OVERFLOW(W, V) ((MP_WORD_MAX - (V)) < (W))
163
164 /* }}} */
165
166 /* Default number of digits allocated to a new mp_int */
167 static mp_size default_precision = 64;
168
169 /* Minimum number of digits to invoke recursive multiply */
170 static mp_size multiply_threshold = 32;
171
172 /* Default library configuration flags */
173 static mp_word mp_flags = MP_CAP_DIGITS;
174
175 /* Allocate a buffer of (at least) num digits, or return
176 NULL if that couldn't be done. */
177 static mp_digit *s_alloc(mp_size num);
178
179 #if TRACEABLE_FREE
180 static void s_free(void *ptr);
181 #else
182 #define s_free(P) px_free(P)
183 #endif
184
185 /* Insure that z has at least min digits allocated, resizing if
186 necessary. Returns true if successful, false if out of memory. */
187 static int s_pad(mp_int z, mp_size min);
188
189 /* Normalize by removing leading zeroes (except when z = 0) */
190 #if TRACEABLE_CLAMP
191 static void s_clamp(mp_int z);
192 #endif
193
194 /* Fill in a "fake" mp_int on the stack with a given value */
195 static void s_fake(mp_int z, int value, mp_digit vbuf[]);
196
197 /* Compare two runs of digits of given length, returns <0, 0, >0 */
198 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len);
199
200 /* Pack the unsigned digits of v into array t */
201 static int s_vpack(int v, mp_digit t[]);
202
203 /* Compare magnitudes of a and b, returns <0, 0, >0 */
204 static int s_ucmp(mp_int a, mp_int b);
205
206 /* Compare magnitudes of a and v, returns <0, 0, >0 */
207 static int s_vcmp(mp_int a, int v);
208
209 /* Unsigned magnitude addition; assumes dc is big enough.
210 Carry out is returned (no memory allocated). */
211 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc,
212 mp_size size_a, mp_size size_b);
213
214 /* Unsigned magnitude subtraction. Assumes dc is big enough. */
215 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc,
216 mp_size size_a, mp_size size_b);
217
218 /* Unsigned recursive multiplication. Assumes dc is big enough. */
219 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc,
220 mp_size size_a, mp_size size_b);
221
222 /* Unsigned magnitude multiplication. Assumes dc is big enough. */
223 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc,
224 mp_size size_a, mp_size size_b);
225
226 /* Unsigned recursive squaring. Assumes dc is big enough. */
227 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a);
228
229 /* Unsigned magnitude squaring. Assumes dc is big enough. */
230 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a);
231
232 /* Single digit addition. Assumes a is big enough. */
233 static void s_dadd(mp_int a, mp_digit b);
234
235 /* Single digit multiplication. Assumes a is big enough. */
236 static void s_dmul(mp_int a, mp_digit b);
237
238 /* Single digit multiplication on buffers; assumes dc is big enough. */
239 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc,
240 mp_size size_a);
241
242 /* Single digit division. Replaces a with the quotient,
243 returns the remainder. */
244 static mp_digit s_ddiv(mp_int a, mp_digit b);
245
246 /* Quick division by a power of 2, replaces z (no allocation) */
247 static void s_qdiv(mp_int z, mp_size p2);
248
249 /* Quick remainder by a power of 2, replaces z (no allocation) */
250 static void s_qmod(mp_int z, mp_size p2);
251
252 /* Quick multiplication by a power of 2, replaces z.
253 Allocates if necessary; returns false in case this fails. */
254 static int s_qmul(mp_int z, mp_size p2);
255
256 /* Quick subtraction from a power of 2, replaces z.
257 Allocates if necessary; returns false in case this fails. */
258 static int s_qsub(mp_int z, mp_size p2);
259
260 /* Return maximum k such that 2^k divides z. */
261 static int s_dp2k(mp_int z);
262
263 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
264 static int s_isp2(mp_int z);
265
266 /* Set z to 2^k. May allocate; returns false in case this fails. */
267 static int s_2expt(mp_int z, int k);
268
269 /* Normalize a and b for division, returns normalization constant */
270 static int s_norm(mp_int a, mp_int b);
271
272 /* Compute constant mu for Barrett reduction, given modulus m, result
273 replaces z, m is untouched. */
274 static mp_result s_brmu(mp_int z, mp_int m);
275
276 /* Reduce a modulo m, using Barrett's algorithm. */
277 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
278
279 /* Modular exponentiation, using Barrett reduction */
280 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
281
282 /* Unsigned magnitude division. Assumes |a| > |b|. Allocates
283 temporaries; overwrites a with quotient, b with remainder. */
284 static mp_result s_udiv(mp_int a, mp_int b);
285
286 /* Compute the number of digits in radix r required to represent the
287 given value. Does not account for sign flags, terminators, etc. */
288 static int s_outlen(mp_int z, mp_size r);
289
290 /* Guess how many digits of precision will be needed to represent a
291 radix r value of the specified number of digits. Returns a value
292 guaranteed to be no smaller than the actual number required. */
293 static mp_size s_inlen(int len, mp_size r);
294
295 /* Convert a character to a digit value in radix r, or
296 -1 if out of range */
297 static int s_ch2val(char c, int r);
298
299 /* Convert a digit value to a character */
300 static char s_val2ch(int v, int caps);
301
302 /* Take 2's complement of a buffer in place */
303 static void s_2comp(unsigned char *buf, int len);
304
305 /* Convert a value to binary, ignoring sign. On input, *limpos is the
306 bound on how many bytes should be written to buf; on output, *limpos
307 is set to the number of bytes actually written. */
308 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
309
310 #if 0
311 /* Dump a representation of the mp_int to standard output */
312 void s_print(char *tag, mp_int z);
313 void s_print_buf(char *tag, mp_digit *buf, mp_size num);
314 #endif
315
316 /* {{{ get_default_precision() */
317
318 mp_size
mp_get_default_precision(void)319 mp_get_default_precision(void)
320 {
321 return default_precision;
322 }
323
324 /* }}} */
325
326 /* {{{ mp_set_default_precision(s) */
327
328 void
mp_set_default_precision(mp_size s)329 mp_set_default_precision(mp_size s)
330 {
331 NRCHECK(s > 0);
332
333 default_precision = (mp_size) ROUND_PREC(s);
334 }
335
336 /* }}} */
337
338 /* {{{ mp_get_multiply_threshold() */
339
340 mp_size
mp_get_multiply_threshold(void)341 mp_get_multiply_threshold(void)
342 {
343 return multiply_threshold;
344 }
345
346 /* }}} */
347
348 /* {{{ mp_set_multiply_threshold(s) */
349
350 void
mp_set_multiply_threshold(mp_size s)351 mp_set_multiply_threshold(mp_size s)
352 {
353 multiply_threshold = s;
354 }
355
356 /* }}} */
357
358 /* {{{ mp_int_init(z) */
359
360 mp_result
mp_int_init(mp_int z)361 mp_int_init(mp_int z)
362 {
363 return mp_int_init_size(z, default_precision);
364 }
365
366 /* }}} */
367
368 /* {{{ mp_int_alloc() */
369
370 mp_int
mp_int_alloc(void)371 mp_int_alloc(void)
372 {
373 mp_int out = px_alloc(sizeof(mpz_t));
374
375 assert(out != NULL);
376 out->digits = NULL;
377 out->used = 0;
378 out->alloc = 0;
379 out->sign = 0;
380
381 return out;
382 }
383
384 /* }}} */
385
386 /* {{{ mp_int_init_size(z, prec) */
387
388 mp_result
mp_int_init_size(mp_int z,mp_size prec)389 mp_int_init_size(mp_int z, mp_size prec)
390 {
391 CHECK(z != NULL);
392
393 prec = (mp_size) ROUND_PREC(prec);
394 prec = MAX(prec, default_precision);
395
396 if ((MP_DIGITS(z) = s_alloc(prec)) == NULL)
397 return MP_MEMORY;
398
399 z->digits[0] = 0;
400 MP_USED(z) = 1;
401 MP_ALLOC(z) = prec;
402 MP_SIGN(z) = MP_ZPOS;
403
404 return MP_OK;
405 }
406
407 /* }}} */
408
409 /* {{{ mp_int_init_copy(z, old) */
410
411 mp_result
mp_int_init_copy(mp_int z,mp_int old)412 mp_int_init_copy(mp_int z, mp_int old)
413 {
414 mp_result res;
415 mp_size uold,
416 target;
417
418 CHECK(z != NULL && old != NULL);
419
420 uold = MP_USED(old);
421 target = MAX(uold, default_precision);
422
423 if ((res = mp_int_init_size(z, target)) != MP_OK)
424 return res;
425
426 MP_USED(z) = uold;
427 MP_SIGN(z) = MP_SIGN(old);
428 COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
429
430 return MP_OK;
431 }
432
433 /* }}} */
434
435 /* {{{ mp_int_init_value(z, value) */
436
437 mp_result
mp_int_init_value(mp_int z,int value)438 mp_int_init_value(mp_int z, int value)
439 {
440 mp_result res;
441
442 CHECK(z != NULL);
443
444 if ((res = mp_int_init(z)) != MP_OK)
445 return res;
446
447 return mp_int_set_value(z, value);
448 }
449
450 /* }}} */
451
452 /* {{{ mp_int_set_value(z, value) */
453
454 mp_result
mp_int_set_value(mp_int z,int value)455 mp_int_set_value(mp_int z, int value)
456 {
457 mp_size ndig;
458
459 CHECK(z != NULL);
460
461 /* How many digits to copy */
462 ndig = (mp_size) MP_VALUE_DIGITS(value);
463
464 if (!s_pad(z, ndig))
465 return MP_MEMORY;
466
467 MP_USED(z) = (mp_size) s_vpack(value, MP_DIGITS(z));
468 MP_SIGN(z) = (value < 0) ? MP_NEG : MP_ZPOS;
469
470 return MP_OK;
471 }
472
473 /* }}} */
474
475 /* {{{ mp_int_clear(z) */
476
477 void
mp_int_clear(mp_int z)478 mp_int_clear(mp_int z)
479 {
480 if (z == NULL)
481 return;
482
483 if (MP_DIGITS(z) != NULL)
484 {
485 s_free(MP_DIGITS(z));
486 MP_DIGITS(z) = NULL;
487 }
488 }
489
490 /* }}} */
491
492 /* {{{ mp_int_free(z) */
493
494 void
mp_int_free(mp_int z)495 mp_int_free(mp_int z)
496 {
497 NRCHECK(z != NULL);
498
499 if (z->digits != NULL)
500 mp_int_clear(z);
501
502 px_free(z);
503 }
504
505 /* }}} */
506
507 /* {{{ mp_int_copy(a, c) */
508
509 mp_result
mp_int_copy(mp_int a,mp_int c)510 mp_int_copy(mp_int a, mp_int c)
511 {
512 CHECK(a != NULL && c != NULL);
513
514 if (a != c)
515 {
516 mp_size ua = MP_USED(a);
517 mp_digit *da,
518 *dc;
519
520 if (!s_pad(c, ua))
521 return MP_MEMORY;
522
523 da = MP_DIGITS(a);
524 dc = MP_DIGITS(c);
525 COPY(da, dc, ua);
526
527 MP_USED(c) = ua;
528 MP_SIGN(c) = MP_SIGN(a);
529 }
530
531 return MP_OK;
532 }
533
534 /* }}} */
535
536 /* {{{ mp_int_swap(a, c) */
537
538 void
mp_int_swap(mp_int a,mp_int c)539 mp_int_swap(mp_int a, mp_int c)
540 {
541 if (a != c)
542 {
543 mpz_t tmp = *a;
544
545 *a = *c;
546 *c = tmp;
547 }
548 }
549
550 /* }}} */
551
552 /* {{{ mp_int_zero(z) */
553
554 void
mp_int_zero(mp_int z)555 mp_int_zero(mp_int z)
556 {
557 NRCHECK(z != NULL);
558
559 z->digits[0] = 0;
560 MP_USED(z) = 1;
561 MP_SIGN(z) = MP_ZPOS;
562 }
563
564 /* }}} */
565
566 /* {{{ mp_int_abs(a, c) */
567
568 mp_result
mp_int_abs(mp_int a,mp_int c)569 mp_int_abs(mp_int a, mp_int c)
570 {
571 mp_result res;
572
573 CHECK(a != NULL && c != NULL);
574
575 if ((res = mp_int_copy(a, c)) != MP_OK)
576 return res;
577
578 MP_SIGN(c) = MP_ZPOS;
579 return MP_OK;
580 }
581
582 /* }}} */
583
584 /* {{{ mp_int_neg(a, c) */
585
586 mp_result
mp_int_neg(mp_int a,mp_int c)587 mp_int_neg(mp_int a, mp_int c)
588 {
589 mp_result res;
590
591 CHECK(a != NULL && c != NULL);
592
593 if ((res = mp_int_copy(a, c)) != MP_OK)
594 return res;
595
596 if (CMPZ(c) != 0)
597 MP_SIGN(c) = 1 - MP_SIGN(a);
598
599 return MP_OK;
600 }
601
602 /* }}} */
603
604 /* {{{ mp_int_add(a, b, c) */
605
606 mp_result
mp_int_add(mp_int a,mp_int b,mp_int c)607 mp_int_add(mp_int a, mp_int b, mp_int c)
608 {
609 mp_size ua,
610 ub,
611 uc,
612 max;
613
614 CHECK(a != NULL && b != NULL && c != NULL);
615
616 ua = MP_USED(a);
617 ub = MP_USED(b);
618 uc = MP_USED(c);
619 max = MAX(ua, ub);
620
621 if (MP_SIGN(a) == MP_SIGN(b))
622 {
623 /* Same sign -- add magnitudes, preserve sign of addends */
624 mp_digit carry;
625
626 if (!s_pad(c, max))
627 return MP_MEMORY;
628
629 carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
630 uc = max;
631
632 if (carry)
633 {
634 if (!s_pad(c, max + 1))
635 return MP_MEMORY;
636
637 c->digits[max] = carry;
638 ++uc;
639 }
640
641 MP_USED(c) = uc;
642 MP_SIGN(c) = MP_SIGN(a);
643
644 }
645 else
646 {
647 /* Different signs -- subtract magnitudes, preserve sign of greater */
648 mp_int x,
649 y;
650 int cmp = s_ucmp(a, b); /* magnitude comparison, sign ignored */
651
652 /* Set x to max(a, b), y to min(a, b) to simplify later code */
653 if (cmp >= 0)
654 {
655 x = a;
656 y = b;
657 }
658 else
659 {
660 x = b;
661 y = a;
662 }
663
664 if (!s_pad(c, MP_USED(x)))
665 return MP_MEMORY;
666
667 /* Subtract smaller from larger */
668 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
669 MP_USED(c) = MP_USED(x);
670 CLAMP(c);
671
672 /* Give result the sign of the larger */
673 MP_SIGN(c) = MP_SIGN(x);
674 }
675
676 return MP_OK;
677 }
678
679 /* }}} */
680
681 /* {{{ mp_int_add_value(a, value, c) */
682
683 mp_result
mp_int_add_value(mp_int a,int value,mp_int c)684 mp_int_add_value(mp_int a, int value, mp_int c)
685 {
686 mpz_t vtmp;
687 mp_digit vbuf[MP_VALUE_DIGITS(value)];
688
689 s_fake(&vtmp, value, vbuf);
690
691 return mp_int_add(a, &vtmp, c);
692 }
693
694 /* }}} */
695
696 /* {{{ mp_int_sub(a, b, c) */
697
698 mp_result
mp_int_sub(mp_int a,mp_int b,mp_int c)699 mp_int_sub(mp_int a, mp_int b, mp_int c)
700 {
701 mp_size ua,
702 ub,
703 uc,
704 max;
705
706 CHECK(a != NULL && b != NULL && c != NULL);
707
708 ua = MP_USED(a);
709 ub = MP_USED(b);
710 uc = MP_USED(c);
711 max = MAX(ua, ub);
712
713 if (MP_SIGN(a) != MP_SIGN(b))
714 {
715 /* Different signs -- add magnitudes and keep sign of a */
716 mp_digit carry;
717
718 if (!s_pad(c, max))
719 return MP_MEMORY;
720
721 carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
722 uc = max;
723
724 if (carry)
725 {
726 if (!s_pad(c, max + 1))
727 return MP_MEMORY;
728
729 c->digits[max] = carry;
730 ++uc;
731 }
732
733 MP_USED(c) = uc;
734 MP_SIGN(c) = MP_SIGN(a);
735
736 }
737 else
738 {
739 /* Same signs -- subtract magnitudes */
740 mp_int x,
741 y;
742 mp_sign osign;
743 int cmp = s_ucmp(a, b);
744
745 if (!s_pad(c, max))
746 return MP_MEMORY;
747
748 if (cmp >= 0)
749 {
750 x = a;
751 y = b;
752 osign = MP_ZPOS;
753 }
754 else
755 {
756 x = b;
757 y = a;
758 osign = MP_NEG;
759 }
760
761 if (MP_SIGN(a) == MP_NEG && cmp != 0)
762 osign = 1 - osign;
763
764 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
765 MP_USED(c) = MP_USED(x);
766 CLAMP(c);
767
768 MP_SIGN(c) = osign;
769 }
770
771 return MP_OK;
772 }
773
774 /* }}} */
775
776 /* {{{ mp_int_sub_value(a, value, c) */
777
778 mp_result
mp_int_sub_value(mp_int a,int value,mp_int c)779 mp_int_sub_value(mp_int a, int value, mp_int c)
780 {
781 mpz_t vtmp;
782 mp_digit vbuf[MP_VALUE_DIGITS(value)];
783
784 s_fake(&vtmp, value, vbuf);
785
786 return mp_int_sub(a, &vtmp, c);
787 }
788
789 /* }}} */
790
791 /* {{{ mp_int_mul(a, b, c) */
792
793 mp_result
mp_int_mul(mp_int a,mp_int b,mp_int c)794 mp_int_mul(mp_int a, mp_int b, mp_int c)
795 {
796 mp_digit *out;
797 mp_size osize,
798 ua,
799 ub,
800 p = 0;
801 mp_sign osign;
802
803 CHECK(a != NULL && b != NULL && c != NULL);
804
805 /* If either input is zero, we can shortcut multiplication */
806 if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0)
807 {
808 mp_int_zero(c);
809 return MP_OK;
810 }
811
812 /* Output is positive if inputs have same sign, otherwise negative */
813 osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
814
815 /*
816 * If the output is not equal to any of the inputs, we'll write the
817 * results there directly; otherwise, allocate a temporary space.
818 */
819 ua = MP_USED(a);
820 ub = MP_USED(b);
821 osize = MAX(ua, ub);
822 osize = 4 * ((osize + 1) / 2);
823
824 if (c == a || c == b)
825 {
826 p = ROUND_PREC(osize);
827 p = MAX(p, default_precision);
828
829 if ((out = s_alloc(p)) == NULL)
830 return MP_MEMORY;
831 }
832 else
833 {
834 if (!s_pad(c, osize))
835 return MP_MEMORY;
836
837 out = MP_DIGITS(c);
838 }
839 ZERO(out, osize);
840
841 if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub))
842 return MP_MEMORY;
843
844 /*
845 * If we allocated a new buffer, get rid of whatever memory c was already
846 * using, and fix up its fields to reflect that.
847 */
848 if (out != MP_DIGITS(c))
849 {
850 s_free(MP_DIGITS(c));
851 MP_DIGITS(c) = out;
852 MP_ALLOC(c) = p;
853 }
854
855 MP_USED(c) = osize; /* might not be true, but we'll fix it ... */
856 CLAMP(c); /* ... right here */
857 MP_SIGN(c) = osign;
858
859 return MP_OK;
860 }
861
862 /* }}} */
863
864 /* {{{ mp_int_mul_value(a, value, c) */
865
866 mp_result
mp_int_mul_value(mp_int a,int value,mp_int c)867 mp_int_mul_value(mp_int a, int value, mp_int c)
868 {
869 mpz_t vtmp;
870 mp_digit vbuf[MP_VALUE_DIGITS(value)];
871
872 s_fake(&vtmp, value, vbuf);
873
874 return mp_int_mul(a, &vtmp, c);
875 }
876
877 /* }}} */
878
879 /* {{{ mp_int_mul_pow2(a, p2, c) */
880
881 mp_result
mp_int_mul_pow2(mp_int a,int p2,mp_int c)882 mp_int_mul_pow2(mp_int a, int p2, mp_int c)
883 {
884 mp_result res;
885
886 CHECK(a != NULL && c != NULL && p2 >= 0);
887
888 if ((res = mp_int_copy(a, c)) != MP_OK)
889 return res;
890
891 if (s_qmul(c, (mp_size) p2))
892 return MP_OK;
893 else
894 return MP_MEMORY;
895 }
896
897 /* }}} */
898
899 /* {{{ mp_int_sqr(a, c) */
900
901 mp_result
mp_int_sqr(mp_int a,mp_int c)902 mp_int_sqr(mp_int a, mp_int c)
903 {
904 mp_digit *out;
905 mp_size osize,
906 p = 0;
907
908 CHECK(a != NULL && c != NULL);
909
910 /* Get a temporary buffer big enough to hold the result */
911 osize = (mp_size) 4 * ((MP_USED(a) + 1) / 2);
912
913 if (a == c)
914 {
915 p = ROUND_PREC(osize);
916 p = MAX(p, default_precision);
917
918 if ((out = s_alloc(p)) == NULL)
919 return MP_MEMORY;
920 }
921 else
922 {
923 if (!s_pad(c, osize))
924 return MP_MEMORY;
925
926 out = MP_DIGITS(c);
927 }
928 ZERO(out, osize);
929
930 s_ksqr(MP_DIGITS(a), out, MP_USED(a));
931
932 /*
933 * Get rid of whatever memory c was already using, and fix up its fields
934 * to reflect the new digit array it's using
935 */
936 if (out != MP_DIGITS(c))
937 {
938 s_free(MP_DIGITS(c));
939 MP_DIGITS(c) = out;
940 MP_ALLOC(c) = p;
941 }
942
943 MP_USED(c) = osize; /* might not be true, but we'll fix it ... */
944 CLAMP(c); /* ... right here */
945 MP_SIGN(c) = MP_ZPOS;
946
947 return MP_OK;
948 }
949
950 /* }}} */
951
952 /* {{{ mp_int_div(a, b, q, r) */
953
954 mp_result
mp_int_div(mp_int a,mp_int b,mp_int q,mp_int r)955 mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r)
956 {
957 int cmp,
958 last = 0,
959 lg;
960 mp_result res = MP_OK;
961 mpz_t temp[2];
962 mp_int qout,
963 rout;
964 mp_sign sa = MP_SIGN(a),
965 sb = MP_SIGN(b);
966
967 CHECK(a != NULL && b != NULL && q != r);
968
969 if (CMPZ(b) == 0)
970 return MP_UNDEF;
971 else if ((cmp = s_ucmp(a, b)) < 0)
972 {
973 /*
974 * If |a| < |b|, no division is required: q = 0, r = a
975 */
976 if (r && (res = mp_int_copy(a, r)) != MP_OK)
977 return res;
978
979 if (q)
980 mp_int_zero(q);
981
982 return MP_OK;
983 }
984 else if (cmp == 0)
985 {
986 /*
987 * If |a| = |b|, no division is required: q = 1 or -1, r = 0
988 */
989 if (r)
990 mp_int_zero(r);
991
992 if (q)
993 {
994 mp_int_zero(q);
995 q->digits[0] = 1;
996
997 if (sa != sb)
998 MP_SIGN(q) = MP_NEG;
999 }
1000
1001 return MP_OK;
1002 }
1003
1004 /*
1005 * When |a| > |b|, real division is required. We need someplace to store
1006 * quotient and remainder, but q and r are allowed to be NULL or to
1007 * overlap with the inputs.
1008 */
1009 if ((lg = s_isp2(b)) < 0)
1010 {
1011 if (q && b != q && (res = mp_int_copy(a, q)) == MP_OK)
1012 {
1013 qout = q;
1014 }
1015 else
1016 {
1017 qout = TEMP(last);
1018 SETUP(mp_int_init_copy(TEMP(last), a), last);
1019 }
1020
1021 if (r && a != r && (res = mp_int_copy(b, r)) == MP_OK)
1022 {
1023 rout = r;
1024 }
1025 else
1026 {
1027 rout = TEMP(last);
1028 SETUP(mp_int_init_copy(TEMP(last), b), last);
1029 }
1030
1031 if ((res = s_udiv(qout, rout)) != MP_OK)
1032 goto CLEANUP;
1033 }
1034 else
1035 {
1036 if (q && (res = mp_int_copy(a, q)) != MP_OK)
1037 goto CLEANUP;
1038 if (r && (res = mp_int_copy(a, r)) != MP_OK)
1039 goto CLEANUP;
1040
1041 if (q)
1042 s_qdiv(q, (mp_size) lg);
1043 qout = q;
1044 if (r)
1045 s_qmod(r, (mp_size) lg);
1046 rout = r;
1047 }
1048
1049 /* Recompute signs for output */
1050 if (rout)
1051 {
1052 MP_SIGN(rout) = sa;
1053 if (CMPZ(rout) == 0)
1054 MP_SIGN(rout) = MP_ZPOS;
1055 }
1056 if (qout)
1057 {
1058 MP_SIGN(qout) = (sa == sb) ? MP_ZPOS : MP_NEG;
1059 if (CMPZ(qout) == 0)
1060 MP_SIGN(qout) = MP_ZPOS;
1061 }
1062
1063 if (q && (res = mp_int_copy(qout, q)) != MP_OK)
1064 goto CLEANUP;
1065 if (r && (res = mp_int_copy(rout, r)) != MP_OK)
1066 goto CLEANUP;
1067
1068 CLEANUP:
1069 while (--last >= 0)
1070 mp_int_clear(TEMP(last));
1071
1072 return res;
1073 }
1074
1075 /* }}} */
1076
1077 /* {{{ mp_int_mod(a, m, c) */
1078
1079 mp_result
mp_int_mod(mp_int a,mp_int m,mp_int c)1080 mp_int_mod(mp_int a, mp_int m, mp_int c)
1081 {
1082 mp_result res;
1083 mpz_t tmp;
1084 mp_int out;
1085
1086 if (m == c)
1087 {
1088 if ((res = mp_int_init(&tmp)) != MP_OK)
1089 return res;
1090
1091 out = &tmp;
1092 }
1093 else
1094 {
1095 out = c;
1096 }
1097
1098 if ((res = mp_int_div(a, m, NULL, out)) != MP_OK)
1099 goto CLEANUP;
1100
1101 if (CMPZ(out) < 0)
1102 res = mp_int_add(out, m, c);
1103 else
1104 res = mp_int_copy(out, c);
1105
1106 CLEANUP:
1107 if (out != c)
1108 mp_int_clear(&tmp);
1109
1110 return res;
1111 }
1112
1113 /* }}} */
1114
1115
1116 /* {{{ mp_int_div_value(a, value, q, r) */
1117
1118 mp_result
mp_int_div_value(mp_int a,int value,mp_int q,int * r)1119 mp_int_div_value(mp_int a, int value, mp_int q, int *r)
1120 {
1121 mpz_t vtmp,
1122 rtmp;
1123 mp_digit vbuf[MP_VALUE_DIGITS(value)];
1124 mp_result res;
1125
1126 if ((res = mp_int_init(&rtmp)) != MP_OK)
1127 return res;
1128 s_fake(&vtmp, value, vbuf);
1129
1130 if ((res = mp_int_div(a, &vtmp, q, &rtmp)) != MP_OK)
1131 goto CLEANUP;
1132
1133 if (r)
1134 (void) mp_int_to_int(&rtmp, r); /* can't fail */
1135
1136 CLEANUP:
1137 mp_int_clear(&rtmp);
1138 return res;
1139 }
1140
1141 /* }}} */
1142
1143 /* {{{ mp_int_div_pow2(a, p2, q, r) */
1144
1145 mp_result
mp_int_div_pow2(mp_int a,int p2,mp_int q,mp_int r)1146 mp_int_div_pow2(mp_int a, int p2, mp_int q, mp_int r)
1147 {
1148 mp_result res = MP_OK;
1149
1150 CHECK(a != NULL && p2 >= 0 && q != r);
1151
1152 if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK)
1153 s_qdiv(q, (mp_size) p2);
1154
1155 if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK)
1156 s_qmod(r, (mp_size) p2);
1157
1158 return res;
1159 }
1160
1161 /* }}} */
1162
1163 /* {{{ mp_int_expt(a, b, c) */
1164
1165 mp_result
mp_int_expt(mp_int a,int b,mp_int c)1166 mp_int_expt(mp_int a, int b, mp_int c)
1167 {
1168 mpz_t t;
1169 mp_result res;
1170 unsigned int v = abs(b);
1171
1172 CHECK(b >= 0 && c != NULL);
1173
1174 if ((res = mp_int_init_copy(&t, a)) != MP_OK)
1175 return res;
1176
1177 (void) mp_int_set_value(c, 1);
1178 while (v != 0)
1179 {
1180 if (v & 1)
1181 {
1182 if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1183 goto CLEANUP;
1184 }
1185
1186 v >>= 1;
1187 if (v == 0)
1188 break;
1189
1190 if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1191 goto CLEANUP;
1192 }
1193
1194 CLEANUP:
1195 mp_int_clear(&t);
1196 return res;
1197 }
1198
1199 /* }}} */
1200
1201 /* {{{ mp_int_expt_value(a, b, c) */
1202
1203 mp_result
mp_int_expt_value(int a,int b,mp_int c)1204 mp_int_expt_value(int a, int b, mp_int c)
1205 {
1206 mpz_t t;
1207 mp_result res;
1208 unsigned int v = abs(b);
1209
1210 CHECK(b >= 0 && c != NULL);
1211
1212 if ((res = mp_int_init_value(&t, a)) != MP_OK)
1213 return res;
1214
1215 (void) mp_int_set_value(c, 1);
1216 while (v != 0)
1217 {
1218 if (v & 1)
1219 {
1220 if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1221 goto CLEANUP;
1222 }
1223
1224 v >>= 1;
1225 if (v == 0)
1226 break;
1227
1228 if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1229 goto CLEANUP;
1230 }
1231
1232 CLEANUP:
1233 mp_int_clear(&t);
1234 return res;
1235 }
1236
1237 /* }}} */
1238
1239 /* {{{ mp_int_compare(a, b) */
1240
1241 int
mp_int_compare(mp_int a,mp_int b)1242 mp_int_compare(mp_int a, mp_int b)
1243 {
1244 mp_sign sa;
1245
1246 CHECK(a != NULL && b != NULL);
1247
1248 sa = MP_SIGN(a);
1249 if (sa == MP_SIGN(b))
1250 {
1251 int cmp = s_ucmp(a, b);
1252
1253 /*
1254 * If they're both zero or positive, the normal comparison applies; if
1255 * both negative, the sense is reversed.
1256 */
1257 if (sa != MP_ZPOS)
1258 INVERT_COMPARE_RESULT(cmp);
1259 return cmp;
1260 }
1261 else
1262 {
1263 if (sa == MP_ZPOS)
1264 return 1;
1265 else
1266 return -1;
1267 }
1268 }
1269
1270 /* }}} */
1271
1272 /* {{{ mp_int_compare_unsigned(a, b) */
1273
1274 int
mp_int_compare_unsigned(mp_int a,mp_int b)1275 mp_int_compare_unsigned(mp_int a, mp_int b)
1276 {
1277 NRCHECK(a != NULL && b != NULL);
1278
1279 return s_ucmp(a, b);
1280 }
1281
1282 /* }}} */
1283
1284 /* {{{ mp_int_compare_zero(z) */
1285
1286 int
mp_int_compare_zero(mp_int z)1287 mp_int_compare_zero(mp_int z)
1288 {
1289 NRCHECK(z != NULL);
1290
1291 if (MP_USED(z) == 1 && z->digits[0] == 0)
1292 return 0;
1293 else if (MP_SIGN(z) == MP_ZPOS)
1294 return 1;
1295 else
1296 return -1;
1297 }
1298
1299 /* }}} */
1300
1301 /* {{{ mp_int_compare_value(z, value) */
1302
1303 int
mp_int_compare_value(mp_int z,int value)1304 mp_int_compare_value(mp_int z, int value)
1305 {
1306 mp_sign vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1307 int cmp;
1308
1309 CHECK(z != NULL);
1310
1311 if (vsign == MP_SIGN(z))
1312 {
1313 cmp = s_vcmp(z, value);
1314
1315 if (vsign != MP_ZPOS)
1316 INVERT_COMPARE_RESULT(cmp);
1317 return cmp;
1318 }
1319 else
1320 {
1321 if (value < 0)
1322 return 1;
1323 else
1324 return -1;
1325 }
1326 }
1327
1328 /* }}} */
1329
1330 /* {{{ mp_int_exptmod(a, b, m, c) */
1331
1332 mp_result
mp_int_exptmod(mp_int a,mp_int b,mp_int m,mp_int c)1333 mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c)
1334 {
1335 mp_result res;
1336 mp_size um;
1337 mpz_t temp[3];
1338 mp_int s;
1339 int last = 0;
1340
1341 CHECK(a != NULL && b != NULL && c != NULL && m != NULL);
1342
1343 /* Zero moduli and negative exponents are not considered. */
1344 if (CMPZ(m) == 0)
1345 return MP_UNDEF;
1346 if (CMPZ(b) < 0)
1347 return MP_RANGE;
1348
1349 um = MP_USED(m);
1350 SETUP(mp_int_init_size(TEMP(0), 2 * um), last);
1351 SETUP(mp_int_init_size(TEMP(1), 2 * um), last);
1352
1353 if (c == b || c == m)
1354 {
1355 SETUP(mp_int_init_size(TEMP(2), 2 * um), last);
1356 s = TEMP(2);
1357 }
1358 else
1359 {
1360 s = c;
1361 }
1362
1363 if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK)
1364 goto CLEANUP;
1365
1366 if ((res = s_brmu(TEMP(1), m)) != MP_OK)
1367 goto CLEANUP;
1368
1369 if ((res = s_embar(TEMP(0), b, m, TEMP(1), s)) != MP_OK)
1370 goto CLEANUP;
1371
1372 res = mp_int_copy(s, c);
1373
1374 CLEANUP:
1375 while (--last >= 0)
1376 mp_int_clear(TEMP(last));
1377
1378 return res;
1379 }
1380
1381 /* }}} */
1382
1383 /* {{{ mp_int_exptmod_evalue(a, value, m, c) */
1384
1385 mp_result
mp_int_exptmod_evalue(mp_int a,int value,mp_int m,mp_int c)1386 mp_int_exptmod_evalue(mp_int a, int value, mp_int m, mp_int c)
1387 {
1388 mpz_t vtmp;
1389 mp_digit vbuf[MP_VALUE_DIGITS(value)];
1390
1391 s_fake(&vtmp, value, vbuf);
1392
1393 return mp_int_exptmod(a, &vtmp, m, c);
1394 }
1395
1396 /* }}} */
1397
1398 /* {{{ mp_int_exptmod_bvalue(v, b, m, c) */
1399
1400 mp_result
mp_int_exptmod_bvalue(int value,mp_int b,mp_int m,mp_int c)1401 mp_int_exptmod_bvalue(int value, mp_int b,
1402 mp_int m, mp_int c)
1403 {
1404 mpz_t vtmp;
1405 mp_digit vbuf[MP_VALUE_DIGITS(value)];
1406
1407 s_fake(&vtmp, value, vbuf);
1408
1409 return mp_int_exptmod(&vtmp, b, m, c);
1410 }
1411
1412 /* }}} */
1413
1414 /* {{{ mp_int_exptmod_known(a, b, m, mu, c) */
1415
1416 mp_result
mp_int_exptmod_known(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)1417 mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
1418 {
1419 mp_result res;
1420 mp_size um;
1421 mpz_t temp[2];
1422 mp_int s;
1423 int last = 0;
1424
1425 CHECK(a && b && m && c);
1426
1427 /* Zero moduli and negative exponents are not considered. */
1428 if (CMPZ(m) == 0)
1429 return MP_UNDEF;
1430 if (CMPZ(b) < 0)
1431 return MP_RANGE;
1432
1433 um = MP_USED(m);
1434 SETUP(mp_int_init_size(TEMP(0), 2 * um), last);
1435
1436 if (c == b || c == m)
1437 {
1438 SETUP(mp_int_init_size(TEMP(1), 2 * um), last);
1439 s = TEMP(1);
1440 }
1441 else
1442 {
1443 s = c;
1444 }
1445
1446 if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK)
1447 goto CLEANUP;
1448
1449 if ((res = s_embar(TEMP(0), b, m, mu, s)) != MP_OK)
1450 goto CLEANUP;
1451
1452 res = mp_int_copy(s, c);
1453
1454 CLEANUP:
1455 while (--last >= 0)
1456 mp_int_clear(TEMP(last));
1457
1458 return res;
1459 }
1460
1461 /* }}} */
1462
1463 /* {{{ mp_int_redux_const(m, c) */
1464
1465 mp_result
mp_int_redux_const(mp_int m,mp_int c)1466 mp_int_redux_const(mp_int m, mp_int c)
1467 {
1468 CHECK(m != NULL && c != NULL && m != c);
1469
1470 return s_brmu(c, m);
1471 }
1472
1473 /* }}} */
1474
1475 /* {{{ mp_int_invmod(a, m, c) */
1476
1477 mp_result
mp_int_invmod(mp_int a,mp_int m,mp_int c)1478 mp_int_invmod(mp_int a, mp_int m, mp_int c)
1479 {
1480 mp_result res;
1481 mp_sign sa;
1482 int last = 0;
1483 mpz_t temp[2];
1484
1485 CHECK(a != NULL && m != NULL && c != NULL);
1486
1487 if (CMPZ(a) == 0 || CMPZ(m) <= 0)
1488 return MP_RANGE;
1489
1490 sa = MP_SIGN(a); /* need this for the result later */
1491
1492 for (last = 0; last < 2; ++last)
1493 if ((res = mp_int_init(TEMP(last))) != MP_OK)
1494 goto CLEANUP;
1495
1496 if ((res = mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL)) != MP_OK)
1497 goto CLEANUP;
1498
1499 if (mp_int_compare_value(TEMP(0), 1) != 0)
1500 {
1501 res = MP_UNDEF;
1502 goto CLEANUP;
1503 }
1504
1505 /* It is first necessary to constrain the value to the proper range */
1506 if ((res = mp_int_mod(TEMP(1), m, TEMP(1))) != MP_OK)
1507 goto CLEANUP;
1508
1509 /*
1510 * Now, if 'a' was originally negative, the value we have is actually the
1511 * magnitude of the negative representative; to get the positive value we
1512 * have to subtract from the modulus. Otherwise, the value is okay as it
1513 * stands.
1514 */
1515 if (sa == MP_NEG)
1516 res = mp_int_sub(m, TEMP(1), c);
1517 else
1518 res = mp_int_copy(TEMP(1), c);
1519
1520 CLEANUP:
1521 while (--last >= 0)
1522 mp_int_clear(TEMP(last));
1523
1524 return res;
1525 }
1526
1527 /* }}} */
1528
1529 /* {{{ mp_int_gcd(a, b, c) */
1530
1531 /* Binary GCD algorithm due to Josef Stein, 1961 */
1532 mp_result
mp_int_gcd(mp_int a,mp_int b,mp_int c)1533 mp_int_gcd(mp_int a, mp_int b, mp_int c)
1534 {
1535 int ca,
1536 cb,
1537 k = 0;
1538 mpz_t u,
1539 v,
1540 t;
1541 mp_result res;
1542
1543 CHECK(a != NULL && b != NULL && c != NULL);
1544
1545 ca = CMPZ(a);
1546 cb = CMPZ(b);
1547 if (ca == 0 && cb == 0)
1548 return MP_UNDEF;
1549 else if (ca == 0)
1550 return mp_int_abs(b, c);
1551 else if (cb == 0)
1552 return mp_int_abs(a, c);
1553
1554 if ((res = mp_int_init(&t)) != MP_OK)
1555 return res;
1556 if ((res = mp_int_init_copy(&u, a)) != MP_OK)
1557 goto U;
1558 if ((res = mp_int_init_copy(&v, b)) != MP_OK)
1559 goto V;
1560
1561 MP_SIGN(&u) = MP_ZPOS;
1562 MP_SIGN(&v) = MP_ZPOS;
1563
1564 { /* Divide out common factors of 2 from u and v */
1565 int div2_u = s_dp2k(&u),
1566 div2_v = s_dp2k(&v);
1567
1568 k = MIN(div2_u, div2_v);
1569 s_qdiv(&u, (mp_size) k);
1570 s_qdiv(&v, (mp_size) k);
1571 }
1572
1573 if (mp_int_is_odd(&u))
1574 {
1575 if ((res = mp_int_neg(&v, &t)) != MP_OK)
1576 goto CLEANUP;
1577 }
1578 else
1579 {
1580 if ((res = mp_int_copy(&u, &t)) != MP_OK)
1581 goto CLEANUP;
1582 }
1583
1584 for (;;)
1585 {
1586 s_qdiv(&t, s_dp2k(&t));
1587
1588 if (CMPZ(&t) > 0)
1589 {
1590 if ((res = mp_int_copy(&t, &u)) != MP_OK)
1591 goto CLEANUP;
1592 }
1593 else
1594 {
1595 if ((res = mp_int_neg(&t, &v)) != MP_OK)
1596 goto CLEANUP;
1597 }
1598
1599 if ((res = mp_int_sub(&u, &v, &t)) != MP_OK)
1600 goto CLEANUP;
1601
1602 if (CMPZ(&t) == 0)
1603 break;
1604 }
1605
1606 if ((res = mp_int_abs(&u, c)) != MP_OK)
1607 goto CLEANUP;
1608 if (!s_qmul(c, (mp_size) k))
1609 res = MP_MEMORY;
1610
1611 CLEANUP:
1612 mp_int_clear(&v);
1613 V: mp_int_clear(&u);
1614 U: mp_int_clear(&t);
1615
1616 return res;
1617 }
1618
1619 /* }}} */
1620
1621 /* {{{ mp_int_egcd(a, b, c, x, y) */
1622
1623 /* This is the binary GCD algorithm again, but this time we keep track
1624 of the elementary matrix operations as we go, so we can get values
1625 x and y satisfying c = ax + by.
1626 */
1627 mp_result
mp_int_egcd(mp_int a,mp_int b,mp_int c,mp_int x,mp_int y)1628 mp_int_egcd(mp_int a, mp_int b, mp_int c,
1629 mp_int x, mp_int y)
1630 {
1631 int k,
1632 last = 0,
1633 ca,
1634 cb;
1635 mpz_t temp[8];
1636 mp_result res;
1637
1638 CHECK(a != NULL && b != NULL && c != NULL &&
1639 (x != NULL || y != NULL));
1640
1641 ca = CMPZ(a);
1642 cb = CMPZ(b);
1643 if (ca == 0 && cb == 0)
1644 return MP_UNDEF;
1645 else if (ca == 0)
1646 {
1647 if ((res = mp_int_abs(b, c)) != MP_OK)
1648 return res;
1649 mp_int_zero(x);
1650 (void) mp_int_set_value(y, 1);
1651 return MP_OK;
1652 }
1653 else if (cb == 0)
1654 {
1655 if ((res = mp_int_abs(a, c)) != MP_OK)
1656 return res;
1657 (void) mp_int_set_value(x, 1);
1658 mp_int_zero(y);
1659 return MP_OK;
1660 }
1661
1662 /*
1663 * Initialize temporaries: A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7
1664 */
1665 for (last = 0; last < 4; ++last)
1666 {
1667 if ((res = mp_int_init(TEMP(last))) != MP_OK)
1668 goto CLEANUP;
1669 }
1670 TEMP(0)->digits[0] = 1;
1671 TEMP(3)->digits[0] = 1;
1672
1673 SETUP(mp_int_init_copy(TEMP(4), a), last);
1674 SETUP(mp_int_init_copy(TEMP(5), b), last);
1675
1676 /* We will work with absolute values here */
1677 MP_SIGN(TEMP(4)) = MP_ZPOS;
1678 MP_SIGN(TEMP(5)) = MP_ZPOS;
1679
1680 { /* Divide out common factors of 2 from u and v */
1681 int div2_u = s_dp2k(TEMP(4)),
1682 div2_v = s_dp2k(TEMP(5));
1683
1684 k = MIN(div2_u, div2_v);
1685 s_qdiv(TEMP(4), k);
1686 s_qdiv(TEMP(5), k);
1687 }
1688
1689 SETUP(mp_int_init_copy(TEMP(6), TEMP(4)), last);
1690 SETUP(mp_int_init_copy(TEMP(7), TEMP(5)), last);
1691
1692 for (;;)
1693 {
1694 while (mp_int_is_even(TEMP(4)))
1695 {
1696 s_qdiv(TEMP(4), 1);
1697
1698 if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1)))
1699 {
1700 if ((res = mp_int_add(TEMP(0), TEMP(7), TEMP(0))) != MP_OK)
1701 goto CLEANUP;
1702 if ((res = mp_int_sub(TEMP(1), TEMP(6), TEMP(1))) != MP_OK)
1703 goto CLEANUP;
1704 }
1705
1706 s_qdiv(TEMP(0), 1);
1707 s_qdiv(TEMP(1), 1);
1708 }
1709
1710 while (mp_int_is_even(TEMP(5)))
1711 {
1712 s_qdiv(TEMP(5), 1);
1713
1714 if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3)))
1715 {
1716 if ((res = mp_int_add(TEMP(2), TEMP(7), TEMP(2))) != MP_OK)
1717 goto CLEANUP;
1718 if ((res = mp_int_sub(TEMP(3), TEMP(6), TEMP(3))) != MP_OK)
1719 goto CLEANUP;
1720 }
1721
1722 s_qdiv(TEMP(2), 1);
1723 s_qdiv(TEMP(3), 1);
1724 }
1725
1726 if (mp_int_compare(TEMP(4), TEMP(5)) >= 0)
1727 {
1728 if ((res = mp_int_sub(TEMP(4), TEMP(5), TEMP(4))) != MP_OK)
1729 goto CLEANUP;
1730 if ((res = mp_int_sub(TEMP(0), TEMP(2), TEMP(0))) != MP_OK)
1731 goto CLEANUP;
1732 if ((res = mp_int_sub(TEMP(1), TEMP(3), TEMP(1))) != MP_OK)
1733 goto CLEANUP;
1734 }
1735 else
1736 {
1737 if ((res = mp_int_sub(TEMP(5), TEMP(4), TEMP(5))) != MP_OK)
1738 goto CLEANUP;
1739 if ((res = mp_int_sub(TEMP(2), TEMP(0), TEMP(2))) != MP_OK)
1740 goto CLEANUP;
1741 if ((res = mp_int_sub(TEMP(3), TEMP(1), TEMP(3))) != MP_OK)
1742 goto CLEANUP;
1743 }
1744
1745 if (CMPZ(TEMP(4)) == 0)
1746 {
1747 if (x && (res = mp_int_copy(TEMP(2), x)) != MP_OK)
1748 goto CLEANUP;
1749 if (y && (res = mp_int_copy(TEMP(3), y)) != MP_OK)
1750 goto CLEANUP;
1751 if (c)
1752 {
1753 if (!s_qmul(TEMP(5), k))
1754 {
1755 res = MP_MEMORY;
1756 goto CLEANUP;
1757 }
1758
1759 res = mp_int_copy(TEMP(5), c);
1760 }
1761
1762 break;
1763 }
1764 }
1765
1766 CLEANUP:
1767 while (--last >= 0)
1768 mp_int_clear(TEMP(last));
1769
1770 return res;
1771 }
1772
1773 /* }}} */
1774
1775 /* {{{ mp_int_divisible_value(a, v) */
1776
1777 int
mp_int_divisible_value(mp_int a,int v)1778 mp_int_divisible_value(mp_int a, int v)
1779 {
1780 int rem = 0;
1781
1782 if (mp_int_div_value(a, v, NULL, &rem) != MP_OK)
1783 return 0;
1784
1785 return rem == 0;
1786 }
1787
1788 /* }}} */
1789
1790 /* {{{ mp_int_is_pow2(z) */
1791
1792 int
mp_int_is_pow2(mp_int z)1793 mp_int_is_pow2(mp_int z)
1794 {
1795 CHECK(z != NULL);
1796
1797 return s_isp2(z);
1798 }
1799
1800 /* }}} */
1801
1802 /* {{{ mp_int_sqrt(a, c) */
1803
1804 mp_result
mp_int_sqrt(mp_int a,mp_int c)1805 mp_int_sqrt(mp_int a, mp_int c)
1806 {
1807 mp_result res = MP_OK;
1808 mpz_t temp[2];
1809 int last = 0;
1810
1811 CHECK(a != NULL && c != NULL);
1812
1813 /* The square root of a negative value does not exist in the integers. */
1814 if (MP_SIGN(a) == MP_NEG)
1815 return MP_UNDEF;
1816
1817 SETUP(mp_int_init_copy(TEMP(last), a), last);
1818 SETUP(mp_int_init(TEMP(last)), last);
1819
1820 for (;;)
1821 {
1822 if ((res = mp_int_sqr(TEMP(0), TEMP(1))) != MP_OK)
1823 goto CLEANUP;
1824
1825 if (mp_int_compare_unsigned(a, TEMP(1)) == 0)
1826 break;
1827
1828 if ((res = mp_int_copy(a, TEMP(1))) != MP_OK)
1829 goto CLEANUP;
1830 if ((res = mp_int_div(TEMP(1), TEMP(0), TEMP(1), NULL)) != MP_OK)
1831 goto CLEANUP;
1832 if ((res = mp_int_add(TEMP(0), TEMP(1), TEMP(1))) != MP_OK)
1833 goto CLEANUP;
1834 if ((res = mp_int_div_pow2(TEMP(1), 1, TEMP(1), NULL)) != MP_OK)
1835 goto CLEANUP;
1836
1837 if (mp_int_compare_unsigned(TEMP(0), TEMP(1)) == 0)
1838 break;
1839 if ((res = mp_int_sub_value(TEMP(0), 1, TEMP(0))) != MP_OK)
1840 goto CLEANUP;
1841 if (mp_int_compare_unsigned(TEMP(0), TEMP(1)) == 0)
1842 break;
1843
1844 if ((res = mp_int_copy(TEMP(1), TEMP(0))) != MP_OK)
1845 goto CLEANUP;
1846 }
1847
1848 res = mp_int_copy(TEMP(0), c);
1849
1850 CLEANUP:
1851 while (--last >= 0)
1852 mp_int_clear(TEMP(last));
1853
1854 return res;
1855 }
1856
1857 /* }}} */
1858
1859 /* {{{ mp_int_to_int(z, out) */
1860
1861 mp_result
mp_int_to_int(mp_int z,int * out)1862 mp_int_to_int(mp_int z, int *out)
1863 {
1864 unsigned int uv = 0;
1865 mp_size uz;
1866 mp_digit *dz;
1867 mp_sign sz;
1868
1869 CHECK(z != NULL);
1870
1871 /* Make sure the value is representable as an int */
1872 sz = MP_SIGN(z);
1873 if ((sz == MP_ZPOS && mp_int_compare_value(z, INT_MAX) > 0) ||
1874 mp_int_compare_value(z, INT_MIN) < 0)
1875 return MP_RANGE;
1876
1877 uz = MP_USED(z);
1878 dz = MP_DIGITS(z) + uz - 1;
1879
1880 while (uz > 0)
1881 {
1882 uv <<= MP_DIGIT_BIT / 2;
1883 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1884 --uz;
1885 }
1886
1887 if (out)
1888 *out = (sz == MP_NEG) ? -(int) uv : (int) uv;
1889
1890 return MP_OK;
1891 }
1892
1893 /* }}} */
1894
1895 /* {{{ mp_int_to_string(z, radix, str, limit) */
1896
1897 mp_result
mp_int_to_string(mp_int z,mp_size radix,char * str,int limit)1898 mp_int_to_string(mp_int z, mp_size radix,
1899 char *str, int limit)
1900 {
1901 mp_result res;
1902 int cmp = 0;
1903
1904 CHECK(z != NULL && str != NULL && limit >= 2);
1905
1906 if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1907 return MP_RANGE;
1908
1909 if (CMPZ(z) == 0)
1910 {
1911 *str++ = s_val2ch(0, mp_flags & MP_CAP_DIGITS);
1912 }
1913 else
1914 {
1915 mpz_t tmp;
1916 char *h,
1917 *t;
1918
1919 if ((res = mp_int_init_copy(&tmp, z)) != MP_OK)
1920 return res;
1921
1922 if (MP_SIGN(z) == MP_NEG)
1923 {
1924 *str++ = '-';
1925 --limit;
1926 }
1927 h = str;
1928
1929 /* Generate digits in reverse order until finished or limit reached */
1930 for ( /* */ ; limit > 0; --limit)
1931 {
1932 mp_digit d;
1933
1934 if ((cmp = CMPZ(&tmp)) == 0)
1935 break;
1936
1937 d = s_ddiv(&tmp, (mp_digit) radix);
1938 *str++ = s_val2ch(d, mp_flags & MP_CAP_DIGITS);
1939 }
1940 t = str - 1;
1941
1942 /* Put digits back in correct output order */
1943 while (h < t)
1944 {
1945 char tc = *h;
1946
1947 *h++ = *t;
1948 *t-- = tc;
1949 }
1950
1951 mp_int_clear(&tmp);
1952 }
1953
1954 *str = '\0';
1955 if (cmp == 0)
1956 return MP_OK;
1957 else
1958 return MP_TRUNC;
1959 }
1960
1961 /* }}} */
1962
1963 /* {{{ mp_int_string_len(z, radix) */
1964
1965 mp_result
mp_int_string_len(mp_int z,mp_size radix)1966 mp_int_string_len(mp_int z, mp_size radix)
1967 {
1968 int len;
1969
1970 CHECK(z != NULL);
1971
1972 if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1973 return MP_RANGE;
1974
1975 len = s_outlen(z, radix) + 1; /* for terminator */
1976
1977 /* Allow for sign marker on negatives */
1978 if (MP_SIGN(z) == MP_NEG)
1979 len += 1;
1980
1981 return len;
1982 }
1983
1984 /* }}} */
1985
1986 /* {{{ mp_int_read_string(z, radix, *str) */
1987
1988 /* Read zero-terminated string into z */
1989 mp_result
mp_int_read_string(mp_int z,mp_size radix,const char * str)1990 mp_int_read_string(mp_int z, mp_size radix, const char *str)
1991 {
1992 return mp_int_read_cstring(z, radix, str, NULL);
1993
1994 }
1995
1996 /* }}} */
1997
1998 /* {{{ mp_int_read_cstring(z, radix, *str, **end) */
1999
2000 mp_result
mp_int_read_cstring(mp_int z,mp_size radix,const char * str,char ** end)2001 mp_int_read_cstring(mp_int z, mp_size radix, const char *str, char **end)
2002 {
2003 int ch;
2004
2005 CHECK(z != NULL && str != NULL);
2006
2007 if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
2008 return MP_RANGE;
2009
2010 /* Skip leading whitespace */
2011 while (isspace((unsigned char) *str))
2012 ++str;
2013
2014 /* Handle leading sign tag (+/-, positive default) */
2015 switch (*str)
2016 {
2017 case '-':
2018 MP_SIGN(z) = MP_NEG;
2019 ++str;
2020 break;
2021 case '+':
2022 ++str; /* fallthrough */
2023 default:
2024 MP_SIGN(z) = MP_ZPOS;
2025 break;
2026 }
2027
2028 /* Skip leading zeroes */
2029 while ((ch = s_ch2val(*str, radix)) == 0)
2030 ++str;
2031
2032 /* Make sure there is enough space for the value */
2033 if (!s_pad(z, s_inlen(strlen(str), radix)))
2034 return MP_MEMORY;
2035
2036 MP_USED(z) = 1;
2037 z->digits[0] = 0;
2038
2039 while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0))
2040 {
2041 s_dmul(z, (mp_digit) radix);
2042 s_dadd(z, (mp_digit) ch);
2043 ++str;
2044 }
2045
2046 CLAMP(z);
2047
2048 /* Override sign for zero, even if negative specified. */
2049 if (CMPZ(z) == 0)
2050 MP_SIGN(z) = MP_ZPOS;
2051
2052 if (end != NULL)
2053 *end = (char *) str;
2054
2055 /*
2056 * Return a truncation error if the string has unprocessed characters
2057 * remaining, so the caller can tell if the whole string was done
2058 */
2059 if (*str != '\0')
2060 return MP_TRUNC;
2061 else
2062 return MP_OK;
2063 }
2064
2065 /* }}} */
2066
2067 /* {{{ mp_int_count_bits(z) */
2068
2069 mp_result
mp_int_count_bits(mp_int z)2070 mp_int_count_bits(mp_int z)
2071 {
2072 mp_size nbits = 0,
2073 uz;
2074 mp_digit d;
2075
2076 CHECK(z != NULL);
2077
2078 uz = MP_USED(z);
2079 if (uz == 1 && z->digits[0] == 0)
2080 return 1;
2081
2082 --uz;
2083 nbits = uz * MP_DIGIT_BIT;
2084 d = z->digits[uz];
2085
2086 while (d != 0)
2087 {
2088 d >>= 1;
2089 ++nbits;
2090 }
2091
2092 return nbits;
2093 }
2094
2095 /* }}} */
2096
2097 /* {{{ mp_int_to_binary(z, buf, limit) */
2098
2099 mp_result
mp_int_to_binary(mp_int z,unsigned char * buf,int limit)2100 mp_int_to_binary(mp_int z, unsigned char *buf, int limit)
2101 {
2102 static const int PAD_FOR_2C = 1;
2103
2104 mp_result res;
2105 int limpos = limit;
2106
2107 CHECK(z != NULL && buf != NULL);
2108
2109 res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
2110
2111 if (MP_SIGN(z) == MP_NEG)
2112 s_2comp(buf, limpos);
2113
2114 return res;
2115 }
2116
2117 /* }}} */
2118
2119 /* {{{ mp_int_read_binary(z, buf, len) */
2120
2121 mp_result
mp_int_read_binary(mp_int z,unsigned char * buf,int len)2122 mp_int_read_binary(mp_int z, unsigned char *buf, int len)
2123 {
2124 mp_size need,
2125 i;
2126 unsigned char *tmp;
2127 mp_digit *dz;
2128
2129 CHECK(z != NULL && buf != NULL && len > 0);
2130
2131 /* Figure out how many digits are needed to represent this value */
2132 need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2133 if (!s_pad(z, need))
2134 return MP_MEMORY;
2135
2136 mp_int_zero(z);
2137
2138 /*
2139 * If the high-order bit is set, take the 2's complement before reading
2140 * the value (it will be restored afterward)
2141 */
2142 if (buf[0] >> (CHAR_BIT - 1))
2143 {
2144 MP_SIGN(z) = MP_NEG;
2145 s_2comp(buf, len);
2146 }
2147
2148 dz = MP_DIGITS(z);
2149 for (tmp = buf, i = len; i > 0; --i, ++tmp)
2150 {
2151 s_qmul(z, (mp_size) CHAR_BIT);
2152 *dz |= *tmp;
2153 }
2154
2155 /* Restore 2's complement if we took it before */
2156 if (MP_SIGN(z) == MP_NEG)
2157 s_2comp(buf, len);
2158
2159 return MP_OK;
2160 }
2161
2162 /* }}} */
2163
2164 /* {{{ mp_int_binary_len(z) */
2165
2166 mp_result
mp_int_binary_len(mp_int z)2167 mp_int_binary_len(mp_int z)
2168 {
2169 mp_result res = mp_int_count_bits(z);
2170 int bytes = mp_int_unsigned_len(z);
2171
2172 if (res <= 0)
2173 return res;
2174
2175 bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
2176
2177 /*
2178 * If the highest-order bit falls exactly on a byte boundary, we need to
2179 * pad with an extra byte so that the sign will be read correctly when
2180 * reading it back in.
2181 */
2182 if (bytes * CHAR_BIT == res)
2183 ++bytes;
2184
2185 return bytes;
2186 }
2187
2188 /* }}} */
2189
2190 /* {{{ mp_int_to_unsigned(z, buf, limit) */
2191
2192 mp_result
mp_int_to_unsigned(mp_int z,unsigned char * buf,int limit)2193 mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit)
2194 {
2195 static const int NO_PADDING = 0;
2196
2197 CHECK(z != NULL && buf != NULL);
2198
2199 return s_tobin(z, buf, &limit, NO_PADDING);
2200 }
2201
2202 /* }}} */
2203
2204 /* {{{ mp_int_read_unsigned(z, buf, len) */
2205
2206 mp_result
mp_int_read_unsigned(mp_int z,unsigned char * buf,int len)2207 mp_int_read_unsigned(mp_int z, unsigned char *buf, int len)
2208 {
2209 mp_size need,
2210 i;
2211 unsigned char *tmp;
2212 mp_digit *dz;
2213
2214 CHECK(z != NULL && buf != NULL && len > 0);
2215
2216 /* Figure out how many digits are needed to represent this value */
2217 need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2218 if (!s_pad(z, need))
2219 return MP_MEMORY;
2220
2221 mp_int_zero(z);
2222
2223 dz = MP_DIGITS(z);
2224 for (tmp = buf, i = len; i > 0; --i, ++tmp)
2225 {
2226 (void) s_qmul(z, CHAR_BIT);
2227 *dz |= *tmp;
2228 }
2229
2230 return MP_OK;
2231 }
2232
2233 /* }}} */
2234
2235 /* {{{ mp_int_unsigned_len(z) */
2236
2237 mp_result
mp_int_unsigned_len(mp_int z)2238 mp_int_unsigned_len(mp_int z)
2239 {
2240 mp_result res = mp_int_count_bits(z);
2241 int bytes;
2242
2243 if (res <= 0)
2244 return res;
2245
2246 bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
2247
2248 return bytes;
2249 }
2250
2251 /* }}} */
2252
2253 /* {{{ mp_error_string(res) */
2254
2255 const char *
mp_error_string(mp_result res)2256 mp_error_string(mp_result res)
2257 {
2258 int ix;
2259
2260 if (res > 0)
2261 return s_unknown_err;
2262
2263 res = -res;
2264 for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
2265 ;
2266
2267 if (s_error_msg[ix] != NULL)
2268 return s_error_msg[ix];
2269 else
2270 return s_unknown_err;
2271 }
2272
2273 /* }}} */
2274
2275 /*------------------------------------------------------------------------*/
2276 /* Private functions for internal use. These make assumptions. */
2277
2278 /* {{{ s_alloc(num) */
2279
2280 static mp_digit *
s_alloc(mp_size num)2281 s_alloc(mp_size num)
2282 {
2283 mp_digit *out = px_alloc(num * sizeof(mp_digit));
2284
2285 assert(out != NULL); /* for debugging */
2286
2287 return out;
2288 }
2289
2290 /* }}} */
2291
2292 /* {{{ s_realloc(old, num) */
2293
2294 static mp_digit *
s_realloc(mp_digit * old,mp_size num)2295 s_realloc(mp_digit *old, mp_size num)
2296 {
2297 mp_digit *new = px_realloc(old, num * sizeof(mp_digit));
2298
2299 assert(new != NULL); /* for debugging */
2300
2301 return new;
2302 }
2303
2304 /* }}} */
2305
2306 /* {{{ s_free(ptr) */
2307
2308 #if TRACEABLE_FREE
2309 static void
s_free(void * ptr)2310 s_free(void *ptr)
2311 {
2312 px_free(ptr);
2313 }
2314 #endif
2315
2316 /* }}} */
2317
2318 /* {{{ s_pad(z, min) */
2319
2320 static int
s_pad(mp_int z,mp_size min)2321 s_pad(mp_int z, mp_size min)
2322 {
2323 if (MP_ALLOC(z) < min)
2324 {
2325 mp_size nsize = ROUND_PREC(min);
2326 mp_digit *tmp = s_realloc(MP_DIGITS(z), nsize);
2327
2328 if (tmp == NULL)
2329 return 0;
2330
2331 MP_DIGITS(z) = tmp;
2332 MP_ALLOC(z) = nsize;
2333 }
2334
2335 return 1;
2336 }
2337
2338 /* }}} */
2339
2340 /* {{{ s_clamp(z) */
2341
2342 #if TRACEABLE_CLAMP
2343 static void
s_clamp(mp_int z)2344 s_clamp(mp_int z)
2345 {
2346 mp_size uz = MP_USED(z);
2347 mp_digit *zd = MP_DIGITS(z) + uz - 1;
2348
2349 while (uz > 1 && (*zd-- == 0))
2350 --uz;
2351
2352 MP_USED(z) = uz;
2353 }
2354 #endif
2355
2356 /* }}} */
2357
2358 /* {{{ s_fake(z, value, vbuf) */
2359
2360 static void
s_fake(mp_int z,int value,mp_digit vbuf[])2361 s_fake(mp_int z, int value, mp_digit vbuf[])
2362 {
2363 mp_size uv = (mp_size) s_vpack(value, vbuf);
2364
2365 z->used = uv;
2366 z->alloc = MP_VALUE_DIGITS(value);
2367 z->sign = (value < 0) ? MP_NEG : MP_ZPOS;
2368 z->digits = vbuf;
2369 }
2370
2371 /* }}} */
2372
2373 /* {{{ s_cdig(da, db, len) */
2374
2375 static int
s_cdig(mp_digit * da,mp_digit * db,mp_size len)2376 s_cdig(mp_digit *da, mp_digit *db, mp_size len)
2377 {
2378 mp_digit *dat = da + len - 1,
2379 *dbt = db + len - 1;
2380
2381 for ( /* */ ; len != 0; --len, --dat, --dbt)
2382 {
2383 if (*dat > *dbt)
2384 return 1;
2385 else if (*dat < *dbt)
2386 return -1;
2387 }
2388
2389 return 0;
2390 }
2391
2392 /* }}} */
2393
2394 /* {{{ s_vpack(v, t[]) */
2395
2396 static int
s_vpack(int v,mp_digit t[])2397 s_vpack(int v, mp_digit t[])
2398 {
2399 unsigned int uv = (unsigned int) ((v < 0) ? -v : v);
2400 int ndig = 0;
2401
2402 if (uv == 0)
2403 t[ndig++] = 0;
2404 else
2405 {
2406 while (uv != 0)
2407 {
2408 t[ndig++] = (mp_digit) uv;
2409 uv >>= MP_DIGIT_BIT / 2;
2410 uv >>= MP_DIGIT_BIT / 2;
2411 }
2412 }
2413
2414 return ndig;
2415 }
2416
2417 /* }}} */
2418
2419 /* {{{ s_ucmp(a, b) */
2420
2421 static int
s_ucmp(mp_int a,mp_int b)2422 s_ucmp(mp_int a, mp_int b)
2423 {
2424 mp_size ua = MP_USED(a),
2425 ub = MP_USED(b);
2426
2427 if (ua > ub)
2428 return 1;
2429 else if (ub > ua)
2430 return -1;
2431 else
2432 return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
2433 }
2434
2435 /* }}} */
2436
2437 /* {{{ s_vcmp(a, v) */
2438
2439 static int
s_vcmp(mp_int a,int v)2440 s_vcmp(mp_int a, int v)
2441 {
2442 mp_digit vdig[MP_VALUE_DIGITS(v)];
2443 int ndig = 0;
2444 mp_size ua = MP_USED(a);
2445
2446 ndig = s_vpack(v, vdig);
2447
2448 if (ua > ndig)
2449 return 1;
2450 else if (ua < ndig)
2451 return -1;
2452 else
2453 return s_cdig(MP_DIGITS(a), vdig, ndig);
2454 }
2455
2456 /* }}} */
2457
2458 /* {{{ s_uadd(da, db, dc, size_a, size_b) */
2459
2460 static mp_digit
s_uadd(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2461 s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc,
2462 mp_size size_a, mp_size size_b)
2463 {
2464 mp_size pos;
2465 mp_word w = 0;
2466
2467 /* Insure that da is the longer of the two to simplify later code */
2468 if (size_b > size_a)
2469 {
2470 SWAP(mp_digit *, da, db);
2471 SWAP(mp_size, size_a, size_b);
2472 }
2473
2474 /* Add corresponding digits until the shorter number runs out */
2475 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2476 {
2477 w = w + (mp_word) *da + (mp_word) *db;
2478 *dc = LOWER_HALF(w);
2479 w = UPPER_HALF(w);
2480 }
2481
2482 /* Propagate carries as far as necessary */
2483 for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2484 {
2485 w = w + *da;
2486
2487 *dc = LOWER_HALF(w);
2488 w = UPPER_HALF(w);
2489 }
2490
2491 /* Return carry out */
2492 return (mp_digit) w;
2493 }
2494
2495 /* }}} */
2496
2497 /* {{{ s_usub(da, db, dc, size_a, size_b) */
2498
2499 static void
s_usub(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2500 s_usub(mp_digit *da, mp_digit *db, mp_digit *dc,
2501 mp_size size_a, mp_size size_b)
2502 {
2503 mp_size pos;
2504 mp_word w = 0;
2505
2506 /* We assume that |a| >= |b| so this should definitely hold */
2507 assert(size_a >= size_b);
2508
2509 /* Subtract corresponding digits and propagate borrow */
2510 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2511 {
2512 w = ((mp_word) MP_DIGIT_MAX + 1 + /* MP_RADIX */
2513 (mp_word) *da) - w - (mp_word) *db;
2514
2515 *dc = LOWER_HALF(w);
2516 w = (UPPER_HALF(w) == 0);
2517 }
2518
2519 /* Finish the subtraction for remaining upper digits of da */
2520 for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2521 {
2522 w = ((mp_word) MP_DIGIT_MAX + 1 + /* MP_RADIX */
2523 (mp_word) *da) - w;
2524
2525 *dc = LOWER_HALF(w);
2526 w = (UPPER_HALF(w) == 0);
2527 }
2528
2529 /* If there is a borrow out at the end, it violates the precondition */
2530 assert(w == 0);
2531 }
2532
2533 /* }}} */
2534
2535 /* {{{ s_kmul(da, db, dc, size_a, size_b) */
2536
2537 static int
s_kmul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2538 s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc,
2539 mp_size size_a, mp_size size_b)
2540 {
2541 mp_size bot_size;
2542
2543 /* Make sure b is the smaller of the two input values */
2544 if (size_b > size_a)
2545 {
2546 SWAP(mp_digit *, da, db);
2547 SWAP(mp_size, size_a, size_b);
2548 }
2549
2550 /*
2551 * Insure that the bottom is the larger half in an odd-length split; the
2552 * code below relies on this being true.
2553 */
2554 bot_size = (size_a + 1) / 2;
2555
2556 /*
2557 * If the values are big enough to bother with recursion, use the
2558 * Karatsuba algorithm to compute the product; otherwise use the normal
2559 * multiplication algorithm
2560 */
2561 if (multiply_threshold &&
2562 size_a >= multiply_threshold &&
2563 size_b > bot_size)
2564 {
2565
2566 mp_digit *t1,
2567 *t2,
2568 *t3,
2569 carry;
2570
2571 mp_digit *a_top = da + bot_size;
2572 mp_digit *b_top = db + bot_size;
2573
2574 mp_size at_size = size_a - bot_size;
2575 mp_size bt_size = size_b - bot_size;
2576 mp_size buf_size = 2 * bot_size;
2577
2578 /*
2579 * Do a single allocation for all three temporary buffers needed; each
2580 * buffer must be big enough to hold the product of two bottom halves,
2581 * and one buffer needs space for the completed product; twice the
2582 * space is plenty.
2583 */
2584 if ((t1 = s_alloc(4 * buf_size)) == NULL)
2585 return 0;
2586 t2 = t1 + buf_size;
2587 t3 = t2 + buf_size;
2588 ZERO(t1, 4 * buf_size);
2589
2590 /*
2591 * t1 and t2 are initially used as temporaries to compute the inner
2592 * product (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
2593 */
2594 carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */
2595 t1[bot_size] = carry;
2596
2597 carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */
2598 t2[bot_size] = carry;
2599
2600 (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */
2601
2602 /*
2603 * Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so
2604 * that we're left with only the pieces we want: t3 = a1b0 + a0b1
2605 */
2606 ZERO(t1, buf_size);
2607 ZERO(t2, buf_size);
2608 (void) s_kmul(da, db, t1, bot_size, bot_size); /* t1 = a0 * b0 */
2609 (void) s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */
2610
2611 /* Subtract out t1 and t2 to get the inner product */
2612 s_usub(t3, t1, t3, buf_size + 2, buf_size);
2613 s_usub(t3, t2, t3, buf_size + 2, buf_size);
2614
2615 /* Assemble the output value */
2616 COPY(t1, dc, buf_size);
2617 carry = s_uadd(t3, dc + bot_size, dc + bot_size,
2618 buf_size + 1, buf_size);
2619 assert(carry == 0);
2620
2621 carry = s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size,
2622 buf_size, buf_size);
2623 assert(carry == 0);
2624
2625 s_free(t1); /* note t2 and t3 are just internal pointers
2626 * to t1 */
2627 }
2628 else
2629 {
2630 s_umul(da, db, dc, size_a, size_b);
2631 }
2632
2633 return 1;
2634 }
2635
2636 /* }}} */
2637
2638 /* {{{ s_umul(da, db, dc, size_a, size_b) */
2639
2640 static void
s_umul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2641 s_umul(mp_digit *da, mp_digit *db, mp_digit *dc,
2642 mp_size size_a, mp_size size_b)
2643 {
2644 mp_size a,
2645 b;
2646 mp_word w;
2647
2648 for (a = 0; a < size_a; ++a, ++dc, ++da)
2649 {
2650 mp_digit *dct = dc;
2651 mp_digit *dbt = db;
2652
2653 if (*da == 0)
2654 continue;
2655
2656 w = 0;
2657 for (b = 0; b < size_b; ++b, ++dbt, ++dct)
2658 {
2659 w = (mp_word) *da * (mp_word) *dbt + w + (mp_word) *dct;
2660
2661 *dct = LOWER_HALF(w);
2662 w = UPPER_HALF(w);
2663 }
2664
2665 *dct = (mp_digit) w;
2666 }
2667 }
2668
2669 /* }}} */
2670
2671 /* {{{ s_ksqr(da, dc, size_a) */
2672
2673 static int
s_ksqr(mp_digit * da,mp_digit * dc,mp_size size_a)2674 s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2675 {
2676 if (multiply_threshold && size_a > multiply_threshold)
2677 {
2678 mp_size bot_size = (size_a + 1) / 2;
2679 mp_digit *a_top = da + bot_size;
2680 mp_digit *t1,
2681 *t2,
2682 *t3;
2683 mp_size at_size = size_a - bot_size;
2684 mp_size buf_size = 2 * bot_size;
2685
2686 if ((t1 = s_alloc(4 * buf_size)) == NULL)
2687 return 0;
2688 t2 = t1 + buf_size;
2689 t3 = t2 + buf_size;
2690 ZERO(t1, 4 * buf_size);
2691
2692 (void) s_ksqr(da, t1, bot_size); /* t1 = a0 ^ 2 */
2693 (void) s_ksqr(a_top, t2, at_size); /* t2 = a1 ^ 2 */
2694
2695 (void) s_kmul(da, a_top, t3, bot_size, at_size); /* t3 = a0 * a1 */
2696
2697 /* Quick multiply t3 by 2, shifting left (can't overflow) */
2698 {
2699 int i,
2700 top = bot_size + at_size;
2701 mp_word w,
2702 save = 0;
2703
2704 for (i = 0; i < top; ++i)
2705 {
2706 w = t3[i];
2707 w = (w << 1) | save;
2708 t3[i] = LOWER_HALF(w);
2709 save = UPPER_HALF(w);
2710 }
2711 t3[i] = LOWER_HALF(save);
2712 }
2713
2714 /* Assemble the output value */
2715 COPY(t1, dc, 2 * bot_size);
2716 (void) s_uadd(t3, dc + bot_size, dc + bot_size,
2717 buf_size + 1, buf_size + 1);
2718
2719 (void) s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size,
2720 buf_size, buf_size);
2721
2722 px_free(t1); /* note that t2 and t2 are internal pointers
2723 * only */
2724
2725 }
2726 else
2727 {
2728 s_usqr(da, dc, size_a);
2729 }
2730
2731 return 1;
2732 }
2733
2734 /* }}} */
2735
2736 /* {{{ s_usqr(da, dc, size_a) */
2737
2738 static void
s_usqr(mp_digit * da,mp_digit * dc,mp_size size_a)2739 s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2740 {
2741 mp_size i,
2742 j;
2743 mp_word w;
2744
2745 for (i = 0; i < size_a; ++i, dc += 2, ++da)
2746 {
2747 mp_digit *dct = dc,
2748 *dat = da;
2749
2750 if (*da == 0)
2751 continue;
2752
2753 /* Take care of the first digit, no rollover */
2754 w = (mp_word) *dat * (mp_word) *dat + (mp_word) *dct;
2755 *dct = LOWER_HALF(w);
2756 w = UPPER_HALF(w);
2757 ++dat;
2758 ++dct;
2759
2760 for (j = i + 1; j < size_a; ++j, ++dat, ++dct)
2761 {
2762 mp_word t = (mp_word) *da * (mp_word) *dat;
2763 mp_word u = w + (mp_word) *dct,
2764 ov = 0;
2765
2766 /* Check if doubling t will overflow a word */
2767 if (HIGH_BIT_SET(t))
2768 ov = 1;
2769
2770 w = t + t;
2771
2772 /* Check if adding u to w will overflow a word */
2773 if (ADD_WILL_OVERFLOW(w, u))
2774 ov = 1;
2775
2776 w += u;
2777
2778 *dct = LOWER_HALF(w);
2779 w = UPPER_HALF(w);
2780 if (ov)
2781 {
2782 w += MP_DIGIT_MAX; /* MP_RADIX */
2783 ++w;
2784 }
2785 }
2786
2787 w = w + *dct;
2788 *dct = (mp_digit) w;
2789 while ((w = UPPER_HALF(w)) != 0)
2790 {
2791 ++dct;
2792 w = w + *dct;
2793 *dct = LOWER_HALF(w);
2794 }
2795
2796 assert(w == 0);
2797 }
2798 }
2799
2800 /* }}} */
2801
2802 /* {{{ s_dadd(a, b) */
2803
2804 static void
s_dadd(mp_int a,mp_digit b)2805 s_dadd(mp_int a, mp_digit b)
2806 {
2807 mp_word w = 0;
2808 mp_digit *da = MP_DIGITS(a);
2809 mp_size ua = MP_USED(a);
2810
2811 w = (mp_word) *da + b;
2812 *da++ = LOWER_HALF(w);
2813 w = UPPER_HALF(w);
2814
2815 for (ua -= 1; ua > 0; --ua, ++da)
2816 {
2817 w = (mp_word) *da + w;
2818
2819 *da = LOWER_HALF(w);
2820 w = UPPER_HALF(w);
2821 }
2822
2823 if (w)
2824 {
2825 *da = (mp_digit) w;
2826 MP_USED(a) += 1;
2827 }
2828 }
2829
2830 /* }}} */
2831
2832 /* {{{ s_dmul(a, b) */
2833
2834 static void
s_dmul(mp_int a,mp_digit b)2835 s_dmul(mp_int a, mp_digit b)
2836 {
2837 mp_word w = 0;
2838 mp_digit *da = MP_DIGITS(a);
2839 mp_size ua = MP_USED(a);
2840
2841 while (ua > 0)
2842 {
2843 w = (mp_word) *da * b + w;
2844 *da++ = LOWER_HALF(w);
2845 w = UPPER_HALF(w);
2846 --ua;
2847 }
2848
2849 if (w)
2850 {
2851 *da = (mp_digit) w;
2852 MP_USED(a) += 1;
2853 }
2854 }
2855
2856 /* }}} */
2857
2858 /* {{{ s_dbmul(da, b, dc, size_a) */
2859
2860 static void
s_dbmul(mp_digit * da,mp_digit b,mp_digit * dc,mp_size size_a)2861 s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a)
2862 {
2863 mp_word w = 0;
2864
2865 while (size_a > 0)
2866 {
2867 w = (mp_word) *da++ * (mp_word) b + w;
2868
2869 *dc++ = LOWER_HALF(w);
2870 w = UPPER_HALF(w);
2871 --size_a;
2872 }
2873
2874 if (w)
2875 *dc = LOWER_HALF(w);
2876 }
2877
2878 /* }}} */
2879
2880 /* {{{ s_ddiv(da, d, dc, size_a) */
2881
2882 static mp_digit
s_ddiv(mp_int a,mp_digit b)2883 s_ddiv(mp_int a, mp_digit b)
2884 {
2885 mp_word w = 0,
2886 qdigit;
2887 mp_size ua = MP_USED(a);
2888 mp_digit *da = MP_DIGITS(a) + ua - 1;
2889
2890 for ( /* */ ; ua > 0; --ua, --da)
2891 {
2892 w = (w << MP_DIGIT_BIT) | *da;
2893
2894 if (w >= b)
2895 {
2896 qdigit = w / b;
2897 w = w % b;
2898 }
2899 else
2900 {
2901 qdigit = 0;
2902 }
2903
2904 *da = (mp_digit) qdigit;
2905 }
2906
2907 CLAMP(a);
2908 return (mp_digit) w;
2909 }
2910
2911 /* }}} */
2912
2913 /* {{{ s_qdiv(z, p2) */
2914
2915 static void
s_qdiv(mp_int z,mp_size p2)2916 s_qdiv(mp_int z, mp_size p2)
2917 {
2918 mp_size ndig = p2 / MP_DIGIT_BIT,
2919 nbits = p2 % MP_DIGIT_BIT;
2920 mp_size uz = MP_USED(z);
2921
2922 if (ndig)
2923 {
2924 mp_size mark;
2925 mp_digit *to,
2926 *from;
2927
2928 if (ndig >= uz)
2929 {
2930 mp_int_zero(z);
2931 return;
2932 }
2933
2934 to = MP_DIGITS(z);
2935 from = to + ndig;
2936
2937 for (mark = ndig; mark < uz; ++mark)
2938 *to++ = *from++;
2939
2940 MP_USED(z) = uz - ndig;
2941 }
2942
2943 if (nbits)
2944 {
2945 mp_digit d = 0,
2946 *dz,
2947 save;
2948 mp_size up = MP_DIGIT_BIT - nbits;
2949
2950 uz = MP_USED(z);
2951 dz = MP_DIGITS(z) + uz - 1;
2952
2953 for ( /* */ ; uz > 0; --uz, --dz)
2954 {
2955 save = *dz;
2956
2957 *dz = (*dz >> nbits) | (d << up);
2958 d = save;
2959 }
2960
2961 CLAMP(z);
2962 }
2963
2964 if (MP_USED(z) == 1 && z->digits[0] == 0)
2965 MP_SIGN(z) = MP_ZPOS;
2966 }
2967
2968 /* }}} */
2969
2970 /* {{{ s_qmod(z, p2) */
2971
2972 static void
s_qmod(mp_int z,mp_size p2)2973 s_qmod(mp_int z, mp_size p2)
2974 {
2975 mp_size start = p2 / MP_DIGIT_BIT + 1,
2976 rest = p2 % MP_DIGIT_BIT;
2977 mp_size uz = MP_USED(z);
2978 mp_digit mask = (1 << rest) - 1;
2979
2980 if (start <= uz)
2981 {
2982 MP_USED(z) = start;
2983 z->digits[start - 1] &= mask;
2984 CLAMP(z);
2985 }
2986 }
2987
2988 /* }}} */
2989
2990 /* {{{ s_qmul(z, p2) */
2991
2992 static int
s_qmul(mp_int z,mp_size p2)2993 s_qmul(mp_int z, mp_size p2)
2994 {
2995 mp_size uz,
2996 need,
2997 rest,
2998 extra,
2999 i;
3000 mp_digit *from,
3001 *to,
3002 d;
3003
3004 if (p2 == 0)
3005 return 1;
3006
3007 uz = MP_USED(z);
3008 need = p2 / MP_DIGIT_BIT;
3009 rest = p2 % MP_DIGIT_BIT;
3010
3011 /*
3012 * Figure out if we need an extra digit at the top end; this occurs if the
3013 * topmost `rest' bits of the high-order digit of z are not zero, meaning
3014 * they will be shifted off the end if not preserved
3015 */
3016 extra = 0;
3017 if (rest != 0)
3018 {
3019 mp_digit *dz = MP_DIGITS(z) + uz - 1;
3020
3021 if ((*dz >> (MP_DIGIT_BIT - rest)) != 0)
3022 extra = 1;
3023 }
3024
3025 if (!s_pad(z, uz + need + extra))
3026 return 0;
3027
3028 /*
3029 * If we need to shift by whole digits, do that in one pass, then to back
3030 * and shift by partial digits.
3031 */
3032 if (need > 0)
3033 {
3034 from = MP_DIGITS(z) + uz - 1;
3035 to = from + need;
3036
3037 for (i = 0; i < uz; ++i)
3038 *to-- = *from--;
3039
3040 ZERO(MP_DIGITS(z), need);
3041 uz += need;
3042 }
3043
3044 if (rest)
3045 {
3046 d = 0;
3047 for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from)
3048 {
3049 mp_digit save = *from;
3050
3051 *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
3052 d = save;
3053 }
3054
3055 d >>= (MP_DIGIT_BIT - rest);
3056 if (d != 0)
3057 {
3058 *from = d;
3059 uz += extra;
3060 }
3061 }
3062
3063 MP_USED(z) = uz;
3064 CLAMP(z);
3065
3066 return 1;
3067 }
3068
3069 /* }}} */
3070
3071 /* {{{ s_qsub(z, p2) */
3072
3073 /* Subtract |z| from 2^p2, assuming 2^p2 > |z|, and set z to be positive */
3074 static int
s_qsub(mp_int z,mp_size p2)3075 s_qsub(mp_int z, mp_size p2)
3076 {
3077 mp_digit hi = (1 << (p2 % MP_DIGIT_BIT)),
3078 *zp;
3079 mp_size tdig = (p2 / MP_DIGIT_BIT),
3080 pos;
3081 mp_word w = 0;
3082
3083 if (!s_pad(z, tdig + 1))
3084 return 0;
3085
3086 for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp)
3087 {
3088 w = ((mp_word) MP_DIGIT_MAX + 1) - w - (mp_word) *zp;
3089
3090 *zp = LOWER_HALF(w);
3091 w = UPPER_HALF(w) ? 0 : 1;
3092 }
3093
3094 w = ((mp_word) MP_DIGIT_MAX + 1 + hi) - w - (mp_word) *zp;
3095 *zp = LOWER_HALF(w);
3096
3097 assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
3098
3099 MP_SIGN(z) = MP_ZPOS;
3100 CLAMP(z);
3101
3102 return 1;
3103 }
3104
3105 /* }}} */
3106
3107 /* {{{ s_dp2k(z) */
3108
3109 static int
s_dp2k(mp_int z)3110 s_dp2k(mp_int z)
3111 {
3112 int k = 0;
3113 mp_digit *dp = MP_DIGITS(z),
3114 d;
3115
3116 if (MP_USED(z) == 1 && *dp == 0)
3117 return 1;
3118
3119 while (*dp == 0)
3120 {
3121 k += MP_DIGIT_BIT;
3122 ++dp;
3123 }
3124
3125 d = *dp;
3126 while ((d & 1) == 0)
3127 {
3128 d >>= 1;
3129 ++k;
3130 }
3131
3132 return k;
3133 }
3134
3135 /* }}} */
3136
3137 /* {{{ s_isp2(z) */
3138
3139 static int
s_isp2(mp_int z)3140 s_isp2(mp_int z)
3141 {
3142 mp_size uz = MP_USED(z),
3143 k = 0;
3144 mp_digit *dz = MP_DIGITS(z),
3145 d;
3146
3147 while (uz > 1)
3148 {
3149 if (*dz++ != 0)
3150 return -1;
3151 k += MP_DIGIT_BIT;
3152 --uz;
3153 }
3154
3155 d = *dz;
3156 while (d > 1)
3157 {
3158 if (d & 1)
3159 return -1;
3160 ++k;
3161 d >>= 1;
3162 }
3163
3164 return (int) k;
3165 }
3166
3167 /* }}} */
3168
3169 /* {{{ s_2expt(z, k) */
3170
3171 static int
s_2expt(mp_int z,int k)3172 s_2expt(mp_int z, int k)
3173 {
3174 mp_size ndig,
3175 rest;
3176 mp_digit *dz;
3177
3178 ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
3179 rest = k % MP_DIGIT_BIT;
3180
3181 if (!s_pad(z, ndig))
3182 return 0;
3183
3184 dz = MP_DIGITS(z);
3185 ZERO(dz, ndig);
3186 *(dz + ndig - 1) = (1 << rest);
3187 MP_USED(z) = ndig;
3188
3189 return 1;
3190 }
3191
3192 /* }}} */
3193
3194 /* {{{ s_norm(a, b) */
3195
3196 static int
s_norm(mp_int a,mp_int b)3197 s_norm(mp_int a, mp_int b)
3198 {
3199 mp_digit d = b->digits[MP_USED(b) - 1];
3200 int k = 0;
3201
3202 while (d < (mp_digit) ((mp_digit) 1 << (MP_DIGIT_BIT - 1)))
3203 { /* d < (MP_RADIX / 2) */
3204 d <<= 1;
3205 ++k;
3206 }
3207
3208 /* These multiplications can't fail */
3209 if (k != 0)
3210 {
3211 (void) s_qmul(a, (mp_size) k);
3212 (void) s_qmul(b, (mp_size) k);
3213 }
3214
3215 return k;
3216 }
3217
3218 /* }}} */
3219
3220 /* {{{ s_brmu(z, m) */
3221
3222 static mp_result
s_brmu(mp_int z,mp_int m)3223 s_brmu(mp_int z, mp_int m)
3224 {
3225 mp_size um = MP_USED(m) * 2;
3226
3227 if (!s_pad(z, um))
3228 return MP_MEMORY;
3229
3230 s_2expt(z, MP_DIGIT_BIT * um);
3231 return mp_int_div(z, m, z, NULL);
3232 }
3233
3234 /* }}} */
3235
3236 /* {{{ s_reduce(x, m, mu, q1, q2) */
3237
3238 static int
s_reduce(mp_int x,mp_int m,mp_int mu,mp_int q1,mp_int q2)3239 s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2)
3240 {
3241 mp_size um = MP_USED(m),
3242 umb_p1,
3243 umb_m1;
3244
3245 umb_p1 = (um + 1) * MP_DIGIT_BIT;
3246 umb_m1 = (um - 1) * MP_DIGIT_BIT;
3247
3248 if (mp_int_copy(x, q1) != MP_OK)
3249 return 0;
3250
3251 /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
3252 s_qdiv(q1, umb_m1);
3253 UMUL(q1, mu, q2);
3254 s_qdiv(q2, umb_p1);
3255
3256 /* Set x = x mod b^(k+1) */
3257 s_qmod(x, umb_p1);
3258
3259 /*
3260 * Now, q is a guess for the quotient a / m. Compute x - q * m mod
3261 * b^(k+1), replacing x. This may be off by a factor of 2m, but no more
3262 * than that.
3263 */
3264 UMUL(q2, m, q1);
3265 s_qmod(q1, umb_p1);
3266 (void) mp_int_sub(x, q1, x); /* can't fail */
3267
3268 /*
3269 * The result may be < 0; if it is, add b^(k+1) to pin it in the proper
3270 * range.
3271 */
3272 if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1))
3273 return 0;
3274
3275 /*
3276 * If x > m, we need to back it off until it is in range. This will be
3277 * required at most twice.
3278 */
3279 if (mp_int_compare(x, m) >= 0)
3280 (void) mp_int_sub(x, m, x);
3281 if (mp_int_compare(x, m) >= 0)
3282 (void) mp_int_sub(x, m, x);
3283
3284 /* At this point, x has been properly reduced. */
3285 return 1;
3286 }
3287
3288 /* }}} */
3289
3290 /* {{{ s_embar(a, b, m, mu, c) */
3291
3292 /* Perform modular exponentiation using Barrett's method, where mu is
3293 the reduction constant for m. Assumes a < m, b > 0. */
3294 static mp_result
s_embar(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)3295 s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
3296 {
3297 mp_digit *db,
3298 *dbt,
3299 umu,
3300 d;
3301 mpz_t temp[3];
3302 mp_result res;
3303 int last = 0;
3304
3305 umu = MP_USED(mu);
3306 db = MP_DIGITS(b);
3307 dbt = db + MP_USED(b) - 1;
3308
3309 while (last < 3)
3310 {
3311 SETUP(mp_int_init_size(TEMP(last), 4 * umu), last);
3312 ZERO(MP_DIGITS(TEMP(last - 1)), MP_ALLOC(TEMP(last - 1)));
3313 }
3314
3315 (void) mp_int_set_value(c, 1);
3316
3317 /* Take care of low-order digits */
3318 while (db < dbt)
3319 {
3320 int i;
3321
3322 for (d = *db, i = MP_DIGIT_BIT; i > 0; --i, d >>= 1)
3323 {
3324 if (d & 1)
3325 {
3326 /* The use of a second temporary avoids allocation */
3327 UMUL(c, a, TEMP(0));
3328 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3329 {
3330 res = MP_MEMORY;
3331 goto CLEANUP;
3332 }
3333 mp_int_copy(TEMP(0), c);
3334 }
3335
3336
3337 USQR(a, TEMP(0));
3338 assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3339 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3340 {
3341 res = MP_MEMORY;
3342 goto CLEANUP;
3343 }
3344 assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3345 mp_int_copy(TEMP(0), a);
3346
3347
3348 }
3349
3350 ++db;
3351 }
3352
3353 /* Take care of highest-order digit */
3354 d = *dbt;
3355 for (;;)
3356 {
3357 if (d & 1)
3358 {
3359 UMUL(c, a, TEMP(0));
3360 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3361 {
3362 res = MP_MEMORY;
3363 goto CLEANUP;
3364 }
3365 mp_int_copy(TEMP(0), c);
3366 }
3367
3368 d >>= 1;
3369 if (!d)
3370 break;
3371
3372 USQR(a, TEMP(0));
3373 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3374 {
3375 res = MP_MEMORY;
3376 goto CLEANUP;
3377 }
3378 (void) mp_int_copy(TEMP(0), a);
3379 }
3380
3381 CLEANUP:
3382 while (--last >= 0)
3383 mp_int_clear(TEMP(last));
3384
3385 return res;
3386 }
3387
3388 /* }}} */
3389
3390 /* {{{ s_udiv(a, b) */
3391
3392 /* Precondition: a >= b and b > 0
3393 Postcondition: a' = a / b, b' = a % b
3394 */
3395 static mp_result
s_udiv(mp_int a,mp_int b)3396 s_udiv(mp_int a, mp_int b)
3397 {
3398 mpz_t q,
3399 r,
3400 t;
3401 mp_size ua,
3402 ub,
3403 qpos = 0;
3404 mp_digit *da,
3405 btop;
3406 mp_result res = MP_OK;
3407 int k,
3408 skip = 0;
3409
3410 /* Force signs to positive */
3411 MP_SIGN(a) = MP_ZPOS;
3412 MP_SIGN(b) = MP_ZPOS;
3413
3414 /* Normalize, per Knuth */
3415 k = s_norm(a, b);
3416
3417 ua = MP_USED(a);
3418 ub = MP_USED(b);
3419 btop = b->digits[ub - 1];
3420 if ((res = mp_int_init_size(&q, ua)) != MP_OK)
3421 return res;
3422 if ((res = mp_int_init_size(&t, ua + 1)) != MP_OK)
3423 goto CLEANUP;
3424
3425 da = MP_DIGITS(a);
3426 r.digits = da + ua - 1; /* The contents of r are shared with a */
3427 r.used = 1;
3428 r.sign = MP_ZPOS;
3429 r.alloc = MP_ALLOC(a);
3430 ZERO(t.digits, t.alloc);
3431
3432 /* Solve for quotient digits, store in q.digits in reverse order */
3433 while (r.digits >= da)
3434 {
3435 assert(qpos <= q.alloc);
3436
3437 if (s_ucmp(b, &r) > 0)
3438 {
3439 r.digits -= 1;
3440 r.used += 1;
3441
3442 if (++skip > 1)
3443 q.digits[qpos++] = 0;
3444
3445 CLAMP(&r);
3446 }
3447 else
3448 {
3449 mp_word pfx = r.digits[r.used - 1];
3450 mp_word qdigit;
3451
3452 if (r.used > 1 && (pfx < btop || r.digits[r.used - 2] == 0))
3453 {
3454 pfx <<= MP_DIGIT_BIT / 2;
3455 pfx <<= MP_DIGIT_BIT / 2;
3456 pfx |= r.digits[r.used - 2];
3457 }
3458
3459 qdigit = pfx / btop;
3460 if (qdigit > MP_DIGIT_MAX)
3461 qdigit = 1;
3462
3463 s_dbmul(MP_DIGITS(b), (mp_digit) qdigit, t.digits, ub);
3464 t.used = ub + 1;
3465 CLAMP(&t);
3466 while (s_ucmp(&t, &r) > 0)
3467 {
3468 --qdigit;
3469 (void) mp_int_sub(&t, b, &t); /* cannot fail */
3470 }
3471
3472 s_usub(r.digits, t.digits, r.digits, r.used, t.used);
3473 CLAMP(&r);
3474
3475 q.digits[qpos++] = (mp_digit) qdigit;
3476 ZERO(t.digits, t.used);
3477 skip = 0;
3478 }
3479 }
3480
3481 /* Put quotient digits in the correct order, and discard extra zeroes */
3482 q.used = qpos;
3483 REV(mp_digit, q.digits, qpos);
3484 CLAMP(&q);
3485
3486 /* Denormalize the remainder */
3487 CLAMP(a);
3488 if (k != 0)
3489 s_qdiv(a, k);
3490
3491 mp_int_copy(a, b); /* ok: 0 <= r < b */
3492 mp_int_copy(&q, a); /* ok: q <= a */
3493
3494 mp_int_clear(&t);
3495 CLEANUP:
3496 mp_int_clear(&q);
3497 return res;
3498 }
3499
3500 /* }}} */
3501
3502 /* {{{ s_outlen(z, r) */
3503
3504 /* Precondition: 2 <= r < 64 */
3505 static int
s_outlen(mp_int z,mp_size r)3506 s_outlen(mp_int z, mp_size r)
3507 {
3508 mp_result bits;
3509 double raw;
3510
3511 bits = mp_int_count_bits(z);
3512 raw = (double) bits * s_log2[r];
3513
3514 return (int) (raw + 0.999999);
3515 }
3516
3517 /* }}} */
3518
3519 /* {{{ s_inlen(len, r) */
3520
3521 static mp_size
s_inlen(int len,mp_size r)3522 s_inlen(int len, mp_size r)
3523 {
3524 double raw = (double) len / s_log2[r];
3525 mp_size bits = (mp_size) (raw + 0.5);
3526
3527 return (mp_size) ((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT);
3528 }
3529
3530 /* }}} */
3531
3532 /* {{{ s_ch2val(c, r) */
3533
3534 static int
s_ch2val(char c,int r)3535 s_ch2val(char c, int r)
3536 {
3537 int out;
3538
3539 if (isdigit((unsigned char) c))
3540 out = c - '0';
3541 else if (r > 10 && isalpha((unsigned char) c))
3542 out = toupper((unsigned char) c) - 'A' + 10;
3543 else
3544 return -1;
3545
3546 return (out >= r) ? -1 : out;
3547 }
3548
3549 /* }}} */
3550
3551 /* {{{ s_val2ch(v, caps) */
3552
3553 static char
s_val2ch(int v,int caps)3554 s_val2ch(int v, int caps)
3555 {
3556 assert(v >= 0);
3557
3558 if (v < 10)
3559 return v + '0';
3560 else
3561 {
3562 char out = (v - 10) + 'a';
3563
3564 if (caps)
3565 return toupper((unsigned char) out);
3566 else
3567 return out;
3568 }
3569 }
3570
3571 /* }}} */
3572
3573 /* {{{ s_2comp(buf, len) */
3574
3575 static void
s_2comp(unsigned char * buf,int len)3576 s_2comp(unsigned char *buf, int len)
3577 {
3578 int i;
3579 unsigned short s = 1;
3580
3581 for (i = len - 1; i >= 0; --i)
3582 {
3583 unsigned char c = ~buf[i];
3584
3585 s = c + s;
3586 c = s & UCHAR_MAX;
3587 s >>= CHAR_BIT;
3588
3589 buf[i] = c;
3590 }
3591
3592 /* last carry out is ignored */
3593 }
3594
3595 /* }}} */
3596
3597 /* {{{ s_tobin(z, buf, *limpos) */
3598
3599 static mp_result
s_tobin(mp_int z,unsigned char * buf,int * limpos,int pad)3600 s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad)
3601 {
3602 mp_size uz;
3603 mp_digit *dz;
3604 int pos = 0,
3605 limit = *limpos;
3606
3607 uz = MP_USED(z);
3608 dz = MP_DIGITS(z);
3609 while (uz > 0 && pos < limit)
3610 {
3611 mp_digit d = *dz++;
3612 int i;
3613
3614 for (i = sizeof(mp_digit); i > 0 && pos < limit; --i)
3615 {
3616 buf[pos++] = (unsigned char) d;
3617 d >>= CHAR_BIT;
3618
3619 /* Don't write leading zeroes */
3620 if (d == 0 && uz == 1)
3621 i = 0; /* exit loop without signaling truncation */
3622 }
3623
3624 /* Detect truncation (loop exited with pos >= limit) */
3625 if (i > 0)
3626 break;
3627
3628 --uz;
3629 }
3630
3631 if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1)))
3632 {
3633 if (pos < limit)
3634 buf[pos++] = 0;
3635 else
3636 uz = 1;
3637 }
3638
3639 /* Digits are in reverse order, fix that */
3640 REV(unsigned char, buf, pos);
3641
3642 /* Return the number of bytes actually written */
3643 *limpos = pos;
3644
3645 return (uz == 0) ? MP_OK : MP_TRUNC;
3646 }
3647
3648 /* }}} */
3649
3650 /* {{{ s_print(tag, z) */
3651
3652 #if 0
3653 void
3654 s_print(char *tag, mp_int z)
3655 {
3656 int i;
3657
3658 fprintf(stderr, "%s: %c ", tag,
3659 (MP_SIGN(z) == MP_NEG) ? '-' : '+');
3660
3661 for (i = MP_USED(z) - 1; i >= 0; --i)
3662 fprintf(stderr, "%0*X", (int) (MP_DIGIT_BIT / 4), z->digits[i]);
3663
3664 fputc('\n', stderr);
3665
3666 }
3667
3668 void
3669 s_print_buf(char *tag, mp_digit *buf, mp_size num)
3670 {
3671 int i;
3672
3673 fprintf(stderr, "%s: ", tag);
3674
3675 for (i = num - 1; i >= 0; --i)
3676 fprintf(stderr, "%0*X", (int) (MP_DIGIT_BIT / 4), buf[i]);
3677
3678 fputc('\n', stderr);
3679 }
3680 #endif
3681
3682 /* }}} */
3683
3684 /* HERE THERE BE DRAGONS */
3685