1 /*------------------------------------------------------------------------- 2 * 3 * imath.c 4 * 5 * Last synchronized from https://github.com/creachadair/imath/tree/v1.29, 6 * using the following procedure: 7 * 8 * 1. Download imath.c and imath.h of the last synchronized version. Remove 9 * "#ifdef __cplusplus" blocks, which upset pgindent. Run pgindent on the 10 * two files. Filter the two files through "unexpand -t4 --first-only". 11 * Diff the result against the PostgreSQL versions. As of the last 12 * synchronization, changes were as follows: 13 * 14 * - replace malloc(), realloc() and free() with px_ versions 15 * - redirect assert() to Assert() 16 * - #undef MIN, #undef MAX before defining them 17 * - remove includes covered by c.h 18 * - rename DEBUG to IMATH_DEBUG 19 * - replace stdint.h usage with c.h equivalents 20 * - suppress MSVC warning 4146 21 * - add required PG_USED_FOR_ASSERTS_ONLY 22 * 23 * 2. Download a newer imath.c and imath.h. Transform them like in step 1. 24 * Apply to these files the diff you saved in step 1. Look for new lines 25 * requiring the same kind of change, such as new malloc() calls. 26 * 27 * 3. Configure PostgreSQL using --without-openssl. Run "make -C 28 * contrib/pgcrypto check". 29 * 30 * 4. Update this header comment. 31 * 32 * Portions Copyright (c) 1996-2020, PostgreSQL Global Development Group 33 * 34 * IDENTIFICATION 35 * contrib/pgcrypto/imath.c 36 * 37 * Upstream copyright terms follow. 38 *------------------------------------------------------------------------- 39 */ 40 41 /* 42 Name: imath.c 43 Purpose: Arbitrary precision integer arithmetic routines. 44 Author: M. J. Fromberger 45 46 Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved. 47 48 Permission is hereby granted, free of charge, to any person obtaining a copy 49 of this software and associated documentation files (the "Software"), to deal 50 in the Software without restriction, including without limitation the rights 51 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 52 copies of the Software, and to permit persons to whom the Software is 53 furnished to do so, subject to the following conditions: 54 55 The above copyright notice and this permission notice shall be included in 56 all copies or substantial portions of the Software. 57 58 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 59 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 60 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 61 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 62 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 63 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 64 SOFTWARE. 65 */ 66 67 #include "postgres.h" 68 69 #include "imath.h" 70 #include "px.h" 71 72 #undef assert 73 #define assert(TEST) Assert(TEST) 74 75 const mp_result MP_OK = 0; /* no error, all is well */ 76 const mp_result MP_FALSE = 0; /* boolean false */ 77 const mp_result MP_TRUE = -1; /* boolean true */ 78 const mp_result MP_MEMORY = -2; /* out of memory */ 79 const mp_result MP_RANGE = -3; /* argument out of range */ 80 const mp_result MP_UNDEF = -4; /* result undefined */ 81 const mp_result MP_TRUNC = -5; /* output truncated */ 82 const mp_result MP_BADARG = -6; /* invalid null argument */ 83 const mp_result MP_MINERR = -6; 84 85 const mp_sign MP_NEG = 1; /* value is strictly negative */ 86 const mp_sign MP_ZPOS = 0; /* value is non-negative */ 87 88 static const char *s_unknown_err = "unknown result code"; 89 static const char *s_error_msg[] = {"error code 0", "boolean true", 90 "out of memory", "argument out of range", 91 "result undefined", "output truncated", 92 "invalid argument", NULL}; 93 94 /* The ith entry of this table gives the value of log_i(2). 95 96 An integer value n requires ceil(log_i(n)) digits to be represented 97 in base i. Since it is easy to compute lg(n), by counting bits, we 98 can compute log_i(n) = lg(n) * log_i(2). 99 100 The use of this table eliminates a dependency upon linkage against 101 the standard math libraries. 102 103 If MP_MAX_RADIX is increased, this table should be expanded too. 104 */ 105 static const double s_log2[] = { 106 0.000000000, 0.000000000, 1.000000000, 0.630929754, /* (D)(D) 2 3 */ 107 0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4 5 6 7 */ 108 0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8 9 10 11 */ 109 0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */ 110 0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */ 111 0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */ 112 0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */ 113 0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */ 114 0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */ 115 0.193426404, /* 36 */ 116 }; 117 118 /* Return the number of digits needed to represent a static value */ 119 #define MP_VALUE_DIGITS(V) \ 120 ((sizeof(V) + (sizeof(mp_digit) - 1)) / sizeof(mp_digit)) 121 122 /* Round precision P to nearest word boundary */ 123 static inline mp_size 124 s_round_prec(mp_size P) 125 { 126 return 2 * ((P + 1) / 2); 127 } 128 129 /* Set array P of S digits to zero */ 130 static inline void 131 ZERO(mp_digit *P, mp_size S) 132 { 133 mp_size i__ = S * sizeof(mp_digit); 134 mp_digit *p__ = P; 135 136 memset(p__, 0, i__); 137 } 138 139 /* Copy S digits from array P to array Q */ 140 static inline void 141 COPY(mp_digit *P, mp_digit *Q, mp_size S) 142 { 143 mp_size i__ = S * sizeof(mp_digit); 144 mp_digit *p__ = P; 145 mp_digit *q__ = Q; 146 147 memcpy(q__, p__, i__); 148 } 149 150 /* Reverse N elements of unsigned char in A. */ 151 static inline void 152 REV(unsigned char *A, int N) 153 { 154 unsigned char *u_ = A; 155 unsigned char *v_ = u_ + N - 1; 156 157 while (u_ < v_) 158 { 159 unsigned char xch = *u_; 160 161 *u_++ = *v_; 162 *v_-- = xch; 163 } 164 } 165 166 /* Strip leading zeroes from z_ in-place. */ 167 static inline void 168 CLAMP(mp_int z_) 169 { 170 mp_size uz_ = MP_USED(z_); 171 mp_digit *dz_ = MP_DIGITS(z_) + uz_ - 1; 172 173 while (uz_ > 1 && (*dz_-- == 0)) 174 --uz_; 175 z_->used = uz_; 176 } 177 178 /* Select min/max. */ 179 #undef MIN 180 #undef MAX 181 static inline int 182 MIN(int A, int B) 183 { 184 return (B < A ? B : A); 185 } 186 static inline mp_size 187 MAX(mp_size A, mp_size B) 188 { 189 return (B > A ? B : A); 190 } 191 192 /* Exchange lvalues A and B of type T, e.g. 193 SWAP(int, x, y) where x and y are variables of type int. */ 194 #define SWAP(T, A, B) \ 195 do { \ 196 T t_ = (A); \ 197 A = (B); \ 198 B = t_; \ 199 } while (0) 200 201 /* Declare a block of N temporary mpz_t values. 202 These values are initialized to zero. 203 You must add CLEANUP_TEMP() at the end of the function. 204 Use TEMP(i) to access a pointer to the ith value. 205 */ 206 #define DECLARE_TEMP(N) \ 207 struct { \ 208 mpz_t value[(N)]; \ 209 int len; \ 210 mp_result err; \ 211 } temp_ = { \ 212 .len = (N), \ 213 .err = MP_OK, \ 214 }; \ 215 do { \ 216 for (int i = 0; i < temp_.len; i++) { \ 217 mp_int_init(TEMP(i)); \ 218 } \ 219 } while (0) 220 221 /* Clear all allocated temp values. */ 222 #define CLEANUP_TEMP() \ 223 CLEANUP: \ 224 do { \ 225 for (int i = 0; i < temp_.len; i++) { \ 226 mp_int_clear(TEMP(i)); \ 227 } \ 228 if (temp_.err != MP_OK) { \ 229 return temp_.err; \ 230 } \ 231 } while (0) 232 233 /* A pointer to the kth temp value. */ 234 #define TEMP(K) (temp_.value + (K)) 235 236 /* Evaluate E, an expression of type mp_result expected to return MP_OK. If 237 the value is not MP_OK, the error is cached and control resumes at the 238 cleanup handler, which returns it. 239 */ 240 #define REQUIRE(E) \ 241 do { \ 242 temp_.err = (E); \ 243 if (temp_.err != MP_OK) goto CLEANUP; \ 244 } while (0) 245 246 /* Compare value to zero. */ 247 static inline int 248 CMPZ(mp_int Z) 249 { 250 if (Z->used == 1 && Z->digits[0] == 0) 251 return 0; 252 return (Z->sign == MP_NEG) ? -1 : 1; 253 } 254 255 static inline mp_word 256 UPPER_HALF(mp_word W) 257 { 258 return (W >> MP_DIGIT_BIT); 259 } 260 static inline mp_digit 261 LOWER_HALF(mp_word W) 262 { 263 return (mp_digit) (W); 264 } 265 266 /* Report whether the highest-order bit of W is 1. */ 267 static inline bool 268 HIGH_BIT_SET(mp_word W) 269 { 270 return (W >> (MP_WORD_BIT - 1)) != 0; 271 } 272 273 /* Report whether adding W + V will carry out. */ 274 static inline bool 275 ADD_WILL_OVERFLOW(mp_word W, mp_word V) 276 { 277 return ((MP_WORD_MAX - V) < W); 278 } 279 280 /* Default number of digits allocated to a new mp_int */ 281 static mp_size default_precision = 8; 282 283 void 284 mp_int_default_precision(mp_size size) 285 { 286 assert(size > 0); 287 default_precision = size; 288 } 289 290 /* Minimum number of digits to invoke recursive multiply */ 291 static mp_size multiply_threshold = 32; 292 293 void 294 mp_int_multiply_threshold(mp_size thresh) 295 { 296 assert(thresh >= sizeof(mp_word)); 297 multiply_threshold = thresh; 298 } 299 300 /* Allocate a buffer of (at least) num digits, or return 301 NULL if that couldn't be done. */ 302 static mp_digit *s_alloc(mp_size num); 303 304 /* Release a buffer of digits allocated by s_alloc(). */ 305 static void s_free(void *ptr); 306 307 /* Insure that z has at least min digits allocated, resizing if 308 necessary. Returns true if successful, false if out of memory. */ 309 static bool s_pad(mp_int z, mp_size min); 310 311 /* Ensure Z has at least N digits allocated. */ 312 static inline mp_result 313 GROW(mp_int Z, mp_size N) 314 { 315 return s_pad(Z, N) ? MP_OK : MP_MEMORY; 316 } 317 318 /* Fill in a "fake" mp_int on the stack with a given value */ 319 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]); 320 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]); 321 322 /* Compare two runs of digits of given length, returns <0, 0, >0 */ 323 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len); 324 325 /* Pack the unsigned digits of v into array t */ 326 static int s_uvpack(mp_usmall v, mp_digit t[]); 327 328 /* Compare magnitudes of a and b, returns <0, 0, >0 */ 329 static int s_ucmp(mp_int a, mp_int b); 330 331 /* Compare magnitudes of a and v, returns <0, 0, >0 */ 332 static int s_vcmp(mp_int a, mp_small v); 333 static int s_uvcmp(mp_int a, mp_usmall uv); 334 335 /* Unsigned magnitude addition; assumes dc is big enough. 336 Carry out is returned (no memory allocated). */ 337 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 338 mp_size size_b); 339 340 /* Unsigned magnitude subtraction. Assumes dc is big enough. */ 341 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 342 mp_size size_b); 343 344 /* Unsigned recursive multiplication. Assumes dc is big enough. */ 345 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 346 mp_size size_b); 347 348 /* Unsigned magnitude multiplication. Assumes dc is big enough. */ 349 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 350 mp_size size_b); 351 352 /* Unsigned recursive squaring. Assumes dc is big enough. */ 353 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a); 354 355 /* Unsigned magnitude squaring. Assumes dc is big enough. */ 356 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a); 357 358 /* Single digit addition. Assumes a is big enough. */ 359 static void s_dadd(mp_int a, mp_digit b); 360 361 /* Single digit multiplication. Assumes a is big enough. */ 362 static void s_dmul(mp_int a, mp_digit b); 363 364 /* Single digit multiplication on buffers; assumes dc is big enough. */ 365 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a); 366 367 /* Single digit division. Replaces a with the quotient, 368 returns the remainder. */ 369 static mp_digit s_ddiv(mp_int a, mp_digit b); 370 371 /* Quick division by a power of 2, replaces z (no allocation) */ 372 static void s_qdiv(mp_int z, mp_size p2); 373 374 /* Quick remainder by a power of 2, replaces z (no allocation) */ 375 static void s_qmod(mp_int z, mp_size p2); 376 377 /* Quick multiplication by a power of 2, replaces z. 378 Allocates if necessary; returns false in case this fails. */ 379 static int s_qmul(mp_int z, mp_size p2); 380 381 /* Quick subtraction from a power of 2, replaces z. 382 Allocates if necessary; returns false in case this fails. */ 383 static int s_qsub(mp_int z, mp_size p2); 384 385 /* Return maximum k such that 2^k divides z. */ 386 static int s_dp2k(mp_int z); 387 388 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */ 389 static int s_isp2(mp_int z); 390 391 /* Set z to 2^k. May allocate; returns false in case this fails. */ 392 static int s_2expt(mp_int z, mp_small k); 393 394 /* Normalize a and b for division, returns normalization constant */ 395 static int s_norm(mp_int a, mp_int b); 396 397 /* Compute constant mu for Barrett reduction, given modulus m, result 398 replaces z, m is untouched. */ 399 static mp_result s_brmu(mp_int z, mp_int m); 400 401 /* Reduce a modulo m, using Barrett's algorithm. */ 402 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2); 403 404 /* Modular exponentiation, using Barrett reduction */ 405 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c); 406 407 /* Unsigned magnitude division. Assumes |a| > |b|. Allocates temporaries; 408 overwrites a with quotient, b with remainder. */ 409 static mp_result s_udiv_knuth(mp_int a, mp_int b); 410 411 /* Compute the number of digits in radix r required to represent the given 412 value. Does not account for sign flags, terminators, etc. */ 413 static int s_outlen(mp_int z, mp_size r); 414 415 /* Guess how many digits of precision will be needed to represent a radix r 416 value of the specified number of digits. Returns a value guaranteed to be 417 no smaller than the actual number required. */ 418 static mp_size s_inlen(int len, mp_size r); 419 420 /* Convert a character to a digit value in radix r, or 421 -1 if out of range */ 422 static int s_ch2val(char c, int r); 423 424 /* Convert a digit value to a character */ 425 static char s_val2ch(int v, int caps); 426 427 /* Take 2's complement of a buffer in place */ 428 static void s_2comp(unsigned char *buf, int len); 429 430 /* Convert a value to binary, ignoring sign. On input, *limpos is the bound on 431 how many bytes should be written to buf; on output, *limpos is set to the 432 number of bytes actually written. */ 433 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad); 434 435 /* Multiply X by Y into Z, ignoring signs. Requires that Z have enough storage 436 preallocated to hold the result. */ 437 static inline void 438 UMUL(mp_int X, mp_int Y, mp_int Z) 439 { 440 mp_size ua_ = MP_USED(X); 441 mp_size ub_ = MP_USED(Y); 442 mp_size o_ = ua_ + ub_; 443 444 ZERO(MP_DIGITS(Z), o_); 445 (void) s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_); 446 Z->used = o_; 447 CLAMP(Z); 448 } 449 450 /* Square X into Z. Requires that Z have enough storage to hold the result. */ 451 static inline void 452 USQR(mp_int X, mp_int Z) 453 { 454 mp_size ua_ = MP_USED(X); 455 mp_size o_ = ua_ + ua_; 456 457 ZERO(MP_DIGITS(Z), o_); 458 (void) s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_); 459 Z->used = o_; 460 CLAMP(Z); 461 } 462 463 mp_result 464 mp_int_init(mp_int z) 465 { 466 if (z == NULL) 467 return MP_BADARG; 468 469 z->single = 0; 470 z->digits = &(z->single); 471 z->alloc = 1; 472 z->used = 1; 473 z->sign = MP_ZPOS; 474 475 return MP_OK; 476 } 477 478 mp_int 479 mp_int_alloc(void) 480 { 481 mp_int out = px_alloc(sizeof(mpz_t)); 482 483 if (out != NULL) 484 mp_int_init(out); 485 486 return out; 487 } 488 489 mp_result 490 mp_int_init_size(mp_int z, mp_size prec) 491 { 492 assert(z != NULL); 493 494 if (prec == 0) 495 { 496 prec = default_precision; 497 } 498 else if (prec == 1) 499 { 500 return mp_int_init(z); 501 } 502 else 503 { 504 prec = s_round_prec(prec); 505 } 506 507 z->digits = s_alloc(prec); 508 if (MP_DIGITS(z) == NULL) 509 return MP_MEMORY; 510 511 z->digits[0] = 0; 512 z->used = 1; 513 z->alloc = prec; 514 z->sign = MP_ZPOS; 515 516 return MP_OK; 517 } 518 519 mp_result 520 mp_int_init_copy(mp_int z, mp_int old) 521 { 522 assert(z != NULL && old != NULL); 523 524 mp_size uold = MP_USED(old); 525 526 if (uold == 1) 527 { 528 mp_int_init(z); 529 } 530 else 531 { 532 mp_size target = MAX(uold, default_precision); 533 mp_result res = mp_int_init_size(z, target); 534 535 if (res != MP_OK) 536 return res; 537 } 538 539 z->used = uold; 540 z->sign = old->sign; 541 COPY(MP_DIGITS(old), MP_DIGITS(z), uold); 542 543 return MP_OK; 544 } 545 546 mp_result 547 mp_int_init_value(mp_int z, mp_small value) 548 { 549 mpz_t vtmp; 550 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 551 552 s_fake(&vtmp, value, vbuf); 553 return mp_int_init_copy(z, &vtmp); 554 } 555 556 mp_result 557 mp_int_init_uvalue(mp_int z, mp_usmall uvalue) 558 { 559 mpz_t vtmp; 560 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)]; 561 562 s_ufake(&vtmp, uvalue, vbuf); 563 return mp_int_init_copy(z, &vtmp); 564 } 565 566 mp_result 567 mp_int_set_value(mp_int z, mp_small value) 568 { 569 mpz_t vtmp; 570 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 571 572 s_fake(&vtmp, value, vbuf); 573 return mp_int_copy(&vtmp, z); 574 } 575 576 mp_result 577 mp_int_set_uvalue(mp_int z, mp_usmall uvalue) 578 { 579 mpz_t vtmp; 580 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)]; 581 582 s_ufake(&vtmp, uvalue, vbuf); 583 return mp_int_copy(&vtmp, z); 584 } 585 586 void 587 mp_int_clear(mp_int z) 588 { 589 if (z == NULL) 590 return; 591 592 if (MP_DIGITS(z) != NULL) 593 { 594 if (MP_DIGITS(z) != &(z->single)) 595 s_free(MP_DIGITS(z)); 596 597 z->digits = NULL; 598 } 599 } 600 601 void 602 mp_int_free(mp_int z) 603 { 604 assert(z != NULL); 605 606 mp_int_clear(z); 607 px_free(z); /* note: NOT s_free() */ 608 } 609 610 mp_result 611 mp_int_copy(mp_int a, mp_int c) 612 { 613 assert(a != NULL && c != NULL); 614 615 if (a != c) 616 { 617 mp_size ua = MP_USED(a); 618 mp_digit *da, 619 *dc; 620 621 if (!s_pad(c, ua)) 622 return MP_MEMORY; 623 624 da = MP_DIGITS(a); 625 dc = MP_DIGITS(c); 626 COPY(da, dc, ua); 627 628 c->used = ua; 629 c->sign = a->sign; 630 } 631 632 return MP_OK; 633 } 634 635 void 636 mp_int_swap(mp_int a, mp_int c) 637 { 638 if (a != c) 639 { 640 mpz_t tmp = *a; 641 642 *a = *c; 643 *c = tmp; 644 645 if (MP_DIGITS(a) == &(c->single)) 646 a->digits = &(a->single); 647 if (MP_DIGITS(c) == &(a->single)) 648 c->digits = &(c->single); 649 } 650 } 651 652 void 653 mp_int_zero(mp_int z) 654 { 655 assert(z != NULL); 656 657 z->digits[0] = 0; 658 z->used = 1; 659 z->sign = MP_ZPOS; 660 } 661 662 mp_result 663 mp_int_abs(mp_int a, mp_int c) 664 { 665 assert(a != NULL && c != NULL); 666 667 mp_result res; 668 669 if ((res = mp_int_copy(a, c)) != MP_OK) 670 return res; 671 672 c->sign = MP_ZPOS; 673 return MP_OK; 674 } 675 676 mp_result 677 mp_int_neg(mp_int a, mp_int c) 678 { 679 assert(a != NULL && c != NULL); 680 681 mp_result res; 682 683 if ((res = mp_int_copy(a, c)) != MP_OK) 684 return res; 685 686 if (CMPZ(c) != 0) 687 c->sign = 1 - MP_SIGN(a); 688 689 return MP_OK; 690 } 691 692 mp_result 693 mp_int_add(mp_int a, mp_int b, mp_int c) 694 { 695 assert(a != NULL && b != NULL && c != NULL); 696 697 mp_size ua = MP_USED(a); 698 mp_size ub = MP_USED(b); 699 mp_size max = MAX(ua, ub); 700 701 if (MP_SIGN(a) == MP_SIGN(b)) 702 { 703 /* Same sign -- add magnitudes, preserve sign of addends */ 704 if (!s_pad(c, max)) 705 return MP_MEMORY; 706 707 mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub); 708 mp_size uc = max; 709 710 if (carry) 711 { 712 if (!s_pad(c, max + 1)) 713 return MP_MEMORY; 714 715 c->digits[max] = carry; 716 ++uc; 717 } 718 719 c->used = uc; 720 c->sign = a->sign; 721 722 } 723 else 724 { 725 /* Different signs -- subtract magnitudes, preserve sign of greater */ 726 int cmp = s_ucmp(a, b); /* magnitude comparision, sign ignored */ 727 728 /* 729 * Set x to max(a, b), y to min(a, b) to simplify later code. A 730 * special case yields zero for equal magnitudes. 731 */ 732 mp_int x, 733 y; 734 735 if (cmp == 0) 736 { 737 mp_int_zero(c); 738 return MP_OK; 739 } 740 else if (cmp < 0) 741 { 742 x = b; 743 y = a; 744 } 745 else 746 { 747 x = a; 748 y = b; 749 } 750 751 if (!s_pad(c, MP_USED(x))) 752 return MP_MEMORY; 753 754 /* Subtract smaller from larger */ 755 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y)); 756 c->used = x->used; 757 CLAMP(c); 758 759 /* Give result the sign of the larger */ 760 c->sign = x->sign; 761 } 762 763 return MP_OK; 764 } 765 766 mp_result 767 mp_int_add_value(mp_int a, mp_small value, mp_int c) 768 { 769 mpz_t vtmp; 770 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 771 772 s_fake(&vtmp, value, vbuf); 773 774 return mp_int_add(a, &vtmp, c); 775 } 776 777 mp_result 778 mp_int_sub(mp_int a, mp_int b, mp_int c) 779 { 780 assert(a != NULL && b != NULL && c != NULL); 781 782 mp_size ua = MP_USED(a); 783 mp_size ub = MP_USED(b); 784 mp_size max = MAX(ua, ub); 785 786 if (MP_SIGN(a) != MP_SIGN(b)) 787 { 788 /* Different signs -- add magnitudes and keep sign of a */ 789 if (!s_pad(c, max)) 790 return MP_MEMORY; 791 792 mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub); 793 mp_size uc = max; 794 795 if (carry) 796 { 797 if (!s_pad(c, max + 1)) 798 return MP_MEMORY; 799 800 c->digits[max] = carry; 801 ++uc; 802 } 803 804 c->used = uc; 805 c->sign = a->sign; 806 807 } 808 else 809 { 810 /* Same signs -- subtract magnitudes */ 811 if (!s_pad(c, max)) 812 return MP_MEMORY; 813 mp_int x, 814 y; 815 mp_sign osign; 816 817 int cmp = s_ucmp(a, b); 818 819 if (cmp >= 0) 820 { 821 x = a; 822 y = b; 823 osign = MP_ZPOS; 824 } 825 else 826 { 827 x = b; 828 y = a; 829 osign = MP_NEG; 830 } 831 832 if (MP_SIGN(a) == MP_NEG && cmp != 0) 833 osign = 1 - osign; 834 835 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y)); 836 c->used = x->used; 837 CLAMP(c); 838 839 c->sign = osign; 840 } 841 842 return MP_OK; 843 } 844 845 mp_result 846 mp_int_sub_value(mp_int a, mp_small value, mp_int c) 847 { 848 mpz_t vtmp; 849 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 850 851 s_fake(&vtmp, value, vbuf); 852 853 return mp_int_sub(a, &vtmp, c); 854 } 855 856 mp_result 857 mp_int_mul(mp_int a, mp_int b, mp_int c) 858 { 859 assert(a != NULL && b != NULL && c != NULL); 860 861 /* If either input is zero, we can shortcut multiplication */ 862 if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0) 863 { 864 mp_int_zero(c); 865 return MP_OK; 866 } 867 868 /* Output is positive if inputs have same sign, otherwise negative */ 869 mp_sign osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG; 870 871 /* 872 * If the output is not identical to any of the inputs, we'll write the 873 * results directly; otherwise, allocate a temporary space. 874 */ 875 mp_size ua = MP_USED(a); 876 mp_size ub = MP_USED(b); 877 mp_size osize = MAX(ua, ub); 878 879 osize = 4 * ((osize + 1) / 2); 880 881 mp_digit *out; 882 mp_size p = 0; 883 884 if (c == a || c == b) 885 { 886 p = MAX(s_round_prec(osize), default_precision); 887 888 if ((out = s_alloc(p)) == NULL) 889 return MP_MEMORY; 890 } 891 else 892 { 893 if (!s_pad(c, osize)) 894 return MP_MEMORY; 895 896 out = MP_DIGITS(c); 897 } 898 ZERO(out, osize); 899 900 if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub)) 901 return MP_MEMORY; 902 903 /* 904 * If we allocated a new buffer, get rid of whatever memory c was already 905 * using, and fix up its fields to reflect that. 906 */ 907 if (out != MP_DIGITS(c)) 908 { 909 if ((void *) MP_DIGITS(c) != (void *) c) 910 s_free(MP_DIGITS(c)); 911 c->digits = out; 912 c->alloc = p; 913 } 914 915 c->used = osize; /* might not be true, but we'll fix it ... */ 916 CLAMP(c); /* ... right here */ 917 c->sign = osign; 918 919 return MP_OK; 920 } 921 922 mp_result 923 mp_int_mul_value(mp_int a, mp_small value, mp_int c) 924 { 925 mpz_t vtmp; 926 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 927 928 s_fake(&vtmp, value, vbuf); 929 930 return mp_int_mul(a, &vtmp, c); 931 } 932 933 mp_result 934 mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c) 935 { 936 assert(a != NULL && c != NULL && p2 >= 0); 937 938 mp_result res = mp_int_copy(a, c); 939 940 if (res != MP_OK) 941 return res; 942 943 if (s_qmul(c, (mp_size) p2)) 944 { 945 return MP_OK; 946 } 947 else 948 { 949 return MP_MEMORY; 950 } 951 } 952 953 mp_result 954 mp_int_sqr(mp_int a, mp_int c) 955 { 956 assert(a != NULL && c != NULL); 957 958 /* Get a temporary buffer big enough to hold the result */ 959 mp_size osize = (mp_size) 4 * ((MP_USED(a) + 1) / 2); 960 mp_size p = 0; 961 mp_digit *out; 962 963 if (a == c) 964 { 965 p = s_round_prec(osize); 966 p = MAX(p, default_precision); 967 968 if ((out = s_alloc(p)) == NULL) 969 return MP_MEMORY; 970 } 971 else 972 { 973 if (!s_pad(c, osize)) 974 return MP_MEMORY; 975 976 out = MP_DIGITS(c); 977 } 978 ZERO(out, osize); 979 980 s_ksqr(MP_DIGITS(a), out, MP_USED(a)); 981 982 /* 983 * Get rid of whatever memory c was already using, and fix up its fields 984 * to reflect the new digit array it's using 985 */ 986 if (out != MP_DIGITS(c)) 987 { 988 if ((void *) MP_DIGITS(c) != (void *) c) 989 s_free(MP_DIGITS(c)); 990 c->digits = out; 991 c->alloc = p; 992 } 993 994 c->used = osize; /* might not be true, but we'll fix it ... */ 995 CLAMP(c); /* ... right here */ 996 c->sign = MP_ZPOS; 997 998 return MP_OK; 999 } 1000 1001 mp_result 1002 mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r) 1003 { 1004 assert(a != NULL && b != NULL && q != r); 1005 1006 int cmp; 1007 mp_result res = MP_OK; 1008 mp_int qout, 1009 rout; 1010 mp_sign sa = MP_SIGN(a); 1011 mp_sign sb = MP_SIGN(b); 1012 1013 if (CMPZ(b) == 0) 1014 { 1015 return MP_UNDEF; 1016 } 1017 else if ((cmp = s_ucmp(a, b)) < 0) 1018 { 1019 /* 1020 * If |a| < |b|, no division is required: q = 0, r = a 1021 */ 1022 if (r && (res = mp_int_copy(a, r)) != MP_OK) 1023 return res; 1024 1025 if (q) 1026 mp_int_zero(q); 1027 1028 return MP_OK; 1029 } 1030 else if (cmp == 0) 1031 { 1032 /* 1033 * If |a| = |b|, no division is required: q = 1 or -1, r = 0 1034 */ 1035 if (r) 1036 mp_int_zero(r); 1037 1038 if (q) 1039 { 1040 mp_int_zero(q); 1041 q->digits[0] = 1; 1042 1043 if (sa != sb) 1044 q->sign = MP_NEG; 1045 } 1046 1047 return MP_OK; 1048 } 1049 1050 /* 1051 * When |a| > |b|, real division is required. We need someplace to store 1052 * quotient and remainder, but q and r are allowed to be NULL or to 1053 * overlap with the inputs. 1054 */ 1055 DECLARE_TEMP(2); 1056 int lg; 1057 1058 if ((lg = s_isp2(b)) < 0) 1059 { 1060 if (q && b != q) 1061 { 1062 REQUIRE(mp_int_copy(a, q)); 1063 qout = q; 1064 } 1065 else 1066 { 1067 REQUIRE(mp_int_copy(a, TEMP(0))); 1068 qout = TEMP(0); 1069 } 1070 1071 if (r && a != r) 1072 { 1073 REQUIRE(mp_int_copy(b, r)); 1074 rout = r; 1075 } 1076 else 1077 { 1078 REQUIRE(mp_int_copy(b, TEMP(1))); 1079 rout = TEMP(1); 1080 } 1081 1082 REQUIRE(s_udiv_knuth(qout, rout)); 1083 } 1084 else 1085 { 1086 if (q) 1087 REQUIRE(mp_int_copy(a, q)); 1088 if (r) 1089 REQUIRE(mp_int_copy(a, r)); 1090 1091 if (q) 1092 s_qdiv(q, (mp_size) lg); 1093 qout = q; 1094 if (r) 1095 s_qmod(r, (mp_size) lg); 1096 rout = r; 1097 } 1098 1099 /* Recompute signs for output */ 1100 if (rout) 1101 { 1102 rout->sign = sa; 1103 if (CMPZ(rout) == 0) 1104 rout->sign = MP_ZPOS; 1105 } 1106 if (qout) 1107 { 1108 qout->sign = (sa == sb) ? MP_ZPOS : MP_NEG; 1109 if (CMPZ(qout) == 0) 1110 qout->sign = MP_ZPOS; 1111 } 1112 1113 if (q) 1114 REQUIRE(mp_int_copy(qout, q)); 1115 if (r) 1116 REQUIRE(mp_int_copy(rout, r)); 1117 CLEANUP_TEMP(); 1118 return res; 1119 } 1120 1121 mp_result 1122 mp_int_mod(mp_int a, mp_int m, mp_int c) 1123 { 1124 DECLARE_TEMP(1); 1125 mp_int out = (m == c) ? TEMP(0) : c; 1126 1127 REQUIRE(mp_int_div(a, m, NULL, out)); 1128 if (CMPZ(out) < 0) 1129 { 1130 REQUIRE(mp_int_add(out, m, c)); 1131 } 1132 else 1133 { 1134 REQUIRE(mp_int_copy(out, c)); 1135 } 1136 CLEANUP_TEMP(); 1137 return MP_OK; 1138 } 1139 1140 mp_result 1141 mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r) 1142 { 1143 mpz_t vtmp; 1144 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 1145 1146 s_fake(&vtmp, value, vbuf); 1147 1148 DECLARE_TEMP(1); 1149 REQUIRE(mp_int_div(a, &vtmp, q, TEMP(0))); 1150 1151 if (r) 1152 (void) mp_int_to_int(TEMP(0), r); /* can't fail */ 1153 1154 CLEANUP_TEMP(); 1155 return MP_OK; 1156 } 1157 1158 mp_result 1159 mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r) 1160 { 1161 assert(a != NULL && p2 >= 0 && q != r); 1162 1163 mp_result res = MP_OK; 1164 1165 if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK) 1166 { 1167 s_qdiv(q, (mp_size) p2); 1168 } 1169 1170 if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK) 1171 { 1172 s_qmod(r, (mp_size) p2); 1173 } 1174 1175 return res; 1176 } 1177 1178 mp_result 1179 mp_int_expt(mp_int a, mp_small b, mp_int c) 1180 { 1181 assert(c != NULL); 1182 if (b < 0) 1183 return MP_RANGE; 1184 1185 DECLARE_TEMP(1); 1186 REQUIRE(mp_int_copy(a, TEMP(0))); 1187 1188 (void) mp_int_set_value(c, 1); 1189 unsigned int v = labs(b); 1190 1191 while (v != 0) 1192 { 1193 if (v & 1) 1194 { 1195 REQUIRE(mp_int_mul(c, TEMP(0), c)); 1196 } 1197 1198 v >>= 1; 1199 if (v == 0) 1200 break; 1201 1202 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0))); 1203 } 1204 1205 CLEANUP_TEMP(); 1206 return MP_OK; 1207 } 1208 1209 mp_result 1210 mp_int_expt_value(mp_small a, mp_small b, mp_int c) 1211 { 1212 assert(c != NULL); 1213 if (b < 0) 1214 return MP_RANGE; 1215 1216 DECLARE_TEMP(1); 1217 REQUIRE(mp_int_set_value(TEMP(0), a)); 1218 1219 (void) mp_int_set_value(c, 1); 1220 unsigned int v = labs(b); 1221 1222 while (v != 0) 1223 { 1224 if (v & 1) 1225 { 1226 REQUIRE(mp_int_mul(c, TEMP(0), c)); 1227 } 1228 1229 v >>= 1; 1230 if (v == 0) 1231 break; 1232 1233 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0))); 1234 } 1235 1236 CLEANUP_TEMP(); 1237 return MP_OK; 1238 } 1239 1240 mp_result 1241 mp_int_expt_full(mp_int a, mp_int b, mp_int c) 1242 { 1243 assert(a != NULL && b != NULL && c != NULL); 1244 if (MP_SIGN(b) == MP_NEG) 1245 return MP_RANGE; 1246 1247 DECLARE_TEMP(1); 1248 REQUIRE(mp_int_copy(a, TEMP(0))); 1249 1250 (void) mp_int_set_value(c, 1); 1251 for (unsigned ix = 0; ix < MP_USED(b); ++ix) 1252 { 1253 mp_digit d = b->digits[ix]; 1254 1255 for (unsigned jx = 0; jx < MP_DIGIT_BIT; ++jx) 1256 { 1257 if (d & 1) 1258 { 1259 REQUIRE(mp_int_mul(c, TEMP(0), c)); 1260 } 1261 1262 d >>= 1; 1263 if (d == 0 && ix + 1 == MP_USED(b)) 1264 break; 1265 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0))); 1266 } 1267 } 1268 1269 CLEANUP_TEMP(); 1270 return MP_OK; 1271 } 1272 1273 int 1274 mp_int_compare(mp_int a, mp_int b) 1275 { 1276 assert(a != NULL && b != NULL); 1277 1278 mp_sign sa = MP_SIGN(a); 1279 1280 if (sa == MP_SIGN(b)) 1281 { 1282 int cmp = s_ucmp(a, b); 1283 1284 /* 1285 * If they're both zero or positive, the normal comparison applies; if 1286 * both negative, the sense is reversed. 1287 */ 1288 if (sa == MP_ZPOS) 1289 { 1290 return cmp; 1291 } 1292 else 1293 { 1294 return -cmp; 1295 } 1296 } 1297 else if (sa == MP_ZPOS) 1298 { 1299 return 1; 1300 } 1301 else 1302 { 1303 return -1; 1304 } 1305 } 1306 1307 int 1308 mp_int_compare_unsigned(mp_int a, mp_int b) 1309 { 1310 assert(a != NULL && b != NULL); 1311 1312 return s_ucmp(a, b); 1313 } 1314 1315 int 1316 mp_int_compare_zero(mp_int z) 1317 { 1318 assert(z != NULL); 1319 1320 if (MP_USED(z) == 1 && z->digits[0] == 0) 1321 { 1322 return 0; 1323 } 1324 else if (MP_SIGN(z) == MP_ZPOS) 1325 { 1326 return 1; 1327 } 1328 else 1329 { 1330 return -1; 1331 } 1332 } 1333 1334 int 1335 mp_int_compare_value(mp_int z, mp_small value) 1336 { 1337 assert(z != NULL); 1338 1339 mp_sign vsign = (value < 0) ? MP_NEG : MP_ZPOS; 1340 1341 if (vsign == MP_SIGN(z)) 1342 { 1343 int cmp = s_vcmp(z, value); 1344 1345 return (vsign == MP_ZPOS) ? cmp : -cmp; 1346 } 1347 else 1348 { 1349 return (value < 0) ? 1 : -1; 1350 } 1351 } 1352 1353 int 1354 mp_int_compare_uvalue(mp_int z, mp_usmall uv) 1355 { 1356 assert(z != NULL); 1357 1358 if (MP_SIGN(z) == MP_NEG) 1359 { 1360 return -1; 1361 } 1362 else 1363 { 1364 return s_uvcmp(z, uv); 1365 } 1366 } 1367 1368 mp_result 1369 mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c) 1370 { 1371 assert(a != NULL && b != NULL && c != NULL && m != NULL); 1372 1373 /* Zero moduli and negative exponents are not considered. */ 1374 if (CMPZ(m) == 0) 1375 return MP_UNDEF; 1376 if (CMPZ(b) < 0) 1377 return MP_RANGE; 1378 1379 mp_size um = MP_USED(m); 1380 1381 DECLARE_TEMP(3); 1382 REQUIRE(GROW(TEMP(0), 2 * um)); 1383 REQUIRE(GROW(TEMP(1), 2 * um)); 1384 1385 mp_int s; 1386 1387 if (c == b || c == m) 1388 { 1389 REQUIRE(GROW(TEMP(2), 2 * um)); 1390 s = TEMP(2); 1391 } 1392 else 1393 { 1394 s = c; 1395 } 1396 1397 REQUIRE(mp_int_mod(a, m, TEMP(0))); 1398 REQUIRE(s_brmu(TEMP(1), m)); 1399 REQUIRE(s_embar(TEMP(0), b, m, TEMP(1), s)); 1400 REQUIRE(mp_int_copy(s, c)); 1401 1402 CLEANUP_TEMP(); 1403 return MP_OK; 1404 } 1405 1406 mp_result 1407 mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c) 1408 { 1409 mpz_t vtmp; 1410 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 1411 1412 s_fake(&vtmp, value, vbuf); 1413 1414 return mp_int_exptmod(a, &vtmp, m, c); 1415 } 1416 1417 mp_result 1418 mp_int_exptmod_bvalue(mp_small value, mp_int b, mp_int m, mp_int c) 1419 { 1420 mpz_t vtmp; 1421 mp_digit vbuf[MP_VALUE_DIGITS(value)]; 1422 1423 s_fake(&vtmp, value, vbuf); 1424 1425 return mp_int_exptmod(&vtmp, b, m, c); 1426 } 1427 1428 mp_result 1429 mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu, 1430 mp_int c) 1431 { 1432 assert(a && b && m && c); 1433 1434 /* Zero moduli and negative exponents are not considered. */ 1435 if (CMPZ(m) == 0) 1436 return MP_UNDEF; 1437 if (CMPZ(b) < 0) 1438 return MP_RANGE; 1439 1440 DECLARE_TEMP(2); 1441 mp_size um = MP_USED(m); 1442 1443 REQUIRE(GROW(TEMP(0), 2 * um)); 1444 1445 mp_int s; 1446 1447 if (c == b || c == m) 1448 { 1449 REQUIRE(GROW(TEMP(1), 2 * um)); 1450 s = TEMP(1); 1451 } 1452 else 1453 { 1454 s = c; 1455 } 1456 1457 REQUIRE(mp_int_mod(a, m, TEMP(0))); 1458 REQUIRE(s_embar(TEMP(0), b, m, mu, s)); 1459 REQUIRE(mp_int_copy(s, c)); 1460 1461 CLEANUP_TEMP(); 1462 return MP_OK; 1463 } 1464 1465 mp_result 1466 mp_int_redux_const(mp_int m, mp_int c) 1467 { 1468 assert(m != NULL && c != NULL && m != c); 1469 1470 return s_brmu(c, m); 1471 } 1472 1473 mp_result 1474 mp_int_invmod(mp_int a, mp_int m, mp_int c) 1475 { 1476 assert(a != NULL && m != NULL && c != NULL); 1477 1478 if (CMPZ(a) == 0 || CMPZ(m) <= 0) 1479 return MP_RANGE; 1480 1481 DECLARE_TEMP(2); 1482 1483 REQUIRE(mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL)); 1484 1485 if (mp_int_compare_value(TEMP(0), 1) != 0) 1486 { 1487 REQUIRE(MP_UNDEF); 1488 } 1489 1490 /* It is first necessary to constrain the value to the proper range */ 1491 REQUIRE(mp_int_mod(TEMP(1), m, TEMP(1))); 1492 1493 /* 1494 * Now, if 'a' was originally negative, the value we have is actually the 1495 * magnitude of the negative representative; to get the positive value we 1496 * have to subtract from the modulus. Otherwise, the value is okay as it 1497 * stands. 1498 */ 1499 if (MP_SIGN(a) == MP_NEG) 1500 { 1501 REQUIRE(mp_int_sub(m, TEMP(1), c)); 1502 } 1503 else 1504 { 1505 REQUIRE(mp_int_copy(TEMP(1), c)); 1506 } 1507 1508 CLEANUP_TEMP(); 1509 return MP_OK; 1510 } 1511 1512 /* Binary GCD algorithm due to Josef Stein, 1961 */ 1513 mp_result 1514 mp_int_gcd(mp_int a, mp_int b, mp_int c) 1515 { 1516 assert(a != NULL && b != NULL && c != NULL); 1517 1518 int ca = CMPZ(a); 1519 int cb = CMPZ(b); 1520 1521 if (ca == 0 && cb == 0) 1522 { 1523 return MP_UNDEF; 1524 } 1525 else if (ca == 0) 1526 { 1527 return mp_int_abs(b, c); 1528 } 1529 else if (cb == 0) 1530 { 1531 return mp_int_abs(a, c); 1532 } 1533 1534 DECLARE_TEMP(3); 1535 REQUIRE(mp_int_copy(a, TEMP(0))); 1536 REQUIRE(mp_int_copy(b, TEMP(1))); 1537 1538 TEMP(0)->sign = MP_ZPOS; 1539 TEMP(1)->sign = MP_ZPOS; 1540 1541 int k = 0; 1542 1543 { /* Divide out common factors of 2 from u and v */ 1544 int div2_u = s_dp2k(TEMP(0)); 1545 int div2_v = s_dp2k(TEMP(1)); 1546 1547 k = MIN(div2_u, div2_v); 1548 s_qdiv(TEMP(0), (mp_size) k); 1549 s_qdiv(TEMP(1), (mp_size) k); 1550 } 1551 1552 if (mp_int_is_odd(TEMP(0))) 1553 { 1554 REQUIRE(mp_int_neg(TEMP(1), TEMP(2))); 1555 } 1556 else 1557 { 1558 REQUIRE(mp_int_copy(TEMP(0), TEMP(2))); 1559 } 1560 1561 for (;;) 1562 { 1563 s_qdiv(TEMP(2), s_dp2k(TEMP(2))); 1564 1565 if (CMPZ(TEMP(2)) > 0) 1566 { 1567 REQUIRE(mp_int_copy(TEMP(2), TEMP(0))); 1568 } 1569 else 1570 { 1571 REQUIRE(mp_int_neg(TEMP(2), TEMP(1))); 1572 } 1573 1574 REQUIRE(mp_int_sub(TEMP(0), TEMP(1), TEMP(2))); 1575 1576 if (CMPZ(TEMP(2)) == 0) 1577 break; 1578 } 1579 1580 REQUIRE(mp_int_abs(TEMP(0), c)); 1581 if (!s_qmul(c, (mp_size) k)) 1582 REQUIRE(MP_MEMORY); 1583 1584 CLEANUP_TEMP(); 1585 return MP_OK; 1586 } 1587 1588 /* This is the binary GCD algorithm again, but this time we keep track of the 1589 elementary matrix operations as we go, so we can get values x and y 1590 satisfying c = ax + by. 1591 */ 1592 mp_result 1593 mp_int_egcd(mp_int a, mp_int b, mp_int c, mp_int x, mp_int y) 1594 { 1595 assert(a != NULL && b != NULL && c != NULL && (x != NULL || y != NULL)); 1596 1597 mp_result res = MP_OK; 1598 int ca = CMPZ(a); 1599 int cb = CMPZ(b); 1600 1601 if (ca == 0 && cb == 0) 1602 { 1603 return MP_UNDEF; 1604 } 1605 else if (ca == 0) 1606 { 1607 if ((res = mp_int_abs(b, c)) != MP_OK) 1608 return res; 1609 mp_int_zero(x); 1610 (void) mp_int_set_value(y, 1); 1611 return MP_OK; 1612 } 1613 else if (cb == 0) 1614 { 1615 if ((res = mp_int_abs(a, c)) != MP_OK) 1616 return res; 1617 (void) mp_int_set_value(x, 1); 1618 mp_int_zero(y); 1619 return MP_OK; 1620 } 1621 1622 /* 1623 * Initialize temporaries: A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7 1624 */ 1625 DECLARE_TEMP(8); 1626 REQUIRE(mp_int_set_value(TEMP(0), 1)); 1627 REQUIRE(mp_int_set_value(TEMP(3), 1)); 1628 REQUIRE(mp_int_copy(a, TEMP(4))); 1629 REQUIRE(mp_int_copy(b, TEMP(5))); 1630 1631 /* We will work with absolute values here */ 1632 TEMP(4)->sign = MP_ZPOS; 1633 TEMP(5)->sign = MP_ZPOS; 1634 1635 int k = 0; 1636 1637 { /* Divide out common factors of 2 from u and v */ 1638 int div2_u = s_dp2k(TEMP(4)), 1639 div2_v = s_dp2k(TEMP(5)); 1640 1641 k = MIN(div2_u, div2_v); 1642 s_qdiv(TEMP(4), k); 1643 s_qdiv(TEMP(5), k); 1644 } 1645 1646 REQUIRE(mp_int_copy(TEMP(4), TEMP(6))); 1647 REQUIRE(mp_int_copy(TEMP(5), TEMP(7))); 1648 1649 for (;;) 1650 { 1651 while (mp_int_is_even(TEMP(4))) 1652 { 1653 s_qdiv(TEMP(4), 1); 1654 1655 if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1))) 1656 { 1657 REQUIRE(mp_int_add(TEMP(0), TEMP(7), TEMP(0))); 1658 REQUIRE(mp_int_sub(TEMP(1), TEMP(6), TEMP(1))); 1659 } 1660 1661 s_qdiv(TEMP(0), 1); 1662 s_qdiv(TEMP(1), 1); 1663 } 1664 1665 while (mp_int_is_even(TEMP(5))) 1666 { 1667 s_qdiv(TEMP(5), 1); 1668 1669 if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3))) 1670 { 1671 REQUIRE(mp_int_add(TEMP(2), TEMP(7), TEMP(2))); 1672 REQUIRE(mp_int_sub(TEMP(3), TEMP(6), TEMP(3))); 1673 } 1674 1675 s_qdiv(TEMP(2), 1); 1676 s_qdiv(TEMP(3), 1); 1677 } 1678 1679 if (mp_int_compare(TEMP(4), TEMP(5)) >= 0) 1680 { 1681 REQUIRE(mp_int_sub(TEMP(4), TEMP(5), TEMP(4))); 1682 REQUIRE(mp_int_sub(TEMP(0), TEMP(2), TEMP(0))); 1683 REQUIRE(mp_int_sub(TEMP(1), TEMP(3), TEMP(1))); 1684 } 1685 else 1686 { 1687 REQUIRE(mp_int_sub(TEMP(5), TEMP(4), TEMP(5))); 1688 REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2))); 1689 REQUIRE(mp_int_sub(TEMP(3), TEMP(1), TEMP(3))); 1690 } 1691 1692 if (CMPZ(TEMP(4)) == 0) 1693 { 1694 if (x) 1695 REQUIRE(mp_int_copy(TEMP(2), x)); 1696 if (y) 1697 REQUIRE(mp_int_copy(TEMP(3), y)); 1698 if (c) 1699 { 1700 if (!s_qmul(TEMP(5), k)) 1701 { 1702 REQUIRE(MP_MEMORY); 1703 } 1704 REQUIRE(mp_int_copy(TEMP(5), c)); 1705 } 1706 1707 break; 1708 } 1709 } 1710 1711 CLEANUP_TEMP(); 1712 return MP_OK; 1713 } 1714 1715 mp_result 1716 mp_int_lcm(mp_int a, mp_int b, mp_int c) 1717 { 1718 assert(a != NULL && b != NULL && c != NULL); 1719 1720 /* 1721 * Since a * b = gcd(a, b) * lcm(a, b), we can compute lcm(a, b) = (a / 1722 * gcd(a, b)) * b. 1723 * 1724 * This formulation insures everything works even if the input variables 1725 * share space. 1726 */ 1727 DECLARE_TEMP(1); 1728 REQUIRE(mp_int_gcd(a, b, TEMP(0))); 1729 REQUIRE(mp_int_div(a, TEMP(0), TEMP(0), NULL)); 1730 REQUIRE(mp_int_mul(TEMP(0), b, TEMP(0))); 1731 REQUIRE(mp_int_copy(TEMP(0), c)); 1732 1733 CLEANUP_TEMP(); 1734 return MP_OK; 1735 } 1736 1737 bool 1738 mp_int_divisible_value(mp_int a, mp_small v) 1739 { 1740 mp_small rem = 0; 1741 1742 if (mp_int_div_value(a, v, NULL, &rem) != MP_OK) 1743 { 1744 return false; 1745 } 1746 return rem == 0; 1747 } 1748 1749 int 1750 mp_int_is_pow2(mp_int z) 1751 { 1752 assert(z != NULL); 1753 1754 return s_isp2(z); 1755 } 1756 1757 /* Implementation of Newton's root finding method, based loosely on a patch 1758 contributed by Hal Finkel <half@halssoftware.com> 1759 modified by M. J. Fromberger. 1760 */ 1761 mp_result 1762 mp_int_root(mp_int a, mp_small b, mp_int c) 1763 { 1764 assert(a != NULL && c != NULL && b > 0); 1765 1766 if (b == 1) 1767 { 1768 return mp_int_copy(a, c); 1769 } 1770 bool flips = false; 1771 1772 if (MP_SIGN(a) == MP_NEG) 1773 { 1774 if (b % 2 == 0) 1775 { 1776 return MP_UNDEF; /* root does not exist for negative a with 1777 * even b */ 1778 } 1779 else 1780 { 1781 flips = true; 1782 } 1783 } 1784 1785 DECLARE_TEMP(5); 1786 REQUIRE(mp_int_copy(a, TEMP(0))); 1787 REQUIRE(mp_int_copy(a, TEMP(1))); 1788 TEMP(0)->sign = MP_ZPOS; 1789 TEMP(1)->sign = MP_ZPOS; 1790 1791 for (;;) 1792 { 1793 REQUIRE(mp_int_expt(TEMP(1), b, TEMP(2))); 1794 1795 if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0) 1796 break; 1797 1798 REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2))); 1799 REQUIRE(mp_int_expt(TEMP(1), b - 1, TEMP(3))); 1800 REQUIRE(mp_int_mul_value(TEMP(3), b, TEMP(3))); 1801 REQUIRE(mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL)); 1802 REQUIRE(mp_int_sub(TEMP(1), TEMP(4), TEMP(4))); 1803 1804 if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0) 1805 { 1806 REQUIRE(mp_int_sub_value(TEMP(4), 1, TEMP(4))); 1807 } 1808 REQUIRE(mp_int_copy(TEMP(4), TEMP(1))); 1809 } 1810 1811 REQUIRE(mp_int_copy(TEMP(1), c)); 1812 1813 /* If the original value of a was negative, flip the output sign. */ 1814 if (flips) 1815 (void) mp_int_neg(c, c); /* cannot fail */ 1816 1817 CLEANUP_TEMP(); 1818 return MP_OK; 1819 } 1820 1821 mp_result 1822 mp_int_to_int(mp_int z, mp_small *out) 1823 { 1824 assert(z != NULL); 1825 1826 /* Make sure the value is representable as a small integer */ 1827 mp_sign sz = MP_SIGN(z); 1828 1829 if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) || 1830 mp_int_compare_value(z, MP_SMALL_MIN) < 0) 1831 { 1832 return MP_RANGE; 1833 } 1834 1835 mp_usmall uz = MP_USED(z); 1836 mp_digit *dz = MP_DIGITS(z) + uz - 1; 1837 mp_small uv = 0; 1838 1839 while (uz > 0) 1840 { 1841 uv <<= MP_DIGIT_BIT / 2; 1842 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--; 1843 --uz; 1844 } 1845 1846 if (out) 1847 *out = (mp_small) ((sz == MP_NEG) ? -uv : uv); 1848 1849 return MP_OK; 1850 } 1851 1852 mp_result 1853 mp_int_to_uint(mp_int z, mp_usmall *out) 1854 { 1855 assert(z != NULL); 1856 1857 /* Make sure the value is representable as an unsigned small integer */ 1858 mp_size sz = MP_SIGN(z); 1859 1860 if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0) 1861 { 1862 return MP_RANGE; 1863 } 1864 1865 mp_size uz = MP_USED(z); 1866 mp_digit *dz = MP_DIGITS(z) + uz - 1; 1867 mp_usmall uv = 0; 1868 1869 while (uz > 0) 1870 { 1871 uv <<= MP_DIGIT_BIT / 2; 1872 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--; 1873 --uz; 1874 } 1875 1876 if (out) 1877 *out = uv; 1878 1879 return MP_OK; 1880 } 1881 1882 mp_result 1883 mp_int_to_string(mp_int z, mp_size radix, char *str, int limit) 1884 { 1885 assert(z != NULL && str != NULL && limit >= 2); 1886 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX); 1887 1888 int cmp = 0; 1889 1890 if (CMPZ(z) == 0) 1891 { 1892 *str++ = s_val2ch(0, 1); 1893 } 1894 else 1895 { 1896 mp_result res; 1897 mpz_t tmp; 1898 char *h, 1899 *t; 1900 1901 if ((res = mp_int_init_copy(&tmp, z)) != MP_OK) 1902 return res; 1903 1904 if (MP_SIGN(z) == MP_NEG) 1905 { 1906 *str++ = '-'; 1907 --limit; 1908 } 1909 h = str; 1910 1911 /* Generate digits in reverse order until finished or limit reached */ 1912 for ( /* */ ; limit > 0; --limit) 1913 { 1914 mp_digit d; 1915 1916 if ((cmp = CMPZ(&tmp)) == 0) 1917 break; 1918 1919 d = s_ddiv(&tmp, (mp_digit) radix); 1920 *str++ = s_val2ch(d, 1); 1921 } 1922 t = str - 1; 1923 1924 /* Put digits back in correct output order */ 1925 while (h < t) 1926 { 1927 char tc = *h; 1928 1929 *h++ = *t; 1930 *t-- = tc; 1931 } 1932 1933 mp_int_clear(&tmp); 1934 } 1935 1936 *str = '\0'; 1937 if (cmp == 0) 1938 { 1939 return MP_OK; 1940 } 1941 else 1942 { 1943 return MP_TRUNC; 1944 } 1945 } 1946 1947 mp_result 1948 mp_int_string_len(mp_int z, mp_size radix) 1949 { 1950 assert(z != NULL); 1951 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX); 1952 1953 int len = s_outlen(z, radix) + 1; /* for terminator */ 1954 1955 /* Allow for sign marker on negatives */ 1956 if (MP_SIGN(z) == MP_NEG) 1957 len += 1; 1958 1959 return len; 1960 } 1961 1962 /* Read zero-terminated string into z */ 1963 mp_result 1964 mp_int_read_string(mp_int z, mp_size radix, const char *str) 1965 { 1966 return mp_int_read_cstring(z, radix, str, NULL); 1967 } 1968 1969 mp_result 1970 mp_int_read_cstring(mp_int z, mp_size radix, const char *str, 1971 char **end) 1972 { 1973 assert(z != NULL && str != NULL); 1974 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX); 1975 1976 /* Skip leading whitespace */ 1977 while (isspace((unsigned char) *str)) 1978 ++str; 1979 1980 /* Handle leading sign tag (+/-, positive default) */ 1981 switch (*str) 1982 { 1983 case '-': 1984 z->sign = MP_NEG; 1985 ++str; 1986 break; 1987 case '+': 1988 ++str; /* fallthrough */ 1989 default: 1990 z->sign = MP_ZPOS; 1991 break; 1992 } 1993 1994 /* Skip leading zeroes */ 1995 int ch; 1996 1997 while ((ch = s_ch2val(*str, radix)) == 0) 1998 ++str; 1999 2000 /* Make sure there is enough space for the value */ 2001 if (!s_pad(z, s_inlen(strlen(str), radix))) 2002 return MP_MEMORY; 2003 2004 z->used = 1; 2005 z->digits[0] = 0; 2006 2007 while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0)) 2008 { 2009 s_dmul(z, (mp_digit) radix); 2010 s_dadd(z, (mp_digit) ch); 2011 ++str; 2012 } 2013 2014 CLAMP(z); 2015 2016 /* Override sign for zero, even if negative specified. */ 2017 if (CMPZ(z) == 0) 2018 z->sign = MP_ZPOS; 2019 2020 if (end != NULL) 2021 *end = unconstify(char *, str); 2022 2023 /* 2024 * Return a truncation error if the string has unprocessed characters 2025 * remaining, so the caller can tell if the whole string was done 2026 */ 2027 if (*str != '\0') 2028 { 2029 return MP_TRUNC; 2030 } 2031 else 2032 { 2033 return MP_OK; 2034 } 2035 } 2036 2037 mp_result 2038 mp_int_count_bits(mp_int z) 2039 { 2040 assert(z != NULL); 2041 2042 mp_size uz = MP_USED(z); 2043 2044 if (uz == 1 && z->digits[0] == 0) 2045 return 1; 2046 2047 --uz; 2048 mp_size nbits = uz * MP_DIGIT_BIT; 2049 mp_digit d = z->digits[uz]; 2050 2051 while (d != 0) 2052 { 2053 d >>= 1; 2054 ++nbits; 2055 } 2056 2057 return nbits; 2058 } 2059 2060 mp_result 2061 mp_int_to_binary(mp_int z, unsigned char *buf, int limit) 2062 { 2063 static const int PAD_FOR_2C = 1; 2064 2065 assert(z != NULL && buf != NULL); 2066 2067 int limpos = limit; 2068 mp_result res = s_tobin(z, buf, &limpos, PAD_FOR_2C); 2069 2070 if (MP_SIGN(z) == MP_NEG) 2071 s_2comp(buf, limpos); 2072 2073 return res; 2074 } 2075 2076 mp_result 2077 mp_int_read_binary(mp_int z, unsigned char *buf, int len) 2078 { 2079 assert(z != NULL && buf != NULL && len > 0); 2080 2081 /* Figure out how many digits are needed to represent this value */ 2082 mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT; 2083 2084 if (!s_pad(z, need)) 2085 return MP_MEMORY; 2086 2087 mp_int_zero(z); 2088 2089 /* 2090 * If the high-order bit is set, take the 2's complement before reading 2091 * the value (it will be restored afterward) 2092 */ 2093 if (buf[0] >> (CHAR_BIT - 1)) 2094 { 2095 z->sign = MP_NEG; 2096 s_2comp(buf, len); 2097 } 2098 2099 mp_digit *dz = MP_DIGITS(z); 2100 unsigned char *tmp = buf; 2101 2102 for (int i = len; i > 0; --i, ++tmp) 2103 { 2104 s_qmul(z, (mp_size) CHAR_BIT); 2105 *dz |= *tmp; 2106 } 2107 2108 /* Restore 2's complement if we took it before */ 2109 if (MP_SIGN(z) == MP_NEG) 2110 s_2comp(buf, len); 2111 2112 return MP_OK; 2113 } 2114 2115 mp_result 2116 mp_int_binary_len(mp_int z) 2117 { 2118 mp_result res = mp_int_count_bits(z); 2119 2120 if (res <= 0) 2121 return res; 2122 2123 int bytes = mp_int_unsigned_len(z); 2124 2125 /* 2126 * If the highest-order bit falls exactly on a byte boundary, we need to 2127 * pad with an extra byte so that the sign will be read correctly when 2128 * reading it back in. 2129 */ 2130 if (bytes * CHAR_BIT == res) 2131 ++bytes; 2132 2133 return bytes; 2134 } 2135 2136 mp_result 2137 mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit) 2138 { 2139 static const int NO_PADDING = 0; 2140 2141 assert(z != NULL && buf != NULL); 2142 2143 return s_tobin(z, buf, &limit, NO_PADDING); 2144 } 2145 2146 mp_result 2147 mp_int_read_unsigned(mp_int z, unsigned char *buf, int len) 2148 { 2149 assert(z != NULL && buf != NULL && len > 0); 2150 2151 /* Figure out how many digits are needed to represent this value */ 2152 mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT; 2153 2154 if (!s_pad(z, need)) 2155 return MP_MEMORY; 2156 2157 mp_int_zero(z); 2158 2159 unsigned char *tmp = buf; 2160 2161 for (int i = len; i > 0; --i, ++tmp) 2162 { 2163 (void) s_qmul(z, CHAR_BIT); 2164 *MP_DIGITS(z) |= *tmp; 2165 } 2166 2167 return MP_OK; 2168 } 2169 2170 mp_result 2171 mp_int_unsigned_len(mp_int z) 2172 { 2173 mp_result res = mp_int_count_bits(z); 2174 2175 if (res <= 0) 2176 return res; 2177 2178 int bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT; 2179 2180 return bytes; 2181 } 2182 2183 const char * 2184 mp_error_string(mp_result res) 2185 { 2186 if (res > 0) 2187 return s_unknown_err; 2188 2189 res = -res; 2190 int ix; 2191 2192 for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix) 2193 ; 2194 2195 if (s_error_msg[ix] != NULL) 2196 { 2197 return s_error_msg[ix]; 2198 } 2199 else 2200 { 2201 return s_unknown_err; 2202 } 2203 } 2204 2205 /*------------------------------------------------------------------------*/ 2206 /* Private functions for internal use. These make assumptions. */ 2207 2208 #if IMATH_DEBUG 2209 static const mp_digit fill = (mp_digit) 0xdeadbeefabad1dea; 2210 #endif 2211 2212 static mp_digit * 2213 s_alloc(mp_size num) 2214 { 2215 mp_digit *out = px_alloc(num * sizeof(mp_digit)); 2216 2217 assert(out != NULL); 2218 2219 #if IMATH_DEBUG 2220 for (mp_size ix = 0; ix < num; ++ix) 2221 out[ix] = fill; 2222 #endif 2223 return out; 2224 } 2225 2226 static mp_digit * 2227 s_realloc(mp_digit *old, mp_size osize, mp_size nsize) 2228 { 2229 #if IMATH_DEBUG 2230 mp_digit *new = s_alloc(nsize); 2231 2232 assert(new != NULL); 2233 2234 for (mp_size ix = 0; ix < nsize; ++ix) 2235 new[ix] = fill; 2236 memcpy(new, old, osize * sizeof(mp_digit)); 2237 #else 2238 mp_digit *new = px_realloc(old, nsize * sizeof(mp_digit)); 2239 2240 assert(new != NULL); 2241 #endif 2242 2243 return new; 2244 } 2245 2246 static void 2247 s_free(void *ptr) 2248 { 2249 px_free(ptr); 2250 } 2251 2252 static bool 2253 s_pad(mp_int z, mp_size min) 2254 { 2255 if (MP_ALLOC(z) < min) 2256 { 2257 mp_size nsize = s_round_prec(min); 2258 mp_digit *tmp; 2259 2260 if (z->digits == &(z->single)) 2261 { 2262 if ((tmp = s_alloc(nsize)) == NULL) 2263 return false; 2264 tmp[0] = z->single; 2265 } 2266 else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL) 2267 { 2268 return false; 2269 } 2270 2271 z->digits = tmp; 2272 z->alloc = nsize; 2273 } 2274 2275 return true; 2276 } 2277 2278 /* Note: This will not work correctly when value == MP_SMALL_MIN */ 2279 static void 2280 s_fake(mp_int z, mp_small value, mp_digit vbuf[]) 2281 { 2282 mp_usmall uv = (mp_usmall) (value < 0) ? -value : value; 2283 2284 s_ufake(z, uv, vbuf); 2285 if (value < 0) 2286 z->sign = MP_NEG; 2287 } 2288 2289 static void 2290 s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]) 2291 { 2292 mp_size ndig = (mp_size) s_uvpack(value, vbuf); 2293 2294 z->used = ndig; 2295 z->alloc = MP_VALUE_DIGITS(value); 2296 z->sign = MP_ZPOS; 2297 z->digits = vbuf; 2298 } 2299 2300 static int 2301 s_cdig(mp_digit *da, mp_digit *db, mp_size len) 2302 { 2303 mp_digit *dat = da + len - 1, 2304 *dbt = db + len - 1; 2305 2306 for ( /* */ ; len != 0; --len, --dat, --dbt) 2307 { 2308 if (*dat > *dbt) 2309 { 2310 return 1; 2311 } 2312 else if (*dat < *dbt) 2313 { 2314 return -1; 2315 } 2316 } 2317 2318 return 0; 2319 } 2320 2321 static int 2322 s_uvpack(mp_usmall uv, mp_digit t[]) 2323 { 2324 int ndig = 0; 2325 2326 if (uv == 0) 2327 t[ndig++] = 0; 2328 else 2329 { 2330 while (uv != 0) 2331 { 2332 t[ndig++] = (mp_digit) uv; 2333 uv >>= MP_DIGIT_BIT / 2; 2334 uv >>= MP_DIGIT_BIT / 2; 2335 } 2336 } 2337 2338 return ndig; 2339 } 2340 2341 static int 2342 s_ucmp(mp_int a, mp_int b) 2343 { 2344 mp_size ua = MP_USED(a), 2345 ub = MP_USED(b); 2346 2347 if (ua > ub) 2348 { 2349 return 1; 2350 } 2351 else if (ub > ua) 2352 { 2353 return -1; 2354 } 2355 else 2356 { 2357 return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua); 2358 } 2359 } 2360 2361 static int 2362 s_vcmp(mp_int a, mp_small v) 2363 { 2364 #ifdef _MSC_VER 2365 #pragma warning(push) 2366 #pragma warning(disable: 4146) 2367 #endif 2368 mp_usmall uv = (v < 0) ? -(mp_usmall) v : (mp_usmall) v; 2369 #ifdef _MSC_VER 2370 #pragma warning(pop) 2371 #endif 2372 2373 return s_uvcmp(a, uv); 2374 } 2375 2376 static int 2377 s_uvcmp(mp_int a, mp_usmall uv) 2378 { 2379 mpz_t vtmp; 2380 mp_digit vdig[MP_VALUE_DIGITS(uv)]; 2381 2382 s_ufake(&vtmp, uv, vdig); 2383 return s_ucmp(a, &vtmp); 2384 } 2385 2386 static mp_digit 2387 s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 2388 mp_size size_b) 2389 { 2390 mp_size pos; 2391 mp_word w = 0; 2392 2393 /* Insure that da is the longer of the two to simplify later code */ 2394 if (size_b > size_a) 2395 { 2396 SWAP(mp_digit *, da, db); 2397 SWAP(mp_size, size_a, size_b); 2398 } 2399 2400 /* Add corresponding digits until the shorter number runs out */ 2401 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) 2402 { 2403 w = w + (mp_word) *da + (mp_word) *db; 2404 *dc = LOWER_HALF(w); 2405 w = UPPER_HALF(w); 2406 } 2407 2408 /* Propagate carries as far as necessary */ 2409 for ( /* */ ; pos < size_a; ++pos, ++da, ++dc) 2410 { 2411 w = w + *da; 2412 2413 *dc = LOWER_HALF(w); 2414 w = UPPER_HALF(w); 2415 } 2416 2417 /* Return carry out */ 2418 return (mp_digit) w; 2419 } 2420 2421 static void 2422 s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 2423 mp_size size_b) 2424 { 2425 mp_size pos; 2426 mp_word w = 0; 2427 2428 /* We assume that |a| >= |b| so this should definitely hold */ 2429 assert(size_a >= size_b); 2430 2431 /* Subtract corresponding digits and propagate borrow */ 2432 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) 2433 { 2434 w = ((mp_word) MP_DIGIT_MAX + 1 + /* MP_RADIX */ 2435 (mp_word) *da) - 2436 w - (mp_word) *db; 2437 2438 *dc = LOWER_HALF(w); 2439 w = (UPPER_HALF(w) == 0); 2440 } 2441 2442 /* Finish the subtraction for remaining upper digits of da */ 2443 for ( /* */ ; pos < size_a; ++pos, ++da, ++dc) 2444 { 2445 w = ((mp_word) MP_DIGIT_MAX + 1 + /* MP_RADIX */ 2446 (mp_word) *da) - 2447 w; 2448 2449 *dc = LOWER_HALF(w); 2450 w = (UPPER_HALF(w) == 0); 2451 } 2452 2453 /* If there is a borrow out at the end, it violates the precondition */ 2454 assert(w == 0); 2455 } 2456 2457 static int 2458 s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 2459 mp_size size_b) 2460 { 2461 mp_size bot_size; 2462 2463 /* Make sure b is the smaller of the two input values */ 2464 if (size_b > size_a) 2465 { 2466 SWAP(mp_digit *, da, db); 2467 SWAP(mp_size, size_a, size_b); 2468 } 2469 2470 /* 2471 * Insure that the bottom is the larger half in an odd-length split; the 2472 * code below relies on this being true. 2473 */ 2474 bot_size = (size_a + 1) / 2; 2475 2476 /* 2477 * If the values are big enough to bother with recursion, use the 2478 * Karatsuba algorithm to compute the product; otherwise use the normal 2479 * multiplication algorithm 2480 */ 2481 if (multiply_threshold && size_a >= multiply_threshold && size_b > bot_size) 2482 { 2483 mp_digit *t1, 2484 *t2, 2485 *t3, 2486 carry; 2487 2488 mp_digit *a_top = da + bot_size; 2489 mp_digit *b_top = db + bot_size; 2490 2491 mp_size at_size = size_a - bot_size; 2492 mp_size bt_size = size_b - bot_size; 2493 mp_size buf_size = 2 * bot_size; 2494 2495 /* 2496 * Do a single allocation for all three temporary buffers needed; each 2497 * buffer must be big enough to hold the product of two bottom halves, 2498 * and one buffer needs space for the completed product; twice the 2499 * space is plenty. 2500 */ 2501 if ((t1 = s_alloc(4 * buf_size)) == NULL) 2502 return 0; 2503 t2 = t1 + buf_size; 2504 t3 = t2 + buf_size; 2505 ZERO(t1, 4 * buf_size); 2506 2507 /* 2508 * t1 and t2 are initially used as temporaries to compute the inner 2509 * product (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0 2510 */ 2511 carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */ 2512 t1[bot_size] = carry; 2513 2514 carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */ 2515 t2[bot_size] = carry; 2516 2517 (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */ 2518 2519 /* 2520 * Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so 2521 * that we're left with only the pieces we want: t3 = a1b0 + a0b1 2522 */ 2523 ZERO(t1, buf_size); 2524 ZERO(t2, buf_size); 2525 (void) s_kmul(da, db, t1, bot_size, bot_size); /* t1 = a0 * b0 */ 2526 (void) s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */ 2527 2528 /* Subtract out t1 and t2 to get the inner product */ 2529 s_usub(t3, t1, t3, buf_size + 2, buf_size); 2530 s_usub(t3, t2, t3, buf_size + 2, buf_size); 2531 2532 /* Assemble the output value */ 2533 COPY(t1, dc, buf_size); 2534 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size); 2535 assert(carry == 0); 2536 2537 carry = 2538 s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size); 2539 assert(carry == 0); 2540 2541 s_free(t1); /* note t2 and t3 are just internal pointers 2542 * to t1 */ 2543 } 2544 else 2545 { 2546 s_umul(da, db, dc, size_a, size_b); 2547 } 2548 2549 return 1; 2550 } 2551 2552 static void 2553 s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a, 2554 mp_size size_b) 2555 { 2556 mp_size a, 2557 b; 2558 mp_word w; 2559 2560 for (a = 0; a < size_a; ++a, ++dc, ++da) 2561 { 2562 mp_digit *dct = dc; 2563 mp_digit *dbt = db; 2564 2565 if (*da == 0) 2566 continue; 2567 2568 w = 0; 2569 for (b = 0; b < size_b; ++b, ++dbt, ++dct) 2570 { 2571 w = (mp_word) *da * (mp_word) *dbt + w + (mp_word) *dct; 2572 2573 *dct = LOWER_HALF(w); 2574 w = UPPER_HALF(w); 2575 } 2576 2577 *dct = (mp_digit) w; 2578 } 2579 } 2580 2581 static int 2582 s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a) 2583 { 2584 if (multiply_threshold && size_a > multiply_threshold) 2585 { 2586 mp_size bot_size = (size_a + 1) / 2; 2587 mp_digit *a_top = da + bot_size; 2588 mp_digit *t1, 2589 *t2, 2590 *t3, 2591 carry PG_USED_FOR_ASSERTS_ONLY; 2592 mp_size at_size = size_a - bot_size; 2593 mp_size buf_size = 2 * bot_size; 2594 2595 if ((t1 = s_alloc(4 * buf_size)) == NULL) 2596 return 0; 2597 t2 = t1 + buf_size; 2598 t3 = t2 + buf_size; 2599 ZERO(t1, 4 * buf_size); 2600 2601 (void) s_ksqr(da, t1, bot_size); /* t1 = a0 ^ 2 */ 2602 (void) s_ksqr(a_top, t2, at_size); /* t2 = a1 ^ 2 */ 2603 2604 (void) s_kmul(da, a_top, t3, bot_size, at_size); /* t3 = a0 * a1 */ 2605 2606 /* Quick multiply t3 by 2, shifting left (can't overflow) */ 2607 { 2608 int i, 2609 top = bot_size + at_size; 2610 mp_word w, 2611 save = 0; 2612 2613 for (i = 0; i < top; ++i) 2614 { 2615 w = t3[i]; 2616 w = (w << 1) | save; 2617 t3[i] = LOWER_HALF(w); 2618 save = UPPER_HALF(w); 2619 } 2620 t3[i] = LOWER_HALF(save); 2621 } 2622 2623 /* Assemble the output value */ 2624 COPY(t1, dc, 2 * bot_size); 2625 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size); 2626 assert(carry == 0); 2627 2628 carry = 2629 s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size); 2630 assert(carry == 0); 2631 2632 s_free(t1); /* note that t2 and t2 are internal pointers 2633 * only */ 2634 2635 } 2636 else 2637 { 2638 s_usqr(da, dc, size_a); 2639 } 2640 2641 return 1; 2642 } 2643 2644 static void 2645 s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a) 2646 { 2647 mp_size i, 2648 j; 2649 mp_word w; 2650 2651 for (i = 0; i < size_a; ++i, dc += 2, ++da) 2652 { 2653 mp_digit *dct = dc, 2654 *dat = da; 2655 2656 if (*da == 0) 2657 continue; 2658 2659 /* Take care of the first digit, no rollover */ 2660 w = (mp_word) *dat * (mp_word) *dat + (mp_word) *dct; 2661 *dct = LOWER_HALF(w); 2662 w = UPPER_HALF(w); 2663 ++dat; 2664 ++dct; 2665 2666 for (j = i + 1; j < size_a; ++j, ++dat, ++dct) 2667 { 2668 mp_word t = (mp_word) *da * (mp_word) *dat; 2669 mp_word u = w + (mp_word) *dct, 2670 ov = 0; 2671 2672 /* Check if doubling t will overflow a word */ 2673 if (HIGH_BIT_SET(t)) 2674 ov = 1; 2675 2676 w = t + t; 2677 2678 /* Check if adding u to w will overflow a word */ 2679 if (ADD_WILL_OVERFLOW(w, u)) 2680 ov = 1; 2681 2682 w += u; 2683 2684 *dct = LOWER_HALF(w); 2685 w = UPPER_HALF(w); 2686 if (ov) 2687 { 2688 w += MP_DIGIT_MAX; /* MP_RADIX */ 2689 ++w; 2690 } 2691 } 2692 2693 w = w + *dct; 2694 *dct = (mp_digit) w; 2695 while ((w = UPPER_HALF(w)) != 0) 2696 { 2697 ++dct; 2698 w = w + *dct; 2699 *dct = LOWER_HALF(w); 2700 } 2701 2702 assert(w == 0); 2703 } 2704 } 2705 2706 static void 2707 s_dadd(mp_int a, mp_digit b) 2708 { 2709 mp_word w = 0; 2710 mp_digit *da = MP_DIGITS(a); 2711 mp_size ua = MP_USED(a); 2712 2713 w = (mp_word) *da + b; 2714 *da++ = LOWER_HALF(w); 2715 w = UPPER_HALF(w); 2716 2717 for (ua -= 1; ua > 0; --ua, ++da) 2718 { 2719 w = (mp_word) *da + w; 2720 2721 *da = LOWER_HALF(w); 2722 w = UPPER_HALF(w); 2723 } 2724 2725 if (w) 2726 { 2727 *da = (mp_digit) w; 2728 a->used += 1; 2729 } 2730 } 2731 2732 static void 2733 s_dmul(mp_int a, mp_digit b) 2734 { 2735 mp_word w = 0; 2736 mp_digit *da = MP_DIGITS(a); 2737 mp_size ua = MP_USED(a); 2738 2739 while (ua > 0) 2740 { 2741 w = (mp_word) *da * b + w; 2742 *da++ = LOWER_HALF(w); 2743 w = UPPER_HALF(w); 2744 --ua; 2745 } 2746 2747 if (w) 2748 { 2749 *da = (mp_digit) w; 2750 a->used += 1; 2751 } 2752 } 2753 2754 static void 2755 s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a) 2756 { 2757 mp_word w = 0; 2758 2759 while (size_a > 0) 2760 { 2761 w = (mp_word) *da++ * (mp_word) b + w; 2762 2763 *dc++ = LOWER_HALF(w); 2764 w = UPPER_HALF(w); 2765 --size_a; 2766 } 2767 2768 if (w) 2769 *dc = LOWER_HALF(w); 2770 } 2771 2772 static mp_digit 2773 s_ddiv(mp_int a, mp_digit b) 2774 { 2775 mp_word w = 0, 2776 qdigit; 2777 mp_size ua = MP_USED(a); 2778 mp_digit *da = MP_DIGITS(a) + ua - 1; 2779 2780 for ( /* */ ; ua > 0; --ua, --da) 2781 { 2782 w = (w << MP_DIGIT_BIT) | *da; 2783 2784 if (w >= b) 2785 { 2786 qdigit = w / b; 2787 w = w % b; 2788 } 2789 else 2790 { 2791 qdigit = 0; 2792 } 2793 2794 *da = (mp_digit) qdigit; 2795 } 2796 2797 CLAMP(a); 2798 return (mp_digit) w; 2799 } 2800 2801 static void 2802 s_qdiv(mp_int z, mp_size p2) 2803 { 2804 mp_size ndig = p2 / MP_DIGIT_BIT, 2805 nbits = p2 % MP_DIGIT_BIT; 2806 mp_size uz = MP_USED(z); 2807 2808 if (ndig) 2809 { 2810 mp_size mark; 2811 mp_digit *to, 2812 *from; 2813 2814 if (ndig >= uz) 2815 { 2816 mp_int_zero(z); 2817 return; 2818 } 2819 2820 to = MP_DIGITS(z); 2821 from = to + ndig; 2822 2823 for (mark = ndig; mark < uz; ++mark) 2824 { 2825 *to++ = *from++; 2826 } 2827 2828 z->used = uz - ndig; 2829 } 2830 2831 if (nbits) 2832 { 2833 mp_digit d = 0, 2834 *dz, 2835 save; 2836 mp_size up = MP_DIGIT_BIT - nbits; 2837 2838 uz = MP_USED(z); 2839 dz = MP_DIGITS(z) + uz - 1; 2840 2841 for ( /* */ ; uz > 0; --uz, --dz) 2842 { 2843 save = *dz; 2844 2845 *dz = (*dz >> nbits) | (d << up); 2846 d = save; 2847 } 2848 2849 CLAMP(z); 2850 } 2851 2852 if (MP_USED(z) == 1 && z->digits[0] == 0) 2853 z->sign = MP_ZPOS; 2854 } 2855 2856 static void 2857 s_qmod(mp_int z, mp_size p2) 2858 { 2859 mp_size start = p2 / MP_DIGIT_BIT + 1, 2860 rest = p2 % MP_DIGIT_BIT; 2861 mp_size uz = MP_USED(z); 2862 mp_digit mask = (1u << rest) - 1; 2863 2864 if (start <= uz) 2865 { 2866 z->used = start; 2867 z->digits[start - 1] &= mask; 2868 CLAMP(z); 2869 } 2870 } 2871 2872 static int 2873 s_qmul(mp_int z, mp_size p2) 2874 { 2875 mp_size uz, 2876 need, 2877 rest, 2878 extra, 2879 i; 2880 mp_digit *from, 2881 *to, 2882 d; 2883 2884 if (p2 == 0) 2885 return 1; 2886 2887 uz = MP_USED(z); 2888 need = p2 / MP_DIGIT_BIT; 2889 rest = p2 % MP_DIGIT_BIT; 2890 2891 /* 2892 * Figure out if we need an extra digit at the top end; this occurs if the 2893 * topmost `rest' bits of the high-order digit of z are not zero, meaning 2894 * they will be shifted off the end if not preserved 2895 */ 2896 extra = 0; 2897 if (rest != 0) 2898 { 2899 mp_digit *dz = MP_DIGITS(z) + uz - 1; 2900 2901 if ((*dz >> (MP_DIGIT_BIT - rest)) != 0) 2902 extra = 1; 2903 } 2904 2905 if (!s_pad(z, uz + need + extra)) 2906 return 0; 2907 2908 /* 2909 * If we need to shift by whole digits, do that in one pass, then to back 2910 * and shift by partial digits. 2911 */ 2912 if (need > 0) 2913 { 2914 from = MP_DIGITS(z) + uz - 1; 2915 to = from + need; 2916 2917 for (i = 0; i < uz; ++i) 2918 *to-- = *from--; 2919 2920 ZERO(MP_DIGITS(z), need); 2921 uz += need; 2922 } 2923 2924 if (rest) 2925 { 2926 d = 0; 2927 for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from) 2928 { 2929 mp_digit save = *from; 2930 2931 *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest)); 2932 d = save; 2933 } 2934 2935 d >>= (MP_DIGIT_BIT - rest); 2936 if (d != 0) 2937 { 2938 *from = d; 2939 uz += extra; 2940 } 2941 } 2942 2943 z->used = uz; 2944 CLAMP(z); 2945 2946 return 1; 2947 } 2948 2949 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z| 2950 The sign of the result is always zero/positive. 2951 */ 2952 static int 2953 s_qsub(mp_int z, mp_size p2) 2954 { 2955 mp_digit hi = (1u << (p2 % MP_DIGIT_BIT)), 2956 *zp; 2957 mp_size tdig = (p2 / MP_DIGIT_BIT), 2958 pos; 2959 mp_word w = 0; 2960 2961 if (!s_pad(z, tdig + 1)) 2962 return 0; 2963 2964 for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp) 2965 { 2966 w = ((mp_word) MP_DIGIT_MAX + 1) - w - (mp_word) *zp; 2967 2968 *zp = LOWER_HALF(w); 2969 w = UPPER_HALF(w) ? 0 : 1; 2970 } 2971 2972 w = ((mp_word) MP_DIGIT_MAX + 1 + hi) - w - (mp_word) *zp; 2973 *zp = LOWER_HALF(w); 2974 2975 assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */ 2976 2977 z->sign = MP_ZPOS; 2978 CLAMP(z); 2979 2980 return 1; 2981 } 2982 2983 static int 2984 s_dp2k(mp_int z) 2985 { 2986 int k = 0; 2987 mp_digit *dp = MP_DIGITS(z), 2988 d; 2989 2990 if (MP_USED(z) == 1 && *dp == 0) 2991 return 1; 2992 2993 while (*dp == 0) 2994 { 2995 k += MP_DIGIT_BIT; 2996 ++dp; 2997 } 2998 2999 d = *dp; 3000 while ((d & 1) == 0) 3001 { 3002 d >>= 1; 3003 ++k; 3004 } 3005 3006 return k; 3007 } 3008 3009 static int 3010 s_isp2(mp_int z) 3011 { 3012 mp_size uz = MP_USED(z), 3013 k = 0; 3014 mp_digit *dz = MP_DIGITS(z), 3015 d; 3016 3017 while (uz > 1) 3018 { 3019 if (*dz++ != 0) 3020 return -1; 3021 k += MP_DIGIT_BIT; 3022 --uz; 3023 } 3024 3025 d = *dz; 3026 while (d > 1) 3027 { 3028 if (d & 1) 3029 return -1; 3030 ++k; 3031 d >>= 1; 3032 } 3033 3034 return (int) k; 3035 } 3036 3037 static int 3038 s_2expt(mp_int z, mp_small k) 3039 { 3040 mp_size ndig, 3041 rest; 3042 mp_digit *dz; 3043 3044 ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT; 3045 rest = k % MP_DIGIT_BIT; 3046 3047 if (!s_pad(z, ndig)) 3048 return 0; 3049 3050 dz = MP_DIGITS(z); 3051 ZERO(dz, ndig); 3052 *(dz + ndig - 1) = (1u << rest); 3053 z->used = ndig; 3054 3055 return 1; 3056 } 3057 3058 static int 3059 s_norm(mp_int a, mp_int b) 3060 { 3061 mp_digit d = b->digits[MP_USED(b) - 1]; 3062 int k = 0; 3063 3064 while (d < (1u << (mp_digit) (MP_DIGIT_BIT - 1))) 3065 { /* d < (MP_RADIX / 2) */ 3066 d <<= 1; 3067 ++k; 3068 } 3069 3070 /* These multiplications can't fail */ 3071 if (k != 0) 3072 { 3073 (void) s_qmul(a, (mp_size) k); 3074 (void) s_qmul(b, (mp_size) k); 3075 } 3076 3077 return k; 3078 } 3079 3080 static mp_result 3081 s_brmu(mp_int z, mp_int m) 3082 { 3083 mp_size um = MP_USED(m) * 2; 3084 3085 if (!s_pad(z, um)) 3086 return MP_MEMORY; 3087 3088 s_2expt(z, MP_DIGIT_BIT * um); 3089 return mp_int_div(z, m, z, NULL); 3090 } 3091 3092 static int 3093 s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2) 3094 { 3095 mp_size um = MP_USED(m), 3096 umb_p1, 3097 umb_m1; 3098 3099 umb_p1 = (um + 1) * MP_DIGIT_BIT; 3100 umb_m1 = (um - 1) * MP_DIGIT_BIT; 3101 3102 if (mp_int_copy(x, q1) != MP_OK) 3103 return 0; 3104 3105 /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */ 3106 s_qdiv(q1, umb_m1); 3107 UMUL(q1, mu, q2); 3108 s_qdiv(q2, umb_p1); 3109 3110 /* Set x = x mod b^(k+1) */ 3111 s_qmod(x, umb_p1); 3112 3113 /* 3114 * Now, q is a guess for the quotient a / m. Compute x - q * m mod 3115 * b^(k+1), replacing x. This may be off by a factor of 2m, but no more 3116 * than that. 3117 */ 3118 UMUL(q2, m, q1); 3119 s_qmod(q1, umb_p1); 3120 (void) mp_int_sub(x, q1, x); /* can't fail */ 3121 3122 /* 3123 * The result may be < 0; if it is, add b^(k+1) to pin it in the proper 3124 * range. 3125 */ 3126 if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1)) 3127 return 0; 3128 3129 /* 3130 * If x > m, we need to back it off until it is in range. This will be 3131 * required at most twice. 3132 */ 3133 if (mp_int_compare(x, m) >= 0) 3134 { 3135 (void) mp_int_sub(x, m, x); 3136 if (mp_int_compare(x, m) >= 0) 3137 { 3138 (void) mp_int_sub(x, m, x); 3139 } 3140 } 3141 3142 /* At this point, x has been properly reduced. */ 3143 return 1; 3144 } 3145 3146 /* Perform modular exponentiation using Barrett's method, where mu is the 3147 reduction constant for m. Assumes a < m, b > 0. */ 3148 static mp_result 3149 s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c) 3150 { 3151 mp_digit umu = MP_USED(mu); 3152 mp_digit *db = MP_DIGITS(b); 3153 mp_digit *dbt = db + MP_USED(b) - 1; 3154 3155 DECLARE_TEMP(3); 3156 REQUIRE(GROW(TEMP(0), 4 * umu)); 3157 REQUIRE(GROW(TEMP(1), 4 * umu)); 3158 REQUIRE(GROW(TEMP(2), 4 * umu)); 3159 ZERO(TEMP(0)->digits, TEMP(0)->alloc); 3160 ZERO(TEMP(1)->digits, TEMP(1)->alloc); 3161 ZERO(TEMP(2)->digits, TEMP(2)->alloc); 3162 3163 (void) mp_int_set_value(c, 1); 3164 3165 /* Take care of low-order digits */ 3166 while (db < dbt) 3167 { 3168 mp_digit d = *db; 3169 3170 for (int i = MP_DIGIT_BIT; i > 0; --i, d >>= 1) 3171 { 3172 if (d & 1) 3173 { 3174 /* The use of a second temporary avoids allocation */ 3175 UMUL(c, a, TEMP(0)); 3176 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) 3177 { 3178 REQUIRE(MP_MEMORY); 3179 } 3180 mp_int_copy(TEMP(0), c); 3181 } 3182 3183 USQR(a, TEMP(0)); 3184 assert(MP_SIGN(TEMP(0)) == MP_ZPOS); 3185 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) 3186 { 3187 REQUIRE(MP_MEMORY); 3188 } 3189 assert(MP_SIGN(TEMP(0)) == MP_ZPOS); 3190 mp_int_copy(TEMP(0), a); 3191 } 3192 3193 ++db; 3194 } 3195 3196 /* Take care of highest-order digit */ 3197 mp_digit d = *dbt; 3198 3199 for (;;) 3200 { 3201 if (d & 1) 3202 { 3203 UMUL(c, a, TEMP(0)); 3204 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) 3205 { 3206 REQUIRE(MP_MEMORY); 3207 } 3208 mp_int_copy(TEMP(0), c); 3209 } 3210 3211 d >>= 1; 3212 if (!d) 3213 break; 3214 3215 USQR(a, TEMP(0)); 3216 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) 3217 { 3218 REQUIRE(MP_MEMORY); 3219 } 3220 (void) mp_int_copy(TEMP(0), a); 3221 } 3222 3223 CLEANUP_TEMP(); 3224 return MP_OK; 3225 } 3226 3227 /* Division of nonnegative integers 3228 3229 This function implements division algorithm for unsigned multi-precision 3230 integers. The algorithm is based on Algorithm D from Knuth's "The Art of 3231 Computer Programming", 3rd ed. 1998, pg 272-273. 3232 3233 We diverge from Knuth's algorithm in that we do not perform the subtraction 3234 from the remainder until we have determined that we have the correct 3235 quotient digit. This makes our algorithm less efficient that Knuth because 3236 we might have to perform multiple multiplication and comparison steps before 3237 the subtraction. The advantage is that it is easy to implement and ensure 3238 correctness without worrying about underflow from the subtraction. 3239 3240 inputs: u a n+m digit integer in base b (b is 2^MP_DIGIT_BIT) 3241 v a n digit integer in base b (b is 2^MP_DIGIT_BIT) 3242 n >= 1 3243 m >= 0 3244 outputs: u / v stored in u 3245 u % v stored in v 3246 */ 3247 static mp_result 3248 s_udiv_knuth(mp_int u, mp_int v) 3249 { 3250 /* Force signs to positive */ 3251 u->sign = MP_ZPOS; 3252 v->sign = MP_ZPOS; 3253 3254 /* Use simple division algorithm when v is only one digit long */ 3255 if (MP_USED(v) == 1) 3256 { 3257 mp_digit d, 3258 rem; 3259 3260 d = v->digits[0]; 3261 rem = s_ddiv(u, d); 3262 mp_int_set_value(v, rem); 3263 return MP_OK; 3264 } 3265 3266 /* 3267 * Algorithm D 3268 * 3269 * The n and m variables are defined as used by Knuth. u is an n digit 3270 * number with digits u_{n-1}..u_0. v is an n+m digit number with digits 3271 * from v_{m+n-1}..v_0. We require that n > 1 and m >= 0 3272 */ 3273 mp_size n = MP_USED(v); 3274 mp_size m = MP_USED(u) - n; 3275 3276 assert(n > 1); 3277 /* assert(m >= 0) follows because m is unsigned. */ 3278 3279 /* 3280 * D1: Normalize. The normalization step provides the necessary condition 3281 * for Theorem B, which states that the quotient estimate for q_j, call it 3282 * qhat 3283 * 3284 * qhat = u_{j+n}u_{j+n-1} / v_{n-1} 3285 * 3286 * is bounded by 3287 * 3288 * qhat - 2 <= q_j <= qhat. 3289 * 3290 * That is, qhat is always greater than the actual quotient digit q, and 3291 * it is never more than two larger than the actual quotient digit. 3292 */ 3293 int k = s_norm(u, v); 3294 3295 /* 3296 * Extend size of u by one if needed. 3297 * 3298 * The algorithm begins with a value of u that has one more digit of 3299 * input. The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0. 3300 * If the multiplication did not increase the number of digits of u, we 3301 * need to add a leading zero here. 3302 */ 3303 if (k == 0 || MP_USED(u) != m + n + 1) 3304 { 3305 if (!s_pad(u, m + n + 1)) 3306 return MP_MEMORY; 3307 u->digits[m + n] = 0; 3308 u->used = m + n + 1; 3309 } 3310 3311 /* 3312 * Add a leading 0 to v. 3313 * 3314 * The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0. We need 3315 * to add the leading zero to v here to ensure that the multiplication 3316 * will produce the full n+1 digit result. 3317 */ 3318 if (!s_pad(v, n + 1)) 3319 return MP_MEMORY; 3320 v->digits[n] = 0; 3321 3322 /* 3323 * Initialize temporary variables q and t. q allocates space for m+1 3324 * digits to store the quotient digits t allocates space for n+1 digits to 3325 * hold the result of q_j*v 3326 */ 3327 DECLARE_TEMP(2); 3328 REQUIRE(GROW(TEMP(0), m + 1)); 3329 REQUIRE(GROW(TEMP(1), n + 1)); 3330 3331 /* D2: Initialize j */ 3332 int j = m; 3333 mpz_t r; 3334 3335 r.digits = MP_DIGITS(u) + j; /* The contents of r are shared with u */ 3336 r.used = n + 1; 3337 r.sign = MP_ZPOS; 3338 r.alloc = MP_ALLOC(u); 3339 ZERO(TEMP(1)->digits, TEMP(1)->alloc); 3340 3341 /* Calculate the m+1 digits of the quotient result */ 3342 for (; j >= 0; j--) 3343 { 3344 /* D3: Calculate q' */ 3345 /* r->digits is aligned to position j of the number u */ 3346 mp_word pfx, 3347 qhat; 3348 3349 pfx = r.digits[n]; 3350 pfx <<= MP_DIGIT_BIT / 2; 3351 pfx <<= MP_DIGIT_BIT / 2; 3352 pfx |= r.digits[n - 1]; /* pfx = u_{j+n}{j+n-1} */ 3353 3354 qhat = pfx / v->digits[n - 1]; 3355 3356 /* 3357 * Check to see if qhat > b, and decrease qhat if so. Theorem B 3358 * guarantess that qhat is at most 2 larger than the actual value, so 3359 * it is possible that qhat is greater than the maximum value that 3360 * will fit in a digit 3361 */ 3362 if (qhat > MP_DIGIT_MAX) 3363 qhat = MP_DIGIT_MAX; 3364 3365 /* 3366 * D4,D5,D6: Multiply qhat * v and test for a correct value of q 3367 * 3368 * We proceed a bit different than the way described by Knuth. This 3369 * way is simpler but less efficent. Instead of doing the multiply and 3370 * subtract then checking for underflow, we first do the multiply of 3371 * qhat * v and see if it is larger than the current remainder r. If 3372 * it is larger, we decrease qhat by one and try again. We may need to 3373 * decrease qhat one more time before we get a value that is smaller 3374 * than r. 3375 * 3376 * This way is less efficent than Knuth becuase we do more multiplies, 3377 * but we do not need to worry about underflow this way. 3378 */ 3379 /* t = qhat * v */ 3380 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1); 3381 TEMP(1)->used = n + 1; 3382 CLAMP(TEMP(1)); 3383 3384 /* Clamp r for the comparison. Comparisons do not like leading zeros. */ 3385 CLAMP(&r); 3386 if (s_ucmp(TEMP(1), &r) > 0) 3387 { /* would the remainder be negative? */ 3388 qhat -= 1; /* try a smaller q */ 3389 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1); 3390 TEMP(1)->used = n + 1; 3391 CLAMP(TEMP(1)); 3392 if (s_ucmp(TEMP(1), &r) > 0) 3393 { /* would the remainder be negative? */ 3394 assert(qhat > 0); 3395 qhat -= 1; /* try a smaller q */ 3396 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1); 3397 TEMP(1)->used = n + 1; 3398 CLAMP(TEMP(1)); 3399 } 3400 assert(s_ucmp(TEMP(1), &r) <= 0 && "The mathematics failed us."); 3401 } 3402 3403 /* 3404 * Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be 3405 * n+1 digits long. 3406 */ 3407 r.used = n + 1; 3408 3409 /* 3410 * D4: Multiply and subtract 3411 * 3412 * Note: The multiply was completed above so we only need to subtract 3413 * here. 3414 */ 3415 s_usub(r.digits, TEMP(1)->digits, r.digits, r.used, TEMP(1)->used); 3416 3417 /* 3418 * D5: Test remainder 3419 * 3420 * Note: Not needed because we always check that qhat is the correct 3421 * value before performing the subtract. Value cast to mp_digit to 3422 * prevent warning, qhat has been clamped to MP_DIGIT_MAX 3423 */ 3424 TEMP(0)->digits[j] = (mp_digit) qhat; 3425 3426 /* 3427 * D6: Add back Note: Not needed because we always check that qhat is 3428 * the correct value before performing the subtract. 3429 */ 3430 3431 /* D7: Loop on j */ 3432 r.digits--; 3433 ZERO(TEMP(1)->digits, TEMP(1)->alloc); 3434 } 3435 3436 /* Get rid of leading zeros in q */ 3437 TEMP(0)->used = m + 1; 3438 CLAMP(TEMP(0)); 3439 3440 /* Denormalize the remainder */ 3441 CLAMP(u); /* use u here because the r.digits pointer is 3442 * off-by-one */ 3443 if (k != 0) 3444 s_qdiv(u, k); 3445 3446 mp_int_copy(u, v); /* ok: 0 <= r < v */ 3447 mp_int_copy(TEMP(0), u); /* ok: q <= u */ 3448 3449 CLEANUP_TEMP(); 3450 return MP_OK; 3451 } 3452 3453 static int 3454 s_outlen(mp_int z, mp_size r) 3455 { 3456 assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX); 3457 3458 mp_result bits = mp_int_count_bits(z); 3459 double raw = (double) bits * s_log2[r]; 3460 3461 return (int) (raw + 0.999999); 3462 } 3463 3464 static mp_size 3465 s_inlen(int len, mp_size r) 3466 { 3467 double raw = (double) len / s_log2[r]; 3468 mp_size bits = (mp_size) (raw + 0.5); 3469 3470 return (mp_size) ((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1; 3471 } 3472 3473 static int 3474 s_ch2val(char c, int r) 3475 { 3476 int out; 3477 3478 /* 3479 * In some locales, isalpha() accepts characters outside the range A-Z, 3480 * producing out<0 or out>=36. The "out >= r" check will always catch 3481 * out>=36. Though nothing explicitly catches out<0, our caller reacts 3482 * the same way to every negative return value. 3483 */ 3484 if (isdigit((unsigned char) c)) 3485 out = c - '0'; 3486 else if (r > 10 && isalpha((unsigned char) c)) 3487 out = toupper((unsigned char) c) - 'A' + 10; 3488 else 3489 return -1; 3490 3491 return (out >= r) ? -1 : out; 3492 } 3493 3494 static char 3495 s_val2ch(int v, int caps) 3496 { 3497 assert(v >= 0); 3498 3499 if (v < 10) 3500 { 3501 return v + '0'; 3502 } 3503 else 3504 { 3505 char out = (v - 10) + 'a'; 3506 3507 if (caps) 3508 { 3509 return toupper((unsigned char) out); 3510 } 3511 else 3512 { 3513 return out; 3514 } 3515 } 3516 } 3517 3518 static void 3519 s_2comp(unsigned char *buf, int len) 3520 { 3521 unsigned short s = 1; 3522 3523 for (int i = len - 1; i >= 0; --i) 3524 { 3525 unsigned char c = ~buf[i]; 3526 3527 s = c + s; 3528 c = s & UCHAR_MAX; 3529 s >>= CHAR_BIT; 3530 3531 buf[i] = c; 3532 } 3533 3534 /* last carry out is ignored */ 3535 } 3536 3537 static mp_result 3538 s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad) 3539 { 3540 int pos = 0, 3541 limit = *limpos; 3542 mp_size uz = MP_USED(z); 3543 mp_digit *dz = MP_DIGITS(z); 3544 3545 while (uz > 0 && pos < limit) 3546 { 3547 mp_digit d = *dz++; 3548 int i; 3549 3550 for (i = sizeof(mp_digit); i > 0 && pos < limit; --i) 3551 { 3552 buf[pos++] = (unsigned char) d; 3553 d >>= CHAR_BIT; 3554 3555 /* Don't write leading zeroes */ 3556 if (d == 0 && uz == 1) 3557 i = 0; /* exit loop without signaling truncation */ 3558 } 3559 3560 /* Detect truncation (loop exited with pos >= limit) */ 3561 if (i > 0) 3562 break; 3563 3564 --uz; 3565 } 3566 3567 if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1))) 3568 { 3569 if (pos < limit) 3570 { 3571 buf[pos++] = 0; 3572 } 3573 else 3574 { 3575 uz = 1; 3576 } 3577 } 3578 3579 /* Digits are in reverse order, fix that */ 3580 REV(buf, pos); 3581 3582 /* Return the number of bytes actually written */ 3583 *limpos = pos; 3584 3585 return (uz == 0) ? MP_OK : MP_TRUNC; 3586 } 3587 3588 /* Here there be dragons */ 3589