1 /*-------------------------------------------------------------------------
2  *
3  * imath.c
4  *
5  * Last synchronized from https://github.com/creachadair/imath/tree/v1.29,
6  * using the following procedure:
7  *
8  * 1. Download imath.c and imath.h of the last synchronized version.  Remove
9  *    "#ifdef __cplusplus" blocks, which upset pgindent.  Run pgindent on the
10  *    two files.  Filter the two files through "unexpand -t4 --first-only".
11  *    Diff the result against the PostgreSQL versions.  As of the last
12  *    synchronization, changes were as follows:
13  *
14  *    - replace malloc(), realloc() and free() with px_ versions
15  *    - redirect assert() to Assert()
16  *    - #undef MIN, #undef MAX before defining them
17  *    - remove includes covered by c.h
18  *    - rename DEBUG to IMATH_DEBUG
19  *    - replace stdint.h usage with c.h equivalents
20  *    - suppress MSVC warning 4146
21  *    - add required PG_USED_FOR_ASSERTS_ONLY
22  *
23  * 2. Download a newer imath.c and imath.h.  Transform them like in step 1.
24  *    Apply to these files the diff you saved in step 1.  Look for new lines
25  *    requiring the same kind of change, such as new malloc() calls.
26  *
27  * 3. Configure PostgreSQL using --without-openssl.  Run "make -C
28  *    contrib/pgcrypto check".
29  *
30  * 4. Update this header comment.
31  *
32  * Portions Copyright (c) 1996-2020, PostgreSQL Global Development Group
33  *
34  * IDENTIFICATION
35  *	  contrib/pgcrypto/imath.c
36  *
37  * Upstream copyright terms follow.
38  *-------------------------------------------------------------------------
39  */
40 
41 /*
42   Name:		imath.c
43   Purpose:	Arbitrary precision integer arithmetic routines.
44   Author:   M. J. Fromberger
45 
46   Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved.
47 
48   Permission is hereby granted, free of charge, to any person obtaining a copy
49   of this software and associated documentation files (the "Software"), to deal
50   in the Software without restriction, including without limitation the rights
51   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
52   copies of the Software, and to permit persons to whom the Software is
53   furnished to do so, subject to the following conditions:
54 
55   The above copyright notice and this permission notice shall be included in
56   all copies or substantial portions of the Software.
57 
58   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
59   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
60   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
61   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
62   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
63   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
64   SOFTWARE.
65  */
66 
67 #include "postgres.h"
68 
69 #include "imath.h"
70 #include "px.h"
71 
72 #undef assert
73 #define assert(TEST) Assert(TEST)
74 
75 const mp_result MP_OK = 0;		/* no error, all is well  */
76 const mp_result MP_FALSE = 0;	/* boolean false          */
77 const mp_result MP_TRUE = -1;	/* boolean true           */
78 const mp_result MP_MEMORY = -2; /* out of memory          */
79 const mp_result MP_RANGE = -3;	/* argument out of range  */
80 const mp_result MP_UNDEF = -4;	/* result undefined       */
81 const mp_result MP_TRUNC = -5;	/* output truncated       */
82 const mp_result MP_BADARG = -6; /* invalid null argument  */
83 const mp_result MP_MINERR = -6;
84 
85 const mp_sign MP_NEG = 1;		/* value is strictly negative */
86 const mp_sign MP_ZPOS = 0;		/* value is non-negative      */
87 
88 static const char *s_unknown_err = "unknown result code";
89 static const char *s_error_msg[] = {"error code 0", "boolean true",
90 	"out of memory", "argument out of range",
91 	"result undefined", "output truncated",
92 "invalid argument", NULL};
93 
94 /* The ith entry of this table gives the value of log_i(2).
95 
96    An integer value n requires ceil(log_i(n)) digits to be represented
97    in base i.  Since it is easy to compute lg(n), by counting bits, we
98    can compute log_i(n) = lg(n) * log_i(2).
99 
100    The use of this table eliminates a dependency upon linkage against
101    the standard math libraries.
102 
103    If MP_MAX_RADIX is increased, this table should be expanded too.
104  */
105 static const double s_log2[] = {
106 	0.000000000, 0.000000000, 1.000000000, 0.630929754, /* (D)(D) 2  3 */
107 	0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4  5  6  7 */
108 	0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8  9 10 11 */
109 	0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */
110 	0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */
111 	0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */
112 	0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */
113 	0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */
114 	0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */
115 	0.193426404,				/* 36          */
116 };
117 
118 /* Return the number of digits needed to represent a static value */
119 #define MP_VALUE_DIGITS(V) \
120   ((sizeof(V) + (sizeof(mp_digit) - 1)) / sizeof(mp_digit))
121 
122 /* Round precision P to nearest word boundary */
123 static inline mp_size
124 s_round_prec(mp_size P)
125 {
126 	return 2 * ((P + 1) / 2);
127 }
128 
129 /* Set array P of S digits to zero */
130 static inline void
131 ZERO(mp_digit *P, mp_size S)
132 {
133 	mp_size		i__ = S * sizeof(mp_digit);
134 	mp_digit   *p__ = P;
135 
136 	memset(p__, 0, i__);
137 }
138 
139 /* Copy S digits from array P to array Q */
140 static inline void
141 COPY(mp_digit *P, mp_digit *Q, mp_size S)
142 {
143 	mp_size		i__ = S * sizeof(mp_digit);
144 	mp_digit   *p__ = P;
145 	mp_digit   *q__ = Q;
146 
147 	memcpy(q__, p__, i__);
148 }
149 
150 /* Reverse N elements of unsigned char in A. */
151 static inline void
152 REV(unsigned char *A, int N)
153 {
154 	unsigned char *u_ = A;
155 	unsigned char *v_ = u_ + N - 1;
156 
157 	while (u_ < v_)
158 	{
159 		unsigned char xch = *u_;
160 
161 		*u_++ = *v_;
162 		*v_-- = xch;
163 	}
164 }
165 
166 /* Strip leading zeroes from z_ in-place. */
167 static inline void
168 CLAMP(mp_int z_)
169 {
170 	mp_size		uz_ = MP_USED(z_);
171 	mp_digit   *dz_ = MP_DIGITS(z_) + uz_ - 1;
172 
173 	while (uz_ > 1 && (*dz_-- == 0))
174 		--uz_;
175 	z_->used = uz_;
176 }
177 
178 /* Select min/max. */
179 #undef MIN
180 #undef MAX
181 static inline int
182 MIN(int A, int B)
183 {
184 	return (B < A ? B : A);
185 }
186 static inline mp_size
187 MAX(mp_size A, mp_size B)
188 {
189 	return (B > A ? B : A);
190 }
191 
192 /* Exchange lvalues A and B of type T, e.g.
193    SWAP(int, x, y) where x and y are variables of type int. */
194 #define SWAP(T, A, B) \
195   do {                \
196 	T t_ = (A);       \
197 	A = (B);          \
198 	B = t_;           \
199   } while (0)
200 
201 /* Declare a block of N temporary mpz_t values.
202    These values are initialized to zero.
203    You must add CLEANUP_TEMP() at the end of the function.
204    Use TEMP(i) to access a pointer to the ith value.
205  */
206 #define DECLARE_TEMP(N)                   \
207   struct {                                \
208 	mpz_t value[(N)];                     \
209 	int len;                              \
210 	mp_result err;                        \
211   } temp_ = {                             \
212 	  .len = (N),                         \
213 	  .err = MP_OK,                       \
214   };                                      \
215   do {                                    \
216 	for (int i = 0; i < temp_.len; i++) { \
217 	  mp_int_init(TEMP(i));               \
218 	}                                     \
219   } while (0)
220 
221 /* Clear all allocated temp values. */
222 #define CLEANUP_TEMP()                    \
223   CLEANUP:                                \
224   do {                                    \
225 	for (int i = 0; i < temp_.len; i++) { \
226 	  mp_int_clear(TEMP(i));              \
227 	}                                     \
228 	if (temp_.err != MP_OK) {             \
229 	  return temp_.err;                   \
230 	}                                     \
231   } while (0)
232 
233 /* A pointer to the kth temp value. */
234 #define TEMP(K) (temp_.value + (K))
235 
236 /* Evaluate E, an expression of type mp_result expected to return MP_OK.  If
237    the value is not MP_OK, the error is cached and control resumes at the
238    cleanup handler, which returns it.
239 */
240 #define REQUIRE(E)                        \
241   do {                                    \
242 	temp_.err = (E);                      \
243 	if (temp_.err != MP_OK) goto CLEANUP; \
244   } while (0)
245 
246 /* Compare value to zero. */
247 static inline int
248 CMPZ(mp_int Z)
249 {
250 	if (Z->used == 1 && Z->digits[0] == 0)
251 		return 0;
252 	return (Z->sign == MP_NEG) ? -1 : 1;
253 }
254 
255 static inline mp_word
256 UPPER_HALF(mp_word W)
257 {
258 	return (W >> MP_DIGIT_BIT);
259 }
260 static inline mp_digit
261 LOWER_HALF(mp_word W)
262 {
263 	return (mp_digit) (W);
264 }
265 
266 /* Report whether the highest-order bit of W is 1. */
267 static inline bool
268 HIGH_BIT_SET(mp_word W)
269 {
270 	return (W >> (MP_WORD_BIT - 1)) != 0;
271 }
272 
273 /* Report whether adding W + V will carry out. */
274 static inline bool
275 ADD_WILL_OVERFLOW(mp_word W, mp_word V)
276 {
277 	return ((MP_WORD_MAX - V) < W);
278 }
279 
280 /* Default number of digits allocated to a new mp_int */
281 static mp_size default_precision = 8;
282 
283 void
284 mp_int_default_precision(mp_size size)
285 {
286 	assert(size > 0);
287 	default_precision = size;
288 }
289 
290 /* Minimum number of digits to invoke recursive multiply */
291 static mp_size multiply_threshold = 32;
292 
293 void
294 mp_int_multiply_threshold(mp_size thresh)
295 {
296 	assert(thresh >= sizeof(mp_word));
297 	multiply_threshold = thresh;
298 }
299 
300 /* Allocate a buffer of (at least) num digits, or return
301    NULL if that couldn't be done.  */
302 static mp_digit *s_alloc(mp_size num);
303 
304 /* Release a buffer of digits allocated by s_alloc(). */
305 static void s_free(void *ptr);
306 
307 /* Insure that z has at least min digits allocated, resizing if
308    necessary.  Returns true if successful, false if out of memory. */
309 static bool s_pad(mp_int z, mp_size min);
310 
311 /* Ensure Z has at least N digits allocated. */
312 static inline mp_result
313 GROW(mp_int Z, mp_size N)
314 {
315 	return s_pad(Z, N) ? MP_OK : MP_MEMORY;
316 }
317 
318 /* Fill in a "fake" mp_int on the stack with a given value */
319 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]);
320 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]);
321 
322 /* Compare two runs of digits of given length, returns <0, 0, >0 */
323 static int	s_cdig(mp_digit *da, mp_digit *db, mp_size len);
324 
325 /* Pack the unsigned digits of v into array t */
326 static int	s_uvpack(mp_usmall v, mp_digit t[]);
327 
328 /* Compare magnitudes of a and b, returns <0, 0, >0 */
329 static int	s_ucmp(mp_int a, mp_int b);
330 
331 /* Compare magnitudes of a and v, returns <0, 0, >0 */
332 static int	s_vcmp(mp_int a, mp_small v);
333 static int	s_uvcmp(mp_int a, mp_usmall uv);
334 
335 /* Unsigned magnitude addition; assumes dc is big enough.
336    Carry out is returned (no memory allocated). */
337 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
338 					   mp_size size_b);
339 
340 /* Unsigned magnitude subtraction.  Assumes dc is big enough. */
341 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
342 				   mp_size size_b);
343 
344 /* Unsigned recursive multiplication.  Assumes dc is big enough. */
345 static int	s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
346 				   mp_size size_b);
347 
348 /* Unsigned magnitude multiplication.  Assumes dc is big enough. */
349 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
350 				   mp_size size_b);
351 
352 /* Unsigned recursive squaring.  Assumes dc is big enough. */
353 static int	s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a);
354 
355 /* Unsigned magnitude squaring.  Assumes dc is big enough. */
356 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a);
357 
358 /* Single digit addition.  Assumes a is big enough. */
359 static void s_dadd(mp_int a, mp_digit b);
360 
361 /* Single digit multiplication.  Assumes a is big enough. */
362 static void s_dmul(mp_int a, mp_digit b);
363 
364 /* Single digit multiplication on buffers; assumes dc is big enough. */
365 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a);
366 
367 /* Single digit division.  Replaces a with the quotient,
368    returns the remainder.  */
369 static mp_digit s_ddiv(mp_int a, mp_digit b);
370 
371 /* Quick division by a power of 2, replaces z (no allocation) */
372 static void s_qdiv(mp_int z, mp_size p2);
373 
374 /* Quick remainder by a power of 2, replaces z (no allocation) */
375 static void s_qmod(mp_int z, mp_size p2);
376 
377 /* Quick multiplication by a power of 2, replaces z.
378    Allocates if necessary; returns false in case this fails. */
379 static int	s_qmul(mp_int z, mp_size p2);
380 
381 /* Quick subtraction from a power of 2, replaces z.
382    Allocates if necessary; returns false in case this fails. */
383 static int	s_qsub(mp_int z, mp_size p2);
384 
385 /* Return maximum k such that 2^k divides z. */
386 static int	s_dp2k(mp_int z);
387 
388 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
389 static int	s_isp2(mp_int z);
390 
391 /* Set z to 2^k.  May allocate; returns false in case this fails. */
392 static int	s_2expt(mp_int z, mp_small k);
393 
394 /* Normalize a and b for division, returns normalization constant */
395 static int	s_norm(mp_int a, mp_int b);
396 
397 /* Compute constant mu for Barrett reduction, given modulus m, result
398    replaces z, m is untouched. */
399 static mp_result s_brmu(mp_int z, mp_int m);
400 
401 /* Reduce a modulo m, using Barrett's algorithm. */
402 static int	s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
403 
404 /* Modular exponentiation, using Barrett reduction */
405 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
406 
407 /* Unsigned magnitude division.  Assumes |a| > |b|.  Allocates temporaries;
408    overwrites a with quotient, b with remainder. */
409 static mp_result s_udiv_knuth(mp_int a, mp_int b);
410 
411 /* Compute the number of digits in radix r required to represent the given
412    value.  Does not account for sign flags, terminators, etc. */
413 static int	s_outlen(mp_int z, mp_size r);
414 
415 /* Guess how many digits of precision will be needed to represent a radix r
416    value of the specified number of digits.  Returns a value guaranteed to be
417    no smaller than the actual number required. */
418 static mp_size s_inlen(int len, mp_size r);
419 
420 /* Convert a character to a digit value in radix r, or
421    -1 if out of range */
422 static int	s_ch2val(char c, int r);
423 
424 /* Convert a digit value to a character */
425 static char s_val2ch(int v, int caps);
426 
427 /* Take 2's complement of a buffer in place */
428 static void s_2comp(unsigned char *buf, int len);
429 
430 /* Convert a value to binary, ignoring sign.  On input, *limpos is the bound on
431    how many bytes should be written to buf; on output, *limpos is set to the
432    number of bytes actually written. */
433 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
434 
435 /* Multiply X by Y into Z, ignoring signs.  Requires that Z have enough storage
436    preallocated to hold the result. */
437 static inline void
438 UMUL(mp_int X, mp_int Y, mp_int Z)
439 {
440 	mp_size		ua_ = MP_USED(X);
441 	mp_size		ub_ = MP_USED(Y);
442 	mp_size		o_ = ua_ + ub_;
443 
444 	ZERO(MP_DIGITS(Z), o_);
445 	(void) s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_);
446 	Z->used = o_;
447 	CLAMP(Z);
448 }
449 
450 /* Square X into Z.  Requires that Z have enough storage to hold the result. */
451 static inline void
452 USQR(mp_int X, mp_int Z)
453 {
454 	mp_size		ua_ = MP_USED(X);
455 	mp_size		o_ = ua_ + ua_;
456 
457 	ZERO(MP_DIGITS(Z), o_);
458 	(void) s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_);
459 	Z->used = o_;
460 	CLAMP(Z);
461 }
462 
463 mp_result
464 mp_int_init(mp_int z)
465 {
466 	if (z == NULL)
467 		return MP_BADARG;
468 
469 	z->single = 0;
470 	z->digits = &(z->single);
471 	z->alloc = 1;
472 	z->used = 1;
473 	z->sign = MP_ZPOS;
474 
475 	return MP_OK;
476 }
477 
478 mp_int
479 mp_int_alloc(void)
480 {
481 	mp_int		out = px_alloc(sizeof(mpz_t));
482 
483 	if (out != NULL)
484 		mp_int_init(out);
485 
486 	return out;
487 }
488 
489 mp_result
490 mp_int_init_size(mp_int z, mp_size prec)
491 {
492 	assert(z != NULL);
493 
494 	if (prec == 0)
495 	{
496 		prec = default_precision;
497 	}
498 	else if (prec == 1)
499 	{
500 		return mp_int_init(z);
501 	}
502 	else
503 	{
504 		prec = s_round_prec(prec);
505 	}
506 
507 	z->digits = s_alloc(prec);
508 	if (MP_DIGITS(z) == NULL)
509 		return MP_MEMORY;
510 
511 	z->digits[0] = 0;
512 	z->used = 1;
513 	z->alloc = prec;
514 	z->sign = MP_ZPOS;
515 
516 	return MP_OK;
517 }
518 
519 mp_result
520 mp_int_init_copy(mp_int z, mp_int old)
521 {
522 	assert(z != NULL && old != NULL);
523 
524 	mp_size		uold = MP_USED(old);
525 
526 	if (uold == 1)
527 	{
528 		mp_int_init(z);
529 	}
530 	else
531 	{
532 		mp_size		target = MAX(uold, default_precision);
533 		mp_result	res = mp_int_init_size(z, target);
534 
535 		if (res != MP_OK)
536 			return res;
537 	}
538 
539 	z->used = uold;
540 	z->sign = old->sign;
541 	COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
542 
543 	return MP_OK;
544 }
545 
546 mp_result
547 mp_int_init_value(mp_int z, mp_small value)
548 {
549 	mpz_t		vtmp;
550 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
551 
552 	s_fake(&vtmp, value, vbuf);
553 	return mp_int_init_copy(z, &vtmp);
554 }
555 
556 mp_result
557 mp_int_init_uvalue(mp_int z, mp_usmall uvalue)
558 {
559 	mpz_t		vtmp;
560 	mp_digit	vbuf[MP_VALUE_DIGITS(uvalue)];
561 
562 	s_ufake(&vtmp, uvalue, vbuf);
563 	return mp_int_init_copy(z, &vtmp);
564 }
565 
566 mp_result
567 mp_int_set_value(mp_int z, mp_small value)
568 {
569 	mpz_t		vtmp;
570 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
571 
572 	s_fake(&vtmp, value, vbuf);
573 	return mp_int_copy(&vtmp, z);
574 }
575 
576 mp_result
577 mp_int_set_uvalue(mp_int z, mp_usmall uvalue)
578 {
579 	mpz_t		vtmp;
580 	mp_digit	vbuf[MP_VALUE_DIGITS(uvalue)];
581 
582 	s_ufake(&vtmp, uvalue, vbuf);
583 	return mp_int_copy(&vtmp, z);
584 }
585 
586 void
587 mp_int_clear(mp_int z)
588 {
589 	if (z == NULL)
590 		return;
591 
592 	if (MP_DIGITS(z) != NULL)
593 	{
594 		if (MP_DIGITS(z) != &(z->single))
595 			s_free(MP_DIGITS(z));
596 
597 		z->digits = NULL;
598 	}
599 }
600 
601 void
602 mp_int_free(mp_int z)
603 {
604 	assert(z != NULL);
605 
606 	mp_int_clear(z);
607 	px_free(z);					/* note: NOT s_free() */
608 }
609 
610 mp_result
611 mp_int_copy(mp_int a, mp_int c)
612 {
613 	assert(a != NULL && c != NULL);
614 
615 	if (a != c)
616 	{
617 		mp_size		ua = MP_USED(a);
618 		mp_digit   *da,
619 				   *dc;
620 
621 		if (!s_pad(c, ua))
622 			return MP_MEMORY;
623 
624 		da = MP_DIGITS(a);
625 		dc = MP_DIGITS(c);
626 		COPY(da, dc, ua);
627 
628 		c->used = ua;
629 		c->sign = a->sign;
630 	}
631 
632 	return MP_OK;
633 }
634 
635 void
636 mp_int_swap(mp_int a, mp_int c)
637 {
638 	if (a != c)
639 	{
640 		mpz_t		tmp = *a;
641 
642 		*a = *c;
643 		*c = tmp;
644 
645 		if (MP_DIGITS(a) == &(c->single))
646 			a->digits = &(a->single);
647 		if (MP_DIGITS(c) == &(a->single))
648 			c->digits = &(c->single);
649 	}
650 }
651 
652 void
653 mp_int_zero(mp_int z)
654 {
655 	assert(z != NULL);
656 
657 	z->digits[0] = 0;
658 	z->used = 1;
659 	z->sign = MP_ZPOS;
660 }
661 
662 mp_result
663 mp_int_abs(mp_int a, mp_int c)
664 {
665 	assert(a != NULL && c != NULL);
666 
667 	mp_result	res;
668 
669 	if ((res = mp_int_copy(a, c)) != MP_OK)
670 		return res;
671 
672 	c->sign = MP_ZPOS;
673 	return MP_OK;
674 }
675 
676 mp_result
677 mp_int_neg(mp_int a, mp_int c)
678 {
679 	assert(a != NULL && c != NULL);
680 
681 	mp_result	res;
682 
683 	if ((res = mp_int_copy(a, c)) != MP_OK)
684 		return res;
685 
686 	if (CMPZ(c) != 0)
687 		c->sign = 1 - MP_SIGN(a);
688 
689 	return MP_OK;
690 }
691 
692 mp_result
693 mp_int_add(mp_int a, mp_int b, mp_int c)
694 {
695 	assert(a != NULL && b != NULL && c != NULL);
696 
697 	mp_size		ua = MP_USED(a);
698 	mp_size		ub = MP_USED(b);
699 	mp_size		max = MAX(ua, ub);
700 
701 	if (MP_SIGN(a) == MP_SIGN(b))
702 	{
703 		/* Same sign -- add magnitudes, preserve sign of addends */
704 		if (!s_pad(c, max))
705 			return MP_MEMORY;
706 
707 		mp_digit	carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
708 		mp_size		uc = max;
709 
710 		if (carry)
711 		{
712 			if (!s_pad(c, max + 1))
713 				return MP_MEMORY;
714 
715 			c->digits[max] = carry;
716 			++uc;
717 		}
718 
719 		c->used = uc;
720 		c->sign = a->sign;
721 
722 	}
723 	else
724 	{
725 		/* Different signs -- subtract magnitudes, preserve sign of greater */
726 		int			cmp = s_ucmp(a, b); /* magnitude comparision, sign ignored */
727 
728 		/*
729 		 * Set x to max(a, b), y to min(a, b) to simplify later code. A
730 		 * special case yields zero for equal magnitudes.
731 		 */
732 		mp_int		x,
733 					y;
734 
735 		if (cmp == 0)
736 		{
737 			mp_int_zero(c);
738 			return MP_OK;
739 		}
740 		else if (cmp < 0)
741 		{
742 			x = b;
743 			y = a;
744 		}
745 		else
746 		{
747 			x = a;
748 			y = b;
749 		}
750 
751 		if (!s_pad(c, MP_USED(x)))
752 			return MP_MEMORY;
753 
754 		/* Subtract smaller from larger */
755 		s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
756 		c->used = x->used;
757 		CLAMP(c);
758 
759 		/* Give result the sign of the larger */
760 		c->sign = x->sign;
761 	}
762 
763 	return MP_OK;
764 }
765 
766 mp_result
767 mp_int_add_value(mp_int a, mp_small value, mp_int c)
768 {
769 	mpz_t		vtmp;
770 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
771 
772 	s_fake(&vtmp, value, vbuf);
773 
774 	return mp_int_add(a, &vtmp, c);
775 }
776 
777 mp_result
778 mp_int_sub(mp_int a, mp_int b, mp_int c)
779 {
780 	assert(a != NULL && b != NULL && c != NULL);
781 
782 	mp_size		ua = MP_USED(a);
783 	mp_size		ub = MP_USED(b);
784 	mp_size		max = MAX(ua, ub);
785 
786 	if (MP_SIGN(a) != MP_SIGN(b))
787 	{
788 		/* Different signs -- add magnitudes and keep sign of a */
789 		if (!s_pad(c, max))
790 			return MP_MEMORY;
791 
792 		mp_digit	carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
793 		mp_size		uc = max;
794 
795 		if (carry)
796 		{
797 			if (!s_pad(c, max + 1))
798 				return MP_MEMORY;
799 
800 			c->digits[max] = carry;
801 			++uc;
802 		}
803 
804 		c->used = uc;
805 		c->sign = a->sign;
806 
807 	}
808 	else
809 	{
810 		/* Same signs -- subtract magnitudes */
811 		if (!s_pad(c, max))
812 			return MP_MEMORY;
813 		mp_int		x,
814 					y;
815 		mp_sign		osign;
816 
817 		int			cmp = s_ucmp(a, b);
818 
819 		if (cmp >= 0)
820 		{
821 			x = a;
822 			y = b;
823 			osign = MP_ZPOS;
824 		}
825 		else
826 		{
827 			x = b;
828 			y = a;
829 			osign = MP_NEG;
830 		}
831 
832 		if (MP_SIGN(a) == MP_NEG && cmp != 0)
833 			osign = 1 - osign;
834 
835 		s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
836 		c->used = x->used;
837 		CLAMP(c);
838 
839 		c->sign = osign;
840 	}
841 
842 	return MP_OK;
843 }
844 
845 mp_result
846 mp_int_sub_value(mp_int a, mp_small value, mp_int c)
847 {
848 	mpz_t		vtmp;
849 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
850 
851 	s_fake(&vtmp, value, vbuf);
852 
853 	return mp_int_sub(a, &vtmp, c);
854 }
855 
856 mp_result
857 mp_int_mul(mp_int a, mp_int b, mp_int c)
858 {
859 	assert(a != NULL && b != NULL && c != NULL);
860 
861 	/* If either input is zero, we can shortcut multiplication */
862 	if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0)
863 	{
864 		mp_int_zero(c);
865 		return MP_OK;
866 	}
867 
868 	/* Output is positive if inputs have same sign, otherwise negative */
869 	mp_sign		osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
870 
871 	/*
872 	 * If the output is not identical to any of the inputs, we'll write the
873 	 * results directly; otherwise, allocate a temporary space.
874 	 */
875 	mp_size		ua = MP_USED(a);
876 	mp_size		ub = MP_USED(b);
877 	mp_size		osize = MAX(ua, ub);
878 
879 	osize = 4 * ((osize + 1) / 2);
880 
881 	mp_digit   *out;
882 	mp_size		p = 0;
883 
884 	if (c == a || c == b)
885 	{
886 		p = MAX(s_round_prec(osize), default_precision);
887 
888 		if ((out = s_alloc(p)) == NULL)
889 			return MP_MEMORY;
890 	}
891 	else
892 	{
893 		if (!s_pad(c, osize))
894 			return MP_MEMORY;
895 
896 		out = MP_DIGITS(c);
897 	}
898 	ZERO(out, osize);
899 
900 	if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub))
901 		return MP_MEMORY;
902 
903 	/*
904 	 * If we allocated a new buffer, get rid of whatever memory c was already
905 	 * using, and fix up its fields to reflect that.
906 	 */
907 	if (out != MP_DIGITS(c))
908 	{
909 		if ((void *) MP_DIGITS(c) != (void *) c)
910 			s_free(MP_DIGITS(c));
911 		c->digits = out;
912 		c->alloc = p;
913 	}
914 
915 	c->used = osize;			/* might not be true, but we'll fix it ... */
916 	CLAMP(c);					/* ... right here */
917 	c->sign = osign;
918 
919 	return MP_OK;
920 }
921 
922 mp_result
923 mp_int_mul_value(mp_int a, mp_small value, mp_int c)
924 {
925 	mpz_t		vtmp;
926 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
927 
928 	s_fake(&vtmp, value, vbuf);
929 
930 	return mp_int_mul(a, &vtmp, c);
931 }
932 
933 mp_result
934 mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c)
935 {
936 	assert(a != NULL && c != NULL && p2 >= 0);
937 
938 	mp_result	res = mp_int_copy(a, c);
939 
940 	if (res != MP_OK)
941 		return res;
942 
943 	if (s_qmul(c, (mp_size) p2))
944 	{
945 		return MP_OK;
946 	}
947 	else
948 	{
949 		return MP_MEMORY;
950 	}
951 }
952 
953 mp_result
954 mp_int_sqr(mp_int a, mp_int c)
955 {
956 	assert(a != NULL && c != NULL);
957 
958 	/* Get a temporary buffer big enough to hold the result */
959 	mp_size		osize = (mp_size) 4 * ((MP_USED(a) + 1) / 2);
960 	mp_size		p = 0;
961 	mp_digit   *out;
962 
963 	if (a == c)
964 	{
965 		p = s_round_prec(osize);
966 		p = MAX(p, default_precision);
967 
968 		if ((out = s_alloc(p)) == NULL)
969 			return MP_MEMORY;
970 	}
971 	else
972 	{
973 		if (!s_pad(c, osize))
974 			return MP_MEMORY;
975 
976 		out = MP_DIGITS(c);
977 	}
978 	ZERO(out, osize);
979 
980 	s_ksqr(MP_DIGITS(a), out, MP_USED(a));
981 
982 	/*
983 	 * Get rid of whatever memory c was already using, and fix up its fields
984 	 * to reflect the new digit array it's using
985 	 */
986 	if (out != MP_DIGITS(c))
987 	{
988 		if ((void *) MP_DIGITS(c) != (void *) c)
989 			s_free(MP_DIGITS(c));
990 		c->digits = out;
991 		c->alloc = p;
992 	}
993 
994 	c->used = osize;			/* might not be true, but we'll fix it ... */
995 	CLAMP(c);					/* ... right here */
996 	c->sign = MP_ZPOS;
997 
998 	return MP_OK;
999 }
1000 
1001 mp_result
1002 mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r)
1003 {
1004 	assert(a != NULL && b != NULL && q != r);
1005 
1006 	int			cmp;
1007 	mp_result	res = MP_OK;
1008 	mp_int		qout,
1009 				rout;
1010 	mp_sign		sa = MP_SIGN(a);
1011 	mp_sign		sb = MP_SIGN(b);
1012 
1013 	if (CMPZ(b) == 0)
1014 	{
1015 		return MP_UNDEF;
1016 	}
1017 	else if ((cmp = s_ucmp(a, b)) < 0)
1018 	{
1019 		/*
1020 		 * If |a| < |b|, no division is required: q = 0, r = a
1021 		 */
1022 		if (r && (res = mp_int_copy(a, r)) != MP_OK)
1023 			return res;
1024 
1025 		if (q)
1026 			mp_int_zero(q);
1027 
1028 		return MP_OK;
1029 	}
1030 	else if (cmp == 0)
1031 	{
1032 		/*
1033 		 * If |a| = |b|, no division is required: q = 1 or -1, r = 0
1034 		 */
1035 		if (r)
1036 			mp_int_zero(r);
1037 
1038 		if (q)
1039 		{
1040 			mp_int_zero(q);
1041 			q->digits[0] = 1;
1042 
1043 			if (sa != sb)
1044 				q->sign = MP_NEG;
1045 		}
1046 
1047 		return MP_OK;
1048 	}
1049 
1050 	/*
1051 	 * When |a| > |b|, real division is required.  We need someplace to store
1052 	 * quotient and remainder, but q and r are allowed to be NULL or to
1053 	 * overlap with the inputs.
1054 	 */
1055 	DECLARE_TEMP(2);
1056 	int			lg;
1057 
1058 	if ((lg = s_isp2(b)) < 0)
1059 	{
1060 		if (q && b != q)
1061 		{
1062 			REQUIRE(mp_int_copy(a, q));
1063 			qout = q;
1064 		}
1065 		else
1066 		{
1067 			REQUIRE(mp_int_copy(a, TEMP(0)));
1068 			qout = TEMP(0);
1069 		}
1070 
1071 		if (r && a != r)
1072 		{
1073 			REQUIRE(mp_int_copy(b, r));
1074 			rout = r;
1075 		}
1076 		else
1077 		{
1078 			REQUIRE(mp_int_copy(b, TEMP(1)));
1079 			rout = TEMP(1);
1080 		}
1081 
1082 		REQUIRE(s_udiv_knuth(qout, rout));
1083 	}
1084 	else
1085 	{
1086 		if (q)
1087 			REQUIRE(mp_int_copy(a, q));
1088 		if (r)
1089 			REQUIRE(mp_int_copy(a, r));
1090 
1091 		if (q)
1092 			s_qdiv(q, (mp_size) lg);
1093 		qout = q;
1094 		if (r)
1095 			s_qmod(r, (mp_size) lg);
1096 		rout = r;
1097 	}
1098 
1099 	/* Recompute signs for output */
1100 	if (rout)
1101 	{
1102 		rout->sign = sa;
1103 		if (CMPZ(rout) == 0)
1104 			rout->sign = MP_ZPOS;
1105 	}
1106 	if (qout)
1107 	{
1108 		qout->sign = (sa == sb) ? MP_ZPOS : MP_NEG;
1109 		if (CMPZ(qout) == 0)
1110 			qout->sign = MP_ZPOS;
1111 	}
1112 
1113 	if (q)
1114 		REQUIRE(mp_int_copy(qout, q));
1115 	if (r)
1116 		REQUIRE(mp_int_copy(rout, r));
1117 	CLEANUP_TEMP();
1118 	return res;
1119 }
1120 
1121 mp_result
1122 mp_int_mod(mp_int a, mp_int m, mp_int c)
1123 {
1124 	DECLARE_TEMP(1);
1125 	mp_int		out = (m == c) ? TEMP(0) : c;
1126 
1127 	REQUIRE(mp_int_div(a, m, NULL, out));
1128 	if (CMPZ(out) < 0)
1129 	{
1130 		REQUIRE(mp_int_add(out, m, c));
1131 	}
1132 	else
1133 	{
1134 		REQUIRE(mp_int_copy(out, c));
1135 	}
1136 	CLEANUP_TEMP();
1137 	return MP_OK;
1138 }
1139 
1140 mp_result
1141 mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r)
1142 {
1143 	mpz_t		vtmp;
1144 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
1145 
1146 	s_fake(&vtmp, value, vbuf);
1147 
1148 	DECLARE_TEMP(1);
1149 	REQUIRE(mp_int_div(a, &vtmp, q, TEMP(0)));
1150 
1151 	if (r)
1152 		(void) mp_int_to_int(TEMP(0), r);	/* can't fail */
1153 
1154 	CLEANUP_TEMP();
1155 	return MP_OK;
1156 }
1157 
1158 mp_result
1159 mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r)
1160 {
1161 	assert(a != NULL && p2 >= 0 && q != r);
1162 
1163 	mp_result	res = MP_OK;
1164 
1165 	if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK)
1166 	{
1167 		s_qdiv(q, (mp_size) p2);
1168 	}
1169 
1170 	if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK)
1171 	{
1172 		s_qmod(r, (mp_size) p2);
1173 	}
1174 
1175 	return res;
1176 }
1177 
1178 mp_result
1179 mp_int_expt(mp_int a, mp_small b, mp_int c)
1180 {
1181 	assert(c != NULL);
1182 	if (b < 0)
1183 		return MP_RANGE;
1184 
1185 	DECLARE_TEMP(1);
1186 	REQUIRE(mp_int_copy(a, TEMP(0)));
1187 
1188 	(void) mp_int_set_value(c, 1);
1189 	unsigned int v = labs(b);
1190 
1191 	while (v != 0)
1192 	{
1193 		if (v & 1)
1194 		{
1195 			REQUIRE(mp_int_mul(c, TEMP(0), c));
1196 		}
1197 
1198 		v >>= 1;
1199 		if (v == 0)
1200 			break;
1201 
1202 		REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
1203 	}
1204 
1205 	CLEANUP_TEMP();
1206 	return MP_OK;
1207 }
1208 
1209 mp_result
1210 mp_int_expt_value(mp_small a, mp_small b, mp_int c)
1211 {
1212 	assert(c != NULL);
1213 	if (b < 0)
1214 		return MP_RANGE;
1215 
1216 	DECLARE_TEMP(1);
1217 	REQUIRE(mp_int_set_value(TEMP(0), a));
1218 
1219 	(void) mp_int_set_value(c, 1);
1220 	unsigned int v = labs(b);
1221 
1222 	while (v != 0)
1223 	{
1224 		if (v & 1)
1225 		{
1226 			REQUIRE(mp_int_mul(c, TEMP(0), c));
1227 		}
1228 
1229 		v >>= 1;
1230 		if (v == 0)
1231 			break;
1232 
1233 		REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
1234 	}
1235 
1236 	CLEANUP_TEMP();
1237 	return MP_OK;
1238 }
1239 
1240 mp_result
1241 mp_int_expt_full(mp_int a, mp_int b, mp_int c)
1242 {
1243 	assert(a != NULL && b != NULL && c != NULL);
1244 	if (MP_SIGN(b) == MP_NEG)
1245 		return MP_RANGE;
1246 
1247 	DECLARE_TEMP(1);
1248 	REQUIRE(mp_int_copy(a, TEMP(0)));
1249 
1250 	(void) mp_int_set_value(c, 1);
1251 	for (unsigned ix = 0; ix < MP_USED(b); ++ix)
1252 	{
1253 		mp_digit	d = b->digits[ix];
1254 
1255 		for (unsigned jx = 0; jx < MP_DIGIT_BIT; ++jx)
1256 		{
1257 			if (d & 1)
1258 			{
1259 				REQUIRE(mp_int_mul(c, TEMP(0), c));
1260 			}
1261 
1262 			d >>= 1;
1263 			if (d == 0 && ix + 1 == MP_USED(b))
1264 				break;
1265 			REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
1266 		}
1267 	}
1268 
1269 	CLEANUP_TEMP();
1270 	return MP_OK;
1271 }
1272 
1273 int
1274 mp_int_compare(mp_int a, mp_int b)
1275 {
1276 	assert(a != NULL && b != NULL);
1277 
1278 	mp_sign		sa = MP_SIGN(a);
1279 
1280 	if (sa == MP_SIGN(b))
1281 	{
1282 		int			cmp = s_ucmp(a, b);
1283 
1284 		/*
1285 		 * If they're both zero or positive, the normal comparison applies; if
1286 		 * both negative, the sense is reversed.
1287 		 */
1288 		if (sa == MP_ZPOS)
1289 		{
1290 			return cmp;
1291 		}
1292 		else
1293 		{
1294 			return -cmp;
1295 		}
1296 	}
1297 	else if (sa == MP_ZPOS)
1298 	{
1299 		return 1;
1300 	}
1301 	else
1302 	{
1303 		return -1;
1304 	}
1305 }
1306 
1307 int
1308 mp_int_compare_unsigned(mp_int a, mp_int b)
1309 {
1310 	assert(a != NULL && b != NULL);
1311 
1312 	return s_ucmp(a, b);
1313 }
1314 
1315 int
1316 mp_int_compare_zero(mp_int z)
1317 {
1318 	assert(z != NULL);
1319 
1320 	if (MP_USED(z) == 1 && z->digits[0] == 0)
1321 	{
1322 		return 0;
1323 	}
1324 	else if (MP_SIGN(z) == MP_ZPOS)
1325 	{
1326 		return 1;
1327 	}
1328 	else
1329 	{
1330 		return -1;
1331 	}
1332 }
1333 
1334 int
1335 mp_int_compare_value(mp_int z, mp_small value)
1336 {
1337 	assert(z != NULL);
1338 
1339 	mp_sign		vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1340 
1341 	if (vsign == MP_SIGN(z))
1342 	{
1343 		int			cmp = s_vcmp(z, value);
1344 
1345 		return (vsign == MP_ZPOS) ? cmp : -cmp;
1346 	}
1347 	else
1348 	{
1349 		return (value < 0) ? 1 : -1;
1350 	}
1351 }
1352 
1353 int
1354 mp_int_compare_uvalue(mp_int z, mp_usmall uv)
1355 {
1356 	assert(z != NULL);
1357 
1358 	if (MP_SIGN(z) == MP_NEG)
1359 	{
1360 		return -1;
1361 	}
1362 	else
1363 	{
1364 		return s_uvcmp(z, uv);
1365 	}
1366 }
1367 
1368 mp_result
1369 mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c)
1370 {
1371 	assert(a != NULL && b != NULL && c != NULL && m != NULL);
1372 
1373 	/* Zero moduli and negative exponents are not considered. */
1374 	if (CMPZ(m) == 0)
1375 		return MP_UNDEF;
1376 	if (CMPZ(b) < 0)
1377 		return MP_RANGE;
1378 
1379 	mp_size		um = MP_USED(m);
1380 
1381 	DECLARE_TEMP(3);
1382 	REQUIRE(GROW(TEMP(0), 2 * um));
1383 	REQUIRE(GROW(TEMP(1), 2 * um));
1384 
1385 	mp_int		s;
1386 
1387 	if (c == b || c == m)
1388 	{
1389 		REQUIRE(GROW(TEMP(2), 2 * um));
1390 		s = TEMP(2);
1391 	}
1392 	else
1393 	{
1394 		s = c;
1395 	}
1396 
1397 	REQUIRE(mp_int_mod(a, m, TEMP(0)));
1398 	REQUIRE(s_brmu(TEMP(1), m));
1399 	REQUIRE(s_embar(TEMP(0), b, m, TEMP(1), s));
1400 	REQUIRE(mp_int_copy(s, c));
1401 
1402 	CLEANUP_TEMP();
1403 	return MP_OK;
1404 }
1405 
1406 mp_result
1407 mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c)
1408 {
1409 	mpz_t		vtmp;
1410 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
1411 
1412 	s_fake(&vtmp, value, vbuf);
1413 
1414 	return mp_int_exptmod(a, &vtmp, m, c);
1415 }
1416 
1417 mp_result
1418 mp_int_exptmod_bvalue(mp_small value, mp_int b, mp_int m, mp_int c)
1419 {
1420 	mpz_t		vtmp;
1421 	mp_digit	vbuf[MP_VALUE_DIGITS(value)];
1422 
1423 	s_fake(&vtmp, value, vbuf);
1424 
1425 	return mp_int_exptmod(&vtmp, b, m, c);
1426 }
1427 
1428 mp_result
1429 mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu,
1430 					 mp_int c)
1431 {
1432 	assert(a && b && m && c);
1433 
1434 	/* Zero moduli and negative exponents are not considered. */
1435 	if (CMPZ(m) == 0)
1436 		return MP_UNDEF;
1437 	if (CMPZ(b) < 0)
1438 		return MP_RANGE;
1439 
1440 	DECLARE_TEMP(2);
1441 	mp_size		um = MP_USED(m);
1442 
1443 	REQUIRE(GROW(TEMP(0), 2 * um));
1444 
1445 	mp_int		s;
1446 
1447 	if (c == b || c == m)
1448 	{
1449 		REQUIRE(GROW(TEMP(1), 2 * um));
1450 		s = TEMP(1);
1451 	}
1452 	else
1453 	{
1454 		s = c;
1455 	}
1456 
1457 	REQUIRE(mp_int_mod(a, m, TEMP(0)));
1458 	REQUIRE(s_embar(TEMP(0), b, m, mu, s));
1459 	REQUIRE(mp_int_copy(s, c));
1460 
1461 	CLEANUP_TEMP();
1462 	return MP_OK;
1463 }
1464 
1465 mp_result
1466 mp_int_redux_const(mp_int m, mp_int c)
1467 {
1468 	assert(m != NULL && c != NULL && m != c);
1469 
1470 	return s_brmu(c, m);
1471 }
1472 
1473 mp_result
1474 mp_int_invmod(mp_int a, mp_int m, mp_int c)
1475 {
1476 	assert(a != NULL && m != NULL && c != NULL);
1477 
1478 	if (CMPZ(a) == 0 || CMPZ(m) <= 0)
1479 		return MP_RANGE;
1480 
1481 	DECLARE_TEMP(2);
1482 
1483 	REQUIRE(mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL));
1484 
1485 	if (mp_int_compare_value(TEMP(0), 1) != 0)
1486 	{
1487 		REQUIRE(MP_UNDEF);
1488 	}
1489 
1490 	/* It is first necessary to constrain the value to the proper range */
1491 	REQUIRE(mp_int_mod(TEMP(1), m, TEMP(1)));
1492 
1493 	/*
1494 	 * Now, if 'a' was originally negative, the value we have is actually the
1495 	 * magnitude of the negative representative; to get the positive value we
1496 	 * have to subtract from the modulus.  Otherwise, the value is okay as it
1497 	 * stands.
1498 	 */
1499 	if (MP_SIGN(a) == MP_NEG)
1500 	{
1501 		REQUIRE(mp_int_sub(m, TEMP(1), c));
1502 	}
1503 	else
1504 	{
1505 		REQUIRE(mp_int_copy(TEMP(1), c));
1506 	}
1507 
1508 	CLEANUP_TEMP();
1509 	return MP_OK;
1510 }
1511 
1512 /* Binary GCD algorithm due to Josef Stein, 1961 */
1513 mp_result
1514 mp_int_gcd(mp_int a, mp_int b, mp_int c)
1515 {
1516 	assert(a != NULL && b != NULL && c != NULL);
1517 
1518 	int			ca = CMPZ(a);
1519 	int			cb = CMPZ(b);
1520 
1521 	if (ca == 0 && cb == 0)
1522 	{
1523 		return MP_UNDEF;
1524 	}
1525 	else if (ca == 0)
1526 	{
1527 		return mp_int_abs(b, c);
1528 	}
1529 	else if (cb == 0)
1530 	{
1531 		return mp_int_abs(a, c);
1532 	}
1533 
1534 	DECLARE_TEMP(3);
1535 	REQUIRE(mp_int_copy(a, TEMP(0)));
1536 	REQUIRE(mp_int_copy(b, TEMP(1)));
1537 
1538 	TEMP(0)->sign = MP_ZPOS;
1539 	TEMP(1)->sign = MP_ZPOS;
1540 
1541 	int			k = 0;
1542 
1543 	{							/* Divide out common factors of 2 from u and v */
1544 		int			div2_u = s_dp2k(TEMP(0));
1545 		int			div2_v = s_dp2k(TEMP(1));
1546 
1547 		k = MIN(div2_u, div2_v);
1548 		s_qdiv(TEMP(0), (mp_size) k);
1549 		s_qdiv(TEMP(1), (mp_size) k);
1550 	}
1551 
1552 	if (mp_int_is_odd(TEMP(0)))
1553 	{
1554 		REQUIRE(mp_int_neg(TEMP(1), TEMP(2)));
1555 	}
1556 	else
1557 	{
1558 		REQUIRE(mp_int_copy(TEMP(0), TEMP(2)));
1559 	}
1560 
1561 	for (;;)
1562 	{
1563 		s_qdiv(TEMP(2), s_dp2k(TEMP(2)));
1564 
1565 		if (CMPZ(TEMP(2)) > 0)
1566 		{
1567 			REQUIRE(mp_int_copy(TEMP(2), TEMP(0)));
1568 		}
1569 		else
1570 		{
1571 			REQUIRE(mp_int_neg(TEMP(2), TEMP(1)));
1572 		}
1573 
1574 		REQUIRE(mp_int_sub(TEMP(0), TEMP(1), TEMP(2)));
1575 
1576 		if (CMPZ(TEMP(2)) == 0)
1577 			break;
1578 	}
1579 
1580 	REQUIRE(mp_int_abs(TEMP(0), c));
1581 	if (!s_qmul(c, (mp_size) k))
1582 		REQUIRE(MP_MEMORY);
1583 
1584 	CLEANUP_TEMP();
1585 	return MP_OK;
1586 }
1587 
1588 /* This is the binary GCD algorithm again, but this time we keep track of the
1589    elementary matrix operations as we go, so we can get values x and y
1590    satisfying c = ax + by.
1591  */
1592 mp_result
1593 mp_int_egcd(mp_int a, mp_int b, mp_int c, mp_int x, mp_int y)
1594 {
1595 	assert(a != NULL && b != NULL && c != NULL && (x != NULL || y != NULL));
1596 
1597 	mp_result	res = MP_OK;
1598 	int			ca = CMPZ(a);
1599 	int			cb = CMPZ(b);
1600 
1601 	if (ca == 0 && cb == 0)
1602 	{
1603 		return MP_UNDEF;
1604 	}
1605 	else if (ca == 0)
1606 	{
1607 		if ((res = mp_int_abs(b, c)) != MP_OK)
1608 			return res;
1609 		mp_int_zero(x);
1610 		(void) mp_int_set_value(y, 1);
1611 		return MP_OK;
1612 	}
1613 	else if (cb == 0)
1614 	{
1615 		if ((res = mp_int_abs(a, c)) != MP_OK)
1616 			return res;
1617 		(void) mp_int_set_value(x, 1);
1618 		mp_int_zero(y);
1619 		return MP_OK;
1620 	}
1621 
1622 	/*
1623 	 * Initialize temporaries: A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7
1624 	 */
1625 	DECLARE_TEMP(8);
1626 	REQUIRE(mp_int_set_value(TEMP(0), 1));
1627 	REQUIRE(mp_int_set_value(TEMP(3), 1));
1628 	REQUIRE(mp_int_copy(a, TEMP(4)));
1629 	REQUIRE(mp_int_copy(b, TEMP(5)));
1630 
1631 	/* We will work with absolute values here */
1632 	TEMP(4)->sign = MP_ZPOS;
1633 	TEMP(5)->sign = MP_ZPOS;
1634 
1635 	int			k = 0;
1636 
1637 	{							/* Divide out common factors of 2 from u and v */
1638 		int			div2_u = s_dp2k(TEMP(4)),
1639 					div2_v = s_dp2k(TEMP(5));
1640 
1641 		k = MIN(div2_u, div2_v);
1642 		s_qdiv(TEMP(4), k);
1643 		s_qdiv(TEMP(5), k);
1644 	}
1645 
1646 	REQUIRE(mp_int_copy(TEMP(4), TEMP(6)));
1647 	REQUIRE(mp_int_copy(TEMP(5), TEMP(7)));
1648 
1649 	for (;;)
1650 	{
1651 		while (mp_int_is_even(TEMP(4)))
1652 		{
1653 			s_qdiv(TEMP(4), 1);
1654 
1655 			if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1)))
1656 			{
1657 				REQUIRE(mp_int_add(TEMP(0), TEMP(7), TEMP(0)));
1658 				REQUIRE(mp_int_sub(TEMP(1), TEMP(6), TEMP(1)));
1659 			}
1660 
1661 			s_qdiv(TEMP(0), 1);
1662 			s_qdiv(TEMP(1), 1);
1663 		}
1664 
1665 		while (mp_int_is_even(TEMP(5)))
1666 		{
1667 			s_qdiv(TEMP(5), 1);
1668 
1669 			if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3)))
1670 			{
1671 				REQUIRE(mp_int_add(TEMP(2), TEMP(7), TEMP(2)));
1672 				REQUIRE(mp_int_sub(TEMP(3), TEMP(6), TEMP(3)));
1673 			}
1674 
1675 			s_qdiv(TEMP(2), 1);
1676 			s_qdiv(TEMP(3), 1);
1677 		}
1678 
1679 		if (mp_int_compare(TEMP(4), TEMP(5)) >= 0)
1680 		{
1681 			REQUIRE(mp_int_sub(TEMP(4), TEMP(5), TEMP(4)));
1682 			REQUIRE(mp_int_sub(TEMP(0), TEMP(2), TEMP(0)));
1683 			REQUIRE(mp_int_sub(TEMP(1), TEMP(3), TEMP(1)));
1684 		}
1685 		else
1686 		{
1687 			REQUIRE(mp_int_sub(TEMP(5), TEMP(4), TEMP(5)));
1688 			REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1689 			REQUIRE(mp_int_sub(TEMP(3), TEMP(1), TEMP(3)));
1690 		}
1691 
1692 		if (CMPZ(TEMP(4)) == 0)
1693 		{
1694 			if (x)
1695 				REQUIRE(mp_int_copy(TEMP(2), x));
1696 			if (y)
1697 				REQUIRE(mp_int_copy(TEMP(3), y));
1698 			if (c)
1699 			{
1700 				if (!s_qmul(TEMP(5), k))
1701 				{
1702 					REQUIRE(MP_MEMORY);
1703 				}
1704 				REQUIRE(mp_int_copy(TEMP(5), c));
1705 			}
1706 
1707 			break;
1708 		}
1709 	}
1710 
1711 	CLEANUP_TEMP();
1712 	return MP_OK;
1713 }
1714 
1715 mp_result
1716 mp_int_lcm(mp_int a, mp_int b, mp_int c)
1717 {
1718 	assert(a != NULL && b != NULL && c != NULL);
1719 
1720 	/*
1721 	 * Since a * b = gcd(a, b) * lcm(a, b), we can compute lcm(a, b) = (a /
1722 	 * gcd(a, b)) * b.
1723 	 *
1724 	 * This formulation insures everything works even if the input variables
1725 	 * share space.
1726 	 */
1727 	DECLARE_TEMP(1);
1728 	REQUIRE(mp_int_gcd(a, b, TEMP(0)));
1729 	REQUIRE(mp_int_div(a, TEMP(0), TEMP(0), NULL));
1730 	REQUIRE(mp_int_mul(TEMP(0), b, TEMP(0)));
1731 	REQUIRE(mp_int_copy(TEMP(0), c));
1732 
1733 	CLEANUP_TEMP();
1734 	return MP_OK;
1735 }
1736 
1737 bool
1738 mp_int_divisible_value(mp_int a, mp_small v)
1739 {
1740 	mp_small	rem = 0;
1741 
1742 	if (mp_int_div_value(a, v, NULL, &rem) != MP_OK)
1743 	{
1744 		return false;
1745 	}
1746 	return rem == 0;
1747 }
1748 
1749 int
1750 mp_int_is_pow2(mp_int z)
1751 {
1752 	assert(z != NULL);
1753 
1754 	return s_isp2(z);
1755 }
1756 
1757 /* Implementation of Newton's root finding method, based loosely on a patch
1758    contributed by Hal Finkel <half@halssoftware.com>
1759    modified by M. J. Fromberger.
1760  */
1761 mp_result
1762 mp_int_root(mp_int a, mp_small b, mp_int c)
1763 {
1764 	assert(a != NULL && c != NULL && b > 0);
1765 
1766 	if (b == 1)
1767 	{
1768 		return mp_int_copy(a, c);
1769 	}
1770 	bool		flips = false;
1771 
1772 	if (MP_SIGN(a) == MP_NEG)
1773 	{
1774 		if (b % 2 == 0)
1775 		{
1776 			return MP_UNDEF;	/* root does not exist for negative a with
1777 								 * even b */
1778 		}
1779 		else
1780 		{
1781 			flips = true;
1782 		}
1783 	}
1784 
1785 	DECLARE_TEMP(5);
1786 	REQUIRE(mp_int_copy(a, TEMP(0)));
1787 	REQUIRE(mp_int_copy(a, TEMP(1)));
1788 	TEMP(0)->sign = MP_ZPOS;
1789 	TEMP(1)->sign = MP_ZPOS;
1790 
1791 	for (;;)
1792 	{
1793 		REQUIRE(mp_int_expt(TEMP(1), b, TEMP(2)));
1794 
1795 		if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0)
1796 			break;
1797 
1798 		REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1799 		REQUIRE(mp_int_expt(TEMP(1), b - 1, TEMP(3)));
1800 		REQUIRE(mp_int_mul_value(TEMP(3), b, TEMP(3)));
1801 		REQUIRE(mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL));
1802 		REQUIRE(mp_int_sub(TEMP(1), TEMP(4), TEMP(4)));
1803 
1804 		if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0)
1805 		{
1806 			REQUIRE(mp_int_sub_value(TEMP(4), 1, TEMP(4)));
1807 		}
1808 		REQUIRE(mp_int_copy(TEMP(4), TEMP(1)));
1809 	}
1810 
1811 	REQUIRE(mp_int_copy(TEMP(1), c));
1812 
1813 	/* If the original value of a was negative, flip the output sign. */
1814 	if (flips)
1815 		(void) mp_int_neg(c, c);	/* cannot fail */
1816 
1817 	CLEANUP_TEMP();
1818 	return MP_OK;
1819 }
1820 
1821 mp_result
1822 mp_int_to_int(mp_int z, mp_small *out)
1823 {
1824 	assert(z != NULL);
1825 
1826 	/* Make sure the value is representable as a small integer */
1827 	mp_sign		sz = MP_SIGN(z);
1828 
1829 	if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) ||
1830 		mp_int_compare_value(z, MP_SMALL_MIN) < 0)
1831 	{
1832 		return MP_RANGE;
1833 	}
1834 
1835 	mp_usmall	uz = MP_USED(z);
1836 	mp_digit   *dz = MP_DIGITS(z) + uz - 1;
1837 	mp_small	uv = 0;
1838 
1839 	while (uz > 0)
1840 	{
1841 		uv <<= MP_DIGIT_BIT / 2;
1842 		uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1843 		--uz;
1844 	}
1845 
1846 	if (out)
1847 		*out = (mp_small) ((sz == MP_NEG) ? -uv : uv);
1848 
1849 	return MP_OK;
1850 }
1851 
1852 mp_result
1853 mp_int_to_uint(mp_int z, mp_usmall *out)
1854 {
1855 	assert(z != NULL);
1856 
1857 	/* Make sure the value is representable as an unsigned small integer */
1858 	mp_size		sz = MP_SIGN(z);
1859 
1860 	if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0)
1861 	{
1862 		return MP_RANGE;
1863 	}
1864 
1865 	mp_size		uz = MP_USED(z);
1866 	mp_digit   *dz = MP_DIGITS(z) + uz - 1;
1867 	mp_usmall	uv = 0;
1868 
1869 	while (uz > 0)
1870 	{
1871 		uv <<= MP_DIGIT_BIT / 2;
1872 		uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1873 		--uz;
1874 	}
1875 
1876 	if (out)
1877 		*out = uv;
1878 
1879 	return MP_OK;
1880 }
1881 
1882 mp_result
1883 mp_int_to_string(mp_int z, mp_size radix, char *str, int limit)
1884 {
1885 	assert(z != NULL && str != NULL && limit >= 2);
1886 	assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1887 
1888 	int			cmp = 0;
1889 
1890 	if (CMPZ(z) == 0)
1891 	{
1892 		*str++ = s_val2ch(0, 1);
1893 	}
1894 	else
1895 	{
1896 		mp_result	res;
1897 		mpz_t		tmp;
1898 		char	   *h,
1899 				   *t;
1900 
1901 		if ((res = mp_int_init_copy(&tmp, z)) != MP_OK)
1902 			return res;
1903 
1904 		if (MP_SIGN(z) == MP_NEG)
1905 		{
1906 			*str++ = '-';
1907 			--limit;
1908 		}
1909 		h = str;
1910 
1911 		/* Generate digits in reverse order until finished or limit reached */
1912 		for ( /* */ ; limit > 0; --limit)
1913 		{
1914 			mp_digit	d;
1915 
1916 			if ((cmp = CMPZ(&tmp)) == 0)
1917 				break;
1918 
1919 			d = s_ddiv(&tmp, (mp_digit) radix);
1920 			*str++ = s_val2ch(d, 1);
1921 		}
1922 		t = str - 1;
1923 
1924 		/* Put digits back in correct output order */
1925 		while (h < t)
1926 		{
1927 			char		tc = *h;
1928 
1929 			*h++ = *t;
1930 			*t-- = tc;
1931 		}
1932 
1933 		mp_int_clear(&tmp);
1934 	}
1935 
1936 	*str = '\0';
1937 	if (cmp == 0)
1938 	{
1939 		return MP_OK;
1940 	}
1941 	else
1942 	{
1943 		return MP_TRUNC;
1944 	}
1945 }
1946 
1947 mp_result
1948 mp_int_string_len(mp_int z, mp_size radix)
1949 {
1950 	assert(z != NULL);
1951 	assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1952 
1953 	int			len = s_outlen(z, radix) + 1;	/* for terminator */
1954 
1955 	/* Allow for sign marker on negatives */
1956 	if (MP_SIGN(z) == MP_NEG)
1957 		len += 1;
1958 
1959 	return len;
1960 }
1961 
1962 /* Read zero-terminated string into z */
1963 mp_result
1964 mp_int_read_string(mp_int z, mp_size radix, const char *str)
1965 {
1966 	return mp_int_read_cstring(z, radix, str, NULL);
1967 }
1968 
1969 mp_result
1970 mp_int_read_cstring(mp_int z, mp_size radix, const char *str,
1971 					char **end)
1972 {
1973 	assert(z != NULL && str != NULL);
1974 	assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1975 
1976 	/* Skip leading whitespace */
1977 	while (isspace((unsigned char) *str))
1978 		++str;
1979 
1980 	/* Handle leading sign tag (+/-, positive default) */
1981 	switch (*str)
1982 	{
1983 		case '-':
1984 			z->sign = MP_NEG;
1985 			++str;
1986 			break;
1987 		case '+':
1988 			++str;				/* fallthrough */
1989 		default:
1990 			z->sign = MP_ZPOS;
1991 			break;
1992 	}
1993 
1994 	/* Skip leading zeroes */
1995 	int			ch;
1996 
1997 	while ((ch = s_ch2val(*str, radix)) == 0)
1998 		++str;
1999 
2000 	/* Make sure there is enough space for the value */
2001 	if (!s_pad(z, s_inlen(strlen(str), radix)))
2002 		return MP_MEMORY;
2003 
2004 	z->used = 1;
2005 	z->digits[0] = 0;
2006 
2007 	while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0))
2008 	{
2009 		s_dmul(z, (mp_digit) radix);
2010 		s_dadd(z, (mp_digit) ch);
2011 		++str;
2012 	}
2013 
2014 	CLAMP(z);
2015 
2016 	/* Override sign for zero, even if negative specified. */
2017 	if (CMPZ(z) == 0)
2018 		z->sign = MP_ZPOS;
2019 
2020 	if (end != NULL)
2021 		*end = unconstify(char *, str);
2022 
2023 	/*
2024 	 * Return a truncation error if the string has unprocessed characters
2025 	 * remaining, so the caller can tell if the whole string was done
2026 	 */
2027 	if (*str != '\0')
2028 	{
2029 		return MP_TRUNC;
2030 	}
2031 	else
2032 	{
2033 		return MP_OK;
2034 	}
2035 }
2036 
2037 mp_result
2038 mp_int_count_bits(mp_int z)
2039 {
2040 	assert(z != NULL);
2041 
2042 	mp_size		uz = MP_USED(z);
2043 
2044 	if (uz == 1 && z->digits[0] == 0)
2045 		return 1;
2046 
2047 	--uz;
2048 	mp_size		nbits = uz * MP_DIGIT_BIT;
2049 	mp_digit	d = z->digits[uz];
2050 
2051 	while (d != 0)
2052 	{
2053 		d >>= 1;
2054 		++nbits;
2055 	}
2056 
2057 	return nbits;
2058 }
2059 
2060 mp_result
2061 mp_int_to_binary(mp_int z, unsigned char *buf, int limit)
2062 {
2063 	static const int PAD_FOR_2C = 1;
2064 
2065 	assert(z != NULL && buf != NULL);
2066 
2067 	int			limpos = limit;
2068 	mp_result	res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
2069 
2070 	if (MP_SIGN(z) == MP_NEG)
2071 		s_2comp(buf, limpos);
2072 
2073 	return res;
2074 }
2075 
2076 mp_result
2077 mp_int_read_binary(mp_int z, unsigned char *buf, int len)
2078 {
2079 	assert(z != NULL && buf != NULL && len > 0);
2080 
2081 	/* Figure out how many digits are needed to represent this value */
2082 	mp_size		need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2083 
2084 	if (!s_pad(z, need))
2085 		return MP_MEMORY;
2086 
2087 	mp_int_zero(z);
2088 
2089 	/*
2090 	 * If the high-order bit is set, take the 2's complement before reading
2091 	 * the value (it will be restored afterward)
2092 	 */
2093 	if (buf[0] >> (CHAR_BIT - 1))
2094 	{
2095 		z->sign = MP_NEG;
2096 		s_2comp(buf, len);
2097 	}
2098 
2099 	mp_digit   *dz = MP_DIGITS(z);
2100 	unsigned char *tmp = buf;
2101 
2102 	for (int i = len; i > 0; --i, ++tmp)
2103 	{
2104 		s_qmul(z, (mp_size) CHAR_BIT);
2105 		*dz |= *tmp;
2106 	}
2107 
2108 	/* Restore 2's complement if we took it before */
2109 	if (MP_SIGN(z) == MP_NEG)
2110 		s_2comp(buf, len);
2111 
2112 	return MP_OK;
2113 }
2114 
2115 mp_result
2116 mp_int_binary_len(mp_int z)
2117 {
2118 	mp_result	res = mp_int_count_bits(z);
2119 
2120 	if (res <= 0)
2121 		return res;
2122 
2123 	int			bytes = mp_int_unsigned_len(z);
2124 
2125 	/*
2126 	 * If the highest-order bit falls exactly on a byte boundary, we need to
2127 	 * pad with an extra byte so that the sign will be read correctly when
2128 	 * reading it back in.
2129 	 */
2130 	if (bytes * CHAR_BIT == res)
2131 		++bytes;
2132 
2133 	return bytes;
2134 }
2135 
2136 mp_result
2137 mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit)
2138 {
2139 	static const int NO_PADDING = 0;
2140 
2141 	assert(z != NULL && buf != NULL);
2142 
2143 	return s_tobin(z, buf, &limit, NO_PADDING);
2144 }
2145 
2146 mp_result
2147 mp_int_read_unsigned(mp_int z, unsigned char *buf, int len)
2148 {
2149 	assert(z != NULL && buf != NULL && len > 0);
2150 
2151 	/* Figure out how many digits are needed to represent this value */
2152 	mp_size		need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
2153 
2154 	if (!s_pad(z, need))
2155 		return MP_MEMORY;
2156 
2157 	mp_int_zero(z);
2158 
2159 	unsigned char *tmp = buf;
2160 
2161 	for (int i = len; i > 0; --i, ++tmp)
2162 	{
2163 		(void) s_qmul(z, CHAR_BIT);
2164 		*MP_DIGITS(z) |= *tmp;
2165 	}
2166 
2167 	return MP_OK;
2168 }
2169 
2170 mp_result
2171 mp_int_unsigned_len(mp_int z)
2172 {
2173 	mp_result	res = mp_int_count_bits(z);
2174 
2175 	if (res <= 0)
2176 		return res;
2177 
2178 	int			bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
2179 
2180 	return bytes;
2181 }
2182 
2183 const char *
2184 mp_error_string(mp_result res)
2185 {
2186 	if (res > 0)
2187 		return s_unknown_err;
2188 
2189 	res = -res;
2190 	int			ix;
2191 
2192 	for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
2193 		;
2194 
2195 	if (s_error_msg[ix] != NULL)
2196 	{
2197 		return s_error_msg[ix];
2198 	}
2199 	else
2200 	{
2201 		return s_unknown_err;
2202 	}
2203 }
2204 
2205 /*------------------------------------------------------------------------*/
2206 /* Private functions for internal use.  These make assumptions.           */
2207 
2208 #if IMATH_DEBUG
2209 static const mp_digit fill = (mp_digit) 0xdeadbeefabad1dea;
2210 #endif
2211 
2212 static mp_digit *
2213 s_alloc(mp_size num)
2214 {
2215 	mp_digit   *out = px_alloc(num * sizeof(mp_digit));
2216 
2217 	assert(out != NULL);
2218 
2219 #if IMATH_DEBUG
2220 	for (mp_size ix = 0; ix < num; ++ix)
2221 		out[ix] = fill;
2222 #endif
2223 	return out;
2224 }
2225 
2226 static mp_digit *
2227 s_realloc(mp_digit *old, mp_size osize, mp_size nsize)
2228 {
2229 #if IMATH_DEBUG
2230 	mp_digit   *new = s_alloc(nsize);
2231 
2232 	assert(new != NULL);
2233 
2234 	for (mp_size ix = 0; ix < nsize; ++ix)
2235 		new[ix] = fill;
2236 	memcpy(new, old, osize * sizeof(mp_digit));
2237 #else
2238 	mp_digit   *new = px_realloc(old, nsize * sizeof(mp_digit));
2239 
2240 	assert(new != NULL);
2241 #endif
2242 
2243 	return new;
2244 }
2245 
2246 static void
2247 s_free(void *ptr)
2248 {
2249 	px_free(ptr);
2250 }
2251 
2252 static bool
2253 s_pad(mp_int z, mp_size min)
2254 {
2255 	if (MP_ALLOC(z) < min)
2256 	{
2257 		mp_size		nsize = s_round_prec(min);
2258 		mp_digit   *tmp;
2259 
2260 		if (z->digits == &(z->single))
2261 		{
2262 			if ((tmp = s_alloc(nsize)) == NULL)
2263 				return false;
2264 			tmp[0] = z->single;
2265 		}
2266 		else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL)
2267 		{
2268 			return false;
2269 		}
2270 
2271 		z->digits = tmp;
2272 		z->alloc = nsize;
2273 	}
2274 
2275 	return true;
2276 }
2277 
2278 /* Note: This will not work correctly when value == MP_SMALL_MIN */
2279 static void
2280 s_fake(mp_int z, mp_small value, mp_digit vbuf[])
2281 {
2282 	mp_usmall	uv = (mp_usmall) (value < 0) ? -value : value;
2283 
2284 	s_ufake(z, uv, vbuf);
2285 	if (value < 0)
2286 		z->sign = MP_NEG;
2287 }
2288 
2289 static void
2290 s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[])
2291 {
2292 	mp_size		ndig = (mp_size) s_uvpack(value, vbuf);
2293 
2294 	z->used = ndig;
2295 	z->alloc = MP_VALUE_DIGITS(value);
2296 	z->sign = MP_ZPOS;
2297 	z->digits = vbuf;
2298 }
2299 
2300 static int
2301 s_cdig(mp_digit *da, mp_digit *db, mp_size len)
2302 {
2303 	mp_digit   *dat = da + len - 1,
2304 			   *dbt = db + len - 1;
2305 
2306 	for ( /* */ ; len != 0; --len, --dat, --dbt)
2307 	{
2308 		if (*dat > *dbt)
2309 		{
2310 			return 1;
2311 		}
2312 		else if (*dat < *dbt)
2313 		{
2314 			return -1;
2315 		}
2316 	}
2317 
2318 	return 0;
2319 }
2320 
2321 static int
2322 s_uvpack(mp_usmall uv, mp_digit t[])
2323 {
2324 	int			ndig = 0;
2325 
2326 	if (uv == 0)
2327 		t[ndig++] = 0;
2328 	else
2329 	{
2330 		while (uv != 0)
2331 		{
2332 			t[ndig++] = (mp_digit) uv;
2333 			uv >>= MP_DIGIT_BIT / 2;
2334 			uv >>= MP_DIGIT_BIT / 2;
2335 		}
2336 	}
2337 
2338 	return ndig;
2339 }
2340 
2341 static int
2342 s_ucmp(mp_int a, mp_int b)
2343 {
2344 	mp_size		ua = MP_USED(a),
2345 				ub = MP_USED(b);
2346 
2347 	if (ua > ub)
2348 	{
2349 		return 1;
2350 	}
2351 	else if (ub > ua)
2352 	{
2353 		return -1;
2354 	}
2355 	else
2356 	{
2357 		return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
2358 	}
2359 }
2360 
2361 static int
2362 s_vcmp(mp_int a, mp_small v)
2363 {
2364 #ifdef _MSC_VER
2365 #pragma warning(push)
2366 #pragma warning(disable: 4146)
2367 #endif
2368 	mp_usmall	uv = (v < 0) ? -(mp_usmall) v : (mp_usmall) v;
2369 #ifdef _MSC_VER
2370 #pragma warning(pop)
2371 #endif
2372 
2373 	return s_uvcmp(a, uv);
2374 }
2375 
2376 static int
2377 s_uvcmp(mp_int a, mp_usmall uv)
2378 {
2379 	mpz_t		vtmp;
2380 	mp_digit	vdig[MP_VALUE_DIGITS(uv)];
2381 
2382 	s_ufake(&vtmp, uv, vdig);
2383 	return s_ucmp(a, &vtmp);
2384 }
2385 
2386 static mp_digit
2387 s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2388 	   mp_size size_b)
2389 {
2390 	mp_size		pos;
2391 	mp_word		w = 0;
2392 
2393 	/* Insure that da is the longer of the two to simplify later code */
2394 	if (size_b > size_a)
2395 	{
2396 		SWAP(mp_digit *, da, db);
2397 		SWAP(mp_size, size_a, size_b);
2398 	}
2399 
2400 	/* Add corresponding digits until the shorter number runs out */
2401 	for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2402 	{
2403 		w = w + (mp_word) *da + (mp_word) *db;
2404 		*dc = LOWER_HALF(w);
2405 		w = UPPER_HALF(w);
2406 	}
2407 
2408 	/* Propagate carries as far as necessary */
2409 	for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2410 	{
2411 		w = w + *da;
2412 
2413 		*dc = LOWER_HALF(w);
2414 		w = UPPER_HALF(w);
2415 	}
2416 
2417 	/* Return carry out */
2418 	return (mp_digit) w;
2419 }
2420 
2421 static void
2422 s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2423 	   mp_size size_b)
2424 {
2425 	mp_size		pos;
2426 	mp_word		w = 0;
2427 
2428 	/* We assume that |a| >= |b| so this should definitely hold */
2429 	assert(size_a >= size_b);
2430 
2431 	/* Subtract corresponding digits and propagate borrow */
2432 	for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc)
2433 	{
2434 		w = ((mp_word) MP_DIGIT_MAX + 1 +	/* MP_RADIX */
2435 			 (mp_word) *da) -
2436 			w - (mp_word) *db;
2437 
2438 		*dc = LOWER_HALF(w);
2439 		w = (UPPER_HALF(w) == 0);
2440 	}
2441 
2442 	/* Finish the subtraction for remaining upper digits of da */
2443 	for ( /* */ ; pos < size_a; ++pos, ++da, ++dc)
2444 	{
2445 		w = ((mp_word) MP_DIGIT_MAX + 1 +	/* MP_RADIX */
2446 			 (mp_word) *da) -
2447 			w;
2448 
2449 		*dc = LOWER_HALF(w);
2450 		w = (UPPER_HALF(w) == 0);
2451 	}
2452 
2453 	/* If there is a borrow out at the end, it violates the precondition */
2454 	assert(w == 0);
2455 }
2456 
2457 static int
2458 s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2459 	   mp_size size_b)
2460 {
2461 	mp_size		bot_size;
2462 
2463 	/* Make sure b is the smaller of the two input values */
2464 	if (size_b > size_a)
2465 	{
2466 		SWAP(mp_digit *, da, db);
2467 		SWAP(mp_size, size_a, size_b);
2468 	}
2469 
2470 	/*
2471 	 * Insure that the bottom is the larger half in an odd-length split; the
2472 	 * code below relies on this being true.
2473 	 */
2474 	bot_size = (size_a + 1) / 2;
2475 
2476 	/*
2477 	 * If the values are big enough to bother with recursion, use the
2478 	 * Karatsuba algorithm to compute the product; otherwise use the normal
2479 	 * multiplication algorithm
2480 	 */
2481 	if (multiply_threshold && size_a >= multiply_threshold && size_b > bot_size)
2482 	{
2483 		mp_digit   *t1,
2484 				   *t2,
2485 				   *t3,
2486 					carry;
2487 
2488 		mp_digit   *a_top = da + bot_size;
2489 		mp_digit   *b_top = db + bot_size;
2490 
2491 		mp_size		at_size = size_a - bot_size;
2492 		mp_size		bt_size = size_b - bot_size;
2493 		mp_size		buf_size = 2 * bot_size;
2494 
2495 		/*
2496 		 * Do a single allocation for all three temporary buffers needed; each
2497 		 * buffer must be big enough to hold the product of two bottom halves,
2498 		 * and one buffer needs space for the completed product; twice the
2499 		 * space is plenty.
2500 		 */
2501 		if ((t1 = s_alloc(4 * buf_size)) == NULL)
2502 			return 0;
2503 		t2 = t1 + buf_size;
2504 		t3 = t2 + buf_size;
2505 		ZERO(t1, 4 * buf_size);
2506 
2507 		/*
2508 		 * t1 and t2 are initially used as temporaries to compute the inner
2509 		 * product (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
2510 		 */
2511 		carry = s_uadd(da, a_top, t1, bot_size, at_size);	/* t1 = a1 + a0 */
2512 		t1[bot_size] = carry;
2513 
2514 		carry = s_uadd(db, b_top, t2, bot_size, bt_size);	/* t2 = b1 + b0 */
2515 		t2[bot_size] = carry;
2516 
2517 		(void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1);	/* t3 = t1 * t2 */
2518 
2519 		/*
2520 		 * Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so
2521 		 * that we're left with only the pieces we want:  t3 = a1b0 + a0b1
2522 		 */
2523 		ZERO(t1, buf_size);
2524 		ZERO(t2, buf_size);
2525 		(void) s_kmul(da, db, t1, bot_size, bot_size);	/* t1 = a0 * b0 */
2526 		(void) s_kmul(a_top, b_top, t2, at_size, bt_size);	/* t2 = a1 * b1 */
2527 
2528 		/* Subtract out t1 and t2 to get the inner product */
2529 		s_usub(t3, t1, t3, buf_size + 2, buf_size);
2530 		s_usub(t3, t2, t3, buf_size + 2, buf_size);
2531 
2532 		/* Assemble the output value */
2533 		COPY(t1, dc, buf_size);
2534 		carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
2535 		assert(carry == 0);
2536 
2537 		carry =
2538 			s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
2539 		assert(carry == 0);
2540 
2541 		s_free(t1);				/* note t2 and t3 are just internal pointers
2542 								 * to t1 */
2543 	}
2544 	else
2545 	{
2546 		s_umul(da, db, dc, size_a, size_b);
2547 	}
2548 
2549 	return 1;
2550 }
2551 
2552 static void
2553 s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
2554 	   mp_size size_b)
2555 {
2556 	mp_size		a,
2557 				b;
2558 	mp_word		w;
2559 
2560 	for (a = 0; a < size_a; ++a, ++dc, ++da)
2561 	{
2562 		mp_digit   *dct = dc;
2563 		mp_digit   *dbt = db;
2564 
2565 		if (*da == 0)
2566 			continue;
2567 
2568 		w = 0;
2569 		for (b = 0; b < size_b; ++b, ++dbt, ++dct)
2570 		{
2571 			w = (mp_word) *da * (mp_word) *dbt + w + (mp_word) *dct;
2572 
2573 			*dct = LOWER_HALF(w);
2574 			w = UPPER_HALF(w);
2575 		}
2576 
2577 		*dct = (mp_digit) w;
2578 	}
2579 }
2580 
2581 static int
2582 s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2583 {
2584 	if (multiply_threshold && size_a > multiply_threshold)
2585 	{
2586 		mp_size		bot_size = (size_a + 1) / 2;
2587 		mp_digit   *a_top = da + bot_size;
2588 		mp_digit   *t1,
2589 				   *t2,
2590 				   *t3,
2591 					carry PG_USED_FOR_ASSERTS_ONLY;
2592 		mp_size		at_size = size_a - bot_size;
2593 		mp_size		buf_size = 2 * bot_size;
2594 
2595 		if ((t1 = s_alloc(4 * buf_size)) == NULL)
2596 			return 0;
2597 		t2 = t1 + buf_size;
2598 		t3 = t2 + buf_size;
2599 		ZERO(t1, 4 * buf_size);
2600 
2601 		(void) s_ksqr(da, t1, bot_size);	/* t1 = a0 ^ 2 */
2602 		(void) s_ksqr(a_top, t2, at_size);	/* t2 = a1 ^ 2 */
2603 
2604 		(void) s_kmul(da, a_top, t3, bot_size, at_size);	/* t3 = a0 * a1 */
2605 
2606 		/* Quick multiply t3 by 2, shifting left (can't overflow) */
2607 		{
2608 			int			i,
2609 						top = bot_size + at_size;
2610 			mp_word		w,
2611 						save = 0;
2612 
2613 			for (i = 0; i < top; ++i)
2614 			{
2615 				w = t3[i];
2616 				w = (w << 1) | save;
2617 				t3[i] = LOWER_HALF(w);
2618 				save = UPPER_HALF(w);
2619 			}
2620 			t3[i] = LOWER_HALF(save);
2621 		}
2622 
2623 		/* Assemble the output value */
2624 		COPY(t1, dc, 2 * bot_size);
2625 		carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
2626 		assert(carry == 0);
2627 
2628 		carry =
2629 			s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
2630 		assert(carry == 0);
2631 
2632 		s_free(t1);				/* note that t2 and t2 are internal pointers
2633 								 * only */
2634 
2635 	}
2636 	else
2637 	{
2638 		s_usqr(da, dc, size_a);
2639 	}
2640 
2641 	return 1;
2642 }
2643 
2644 static void
2645 s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2646 {
2647 	mp_size		i,
2648 				j;
2649 	mp_word		w;
2650 
2651 	for (i = 0; i < size_a; ++i, dc += 2, ++da)
2652 	{
2653 		mp_digit   *dct = dc,
2654 				   *dat = da;
2655 
2656 		if (*da == 0)
2657 			continue;
2658 
2659 		/* Take care of the first digit, no rollover */
2660 		w = (mp_word) *dat * (mp_word) *dat + (mp_word) *dct;
2661 		*dct = LOWER_HALF(w);
2662 		w = UPPER_HALF(w);
2663 		++dat;
2664 		++dct;
2665 
2666 		for (j = i + 1; j < size_a; ++j, ++dat, ++dct)
2667 		{
2668 			mp_word		t = (mp_word) *da * (mp_word) *dat;
2669 			mp_word		u = w + (mp_word) *dct,
2670 						ov = 0;
2671 
2672 			/* Check if doubling t will overflow a word */
2673 			if (HIGH_BIT_SET(t))
2674 				ov = 1;
2675 
2676 			w = t + t;
2677 
2678 			/* Check if adding u to w will overflow a word */
2679 			if (ADD_WILL_OVERFLOW(w, u))
2680 				ov = 1;
2681 
2682 			w += u;
2683 
2684 			*dct = LOWER_HALF(w);
2685 			w = UPPER_HALF(w);
2686 			if (ov)
2687 			{
2688 				w += MP_DIGIT_MAX;	/* MP_RADIX */
2689 				++w;
2690 			}
2691 		}
2692 
2693 		w = w + *dct;
2694 		*dct = (mp_digit) w;
2695 		while ((w = UPPER_HALF(w)) != 0)
2696 		{
2697 			++dct;
2698 			w = w + *dct;
2699 			*dct = LOWER_HALF(w);
2700 		}
2701 
2702 		assert(w == 0);
2703 	}
2704 }
2705 
2706 static void
2707 s_dadd(mp_int a, mp_digit b)
2708 {
2709 	mp_word		w = 0;
2710 	mp_digit   *da = MP_DIGITS(a);
2711 	mp_size		ua = MP_USED(a);
2712 
2713 	w = (mp_word) *da + b;
2714 	*da++ = LOWER_HALF(w);
2715 	w = UPPER_HALF(w);
2716 
2717 	for (ua -= 1; ua > 0; --ua, ++da)
2718 	{
2719 		w = (mp_word) *da + w;
2720 
2721 		*da = LOWER_HALF(w);
2722 		w = UPPER_HALF(w);
2723 	}
2724 
2725 	if (w)
2726 	{
2727 		*da = (mp_digit) w;
2728 		a->used += 1;
2729 	}
2730 }
2731 
2732 static void
2733 s_dmul(mp_int a, mp_digit b)
2734 {
2735 	mp_word		w = 0;
2736 	mp_digit   *da = MP_DIGITS(a);
2737 	mp_size		ua = MP_USED(a);
2738 
2739 	while (ua > 0)
2740 	{
2741 		w = (mp_word) *da * b + w;
2742 		*da++ = LOWER_HALF(w);
2743 		w = UPPER_HALF(w);
2744 		--ua;
2745 	}
2746 
2747 	if (w)
2748 	{
2749 		*da = (mp_digit) w;
2750 		a->used += 1;
2751 	}
2752 }
2753 
2754 static void
2755 s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a)
2756 {
2757 	mp_word		w = 0;
2758 
2759 	while (size_a > 0)
2760 	{
2761 		w = (mp_word) *da++ * (mp_word) b + w;
2762 
2763 		*dc++ = LOWER_HALF(w);
2764 		w = UPPER_HALF(w);
2765 		--size_a;
2766 	}
2767 
2768 	if (w)
2769 		*dc = LOWER_HALF(w);
2770 }
2771 
2772 static mp_digit
2773 s_ddiv(mp_int a, mp_digit b)
2774 {
2775 	mp_word		w = 0,
2776 				qdigit;
2777 	mp_size		ua = MP_USED(a);
2778 	mp_digit   *da = MP_DIGITS(a) + ua - 1;
2779 
2780 	for ( /* */ ; ua > 0; --ua, --da)
2781 	{
2782 		w = (w << MP_DIGIT_BIT) | *da;
2783 
2784 		if (w >= b)
2785 		{
2786 			qdigit = w / b;
2787 			w = w % b;
2788 		}
2789 		else
2790 		{
2791 			qdigit = 0;
2792 		}
2793 
2794 		*da = (mp_digit) qdigit;
2795 	}
2796 
2797 	CLAMP(a);
2798 	return (mp_digit) w;
2799 }
2800 
2801 static void
2802 s_qdiv(mp_int z, mp_size p2)
2803 {
2804 	mp_size		ndig = p2 / MP_DIGIT_BIT,
2805 				nbits = p2 % MP_DIGIT_BIT;
2806 	mp_size		uz = MP_USED(z);
2807 
2808 	if (ndig)
2809 	{
2810 		mp_size		mark;
2811 		mp_digit   *to,
2812 				   *from;
2813 
2814 		if (ndig >= uz)
2815 		{
2816 			mp_int_zero(z);
2817 			return;
2818 		}
2819 
2820 		to = MP_DIGITS(z);
2821 		from = to + ndig;
2822 
2823 		for (mark = ndig; mark < uz; ++mark)
2824 		{
2825 			*to++ = *from++;
2826 		}
2827 
2828 		z->used = uz - ndig;
2829 	}
2830 
2831 	if (nbits)
2832 	{
2833 		mp_digit	d = 0,
2834 				   *dz,
2835 					save;
2836 		mp_size		up = MP_DIGIT_BIT - nbits;
2837 
2838 		uz = MP_USED(z);
2839 		dz = MP_DIGITS(z) + uz - 1;
2840 
2841 		for ( /* */ ; uz > 0; --uz, --dz)
2842 		{
2843 			save = *dz;
2844 
2845 			*dz = (*dz >> nbits) | (d << up);
2846 			d = save;
2847 		}
2848 
2849 		CLAMP(z);
2850 	}
2851 
2852 	if (MP_USED(z) == 1 && z->digits[0] == 0)
2853 		z->sign = MP_ZPOS;
2854 }
2855 
2856 static void
2857 s_qmod(mp_int z, mp_size p2)
2858 {
2859 	mp_size		start = p2 / MP_DIGIT_BIT + 1,
2860 				rest = p2 % MP_DIGIT_BIT;
2861 	mp_size		uz = MP_USED(z);
2862 	mp_digit	mask = (1u << rest) - 1;
2863 
2864 	if (start <= uz)
2865 	{
2866 		z->used = start;
2867 		z->digits[start - 1] &= mask;
2868 		CLAMP(z);
2869 	}
2870 }
2871 
2872 static int
2873 s_qmul(mp_int z, mp_size p2)
2874 {
2875 	mp_size		uz,
2876 				need,
2877 				rest,
2878 				extra,
2879 				i;
2880 	mp_digit   *from,
2881 			   *to,
2882 				d;
2883 
2884 	if (p2 == 0)
2885 		return 1;
2886 
2887 	uz = MP_USED(z);
2888 	need = p2 / MP_DIGIT_BIT;
2889 	rest = p2 % MP_DIGIT_BIT;
2890 
2891 	/*
2892 	 * Figure out if we need an extra digit at the top end; this occurs if the
2893 	 * topmost `rest' bits of the high-order digit of z are not zero, meaning
2894 	 * they will be shifted off the end if not preserved
2895 	 */
2896 	extra = 0;
2897 	if (rest != 0)
2898 	{
2899 		mp_digit   *dz = MP_DIGITS(z) + uz - 1;
2900 
2901 		if ((*dz >> (MP_DIGIT_BIT - rest)) != 0)
2902 			extra = 1;
2903 	}
2904 
2905 	if (!s_pad(z, uz + need + extra))
2906 		return 0;
2907 
2908 	/*
2909 	 * If we need to shift by whole digits, do that in one pass, then to back
2910 	 * and shift by partial digits.
2911 	 */
2912 	if (need > 0)
2913 	{
2914 		from = MP_DIGITS(z) + uz - 1;
2915 		to = from + need;
2916 
2917 		for (i = 0; i < uz; ++i)
2918 			*to-- = *from--;
2919 
2920 		ZERO(MP_DIGITS(z), need);
2921 		uz += need;
2922 	}
2923 
2924 	if (rest)
2925 	{
2926 		d = 0;
2927 		for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from)
2928 		{
2929 			mp_digit	save = *from;
2930 
2931 			*from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
2932 			d = save;
2933 		}
2934 
2935 		d >>= (MP_DIGIT_BIT - rest);
2936 		if (d != 0)
2937 		{
2938 			*from = d;
2939 			uz += extra;
2940 		}
2941 	}
2942 
2943 	z->used = uz;
2944 	CLAMP(z);
2945 
2946 	return 1;
2947 }
2948 
2949 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z|
2950    The sign of the result is always zero/positive.
2951  */
2952 static int
2953 s_qsub(mp_int z, mp_size p2)
2954 {
2955 	mp_digit	hi = (1u << (p2 % MP_DIGIT_BIT)),
2956 			   *zp;
2957 	mp_size		tdig = (p2 / MP_DIGIT_BIT),
2958 				pos;
2959 	mp_word		w = 0;
2960 
2961 	if (!s_pad(z, tdig + 1))
2962 		return 0;
2963 
2964 	for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp)
2965 	{
2966 		w = ((mp_word) MP_DIGIT_MAX + 1) - w - (mp_word) *zp;
2967 
2968 		*zp = LOWER_HALF(w);
2969 		w = UPPER_HALF(w) ? 0 : 1;
2970 	}
2971 
2972 	w = ((mp_word) MP_DIGIT_MAX + 1 + hi) - w - (mp_word) *zp;
2973 	*zp = LOWER_HALF(w);
2974 
2975 	assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
2976 
2977 	z->sign = MP_ZPOS;
2978 	CLAMP(z);
2979 
2980 	return 1;
2981 }
2982 
2983 static int
2984 s_dp2k(mp_int z)
2985 {
2986 	int			k = 0;
2987 	mp_digit   *dp = MP_DIGITS(z),
2988 				d;
2989 
2990 	if (MP_USED(z) == 1 && *dp == 0)
2991 		return 1;
2992 
2993 	while (*dp == 0)
2994 	{
2995 		k += MP_DIGIT_BIT;
2996 		++dp;
2997 	}
2998 
2999 	d = *dp;
3000 	while ((d & 1) == 0)
3001 	{
3002 		d >>= 1;
3003 		++k;
3004 	}
3005 
3006 	return k;
3007 }
3008 
3009 static int
3010 s_isp2(mp_int z)
3011 {
3012 	mp_size		uz = MP_USED(z),
3013 				k = 0;
3014 	mp_digit   *dz = MP_DIGITS(z),
3015 				d;
3016 
3017 	while (uz > 1)
3018 	{
3019 		if (*dz++ != 0)
3020 			return -1;
3021 		k += MP_DIGIT_BIT;
3022 		--uz;
3023 	}
3024 
3025 	d = *dz;
3026 	while (d > 1)
3027 	{
3028 		if (d & 1)
3029 			return -1;
3030 		++k;
3031 		d >>= 1;
3032 	}
3033 
3034 	return (int) k;
3035 }
3036 
3037 static int
3038 s_2expt(mp_int z, mp_small k)
3039 {
3040 	mp_size		ndig,
3041 				rest;
3042 	mp_digit   *dz;
3043 
3044 	ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
3045 	rest = k % MP_DIGIT_BIT;
3046 
3047 	if (!s_pad(z, ndig))
3048 		return 0;
3049 
3050 	dz = MP_DIGITS(z);
3051 	ZERO(dz, ndig);
3052 	*(dz + ndig - 1) = (1u << rest);
3053 	z->used = ndig;
3054 
3055 	return 1;
3056 }
3057 
3058 static int
3059 s_norm(mp_int a, mp_int b)
3060 {
3061 	mp_digit	d = b->digits[MP_USED(b) - 1];
3062 	int			k = 0;
3063 
3064 	while (d < (1u << (mp_digit) (MP_DIGIT_BIT - 1)))
3065 	{							/* d < (MP_RADIX / 2) */
3066 		d <<= 1;
3067 		++k;
3068 	}
3069 
3070 	/* These multiplications can't fail */
3071 	if (k != 0)
3072 	{
3073 		(void) s_qmul(a, (mp_size) k);
3074 		(void) s_qmul(b, (mp_size) k);
3075 	}
3076 
3077 	return k;
3078 }
3079 
3080 static mp_result
3081 s_brmu(mp_int z, mp_int m)
3082 {
3083 	mp_size		um = MP_USED(m) * 2;
3084 
3085 	if (!s_pad(z, um))
3086 		return MP_MEMORY;
3087 
3088 	s_2expt(z, MP_DIGIT_BIT * um);
3089 	return mp_int_div(z, m, z, NULL);
3090 }
3091 
3092 static int
3093 s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2)
3094 {
3095 	mp_size		um = MP_USED(m),
3096 				umb_p1,
3097 				umb_m1;
3098 
3099 	umb_p1 = (um + 1) * MP_DIGIT_BIT;
3100 	umb_m1 = (um - 1) * MP_DIGIT_BIT;
3101 
3102 	if (mp_int_copy(x, q1) != MP_OK)
3103 		return 0;
3104 
3105 	/* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
3106 	s_qdiv(q1, umb_m1);
3107 	UMUL(q1, mu, q2);
3108 	s_qdiv(q2, umb_p1);
3109 
3110 	/* Set x = x mod b^(k+1) */
3111 	s_qmod(x, umb_p1);
3112 
3113 	/*
3114 	 * Now, q is a guess for the quotient a / m. Compute x - q * m mod
3115 	 * b^(k+1), replacing x.  This may be off by a factor of 2m, but no more
3116 	 * than that.
3117 	 */
3118 	UMUL(q2, m, q1);
3119 	s_qmod(q1, umb_p1);
3120 	(void) mp_int_sub(x, q1, x);	/* can't fail */
3121 
3122 	/*
3123 	 * The result may be < 0; if it is, add b^(k+1) to pin it in the proper
3124 	 * range.
3125 	 */
3126 	if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1))
3127 		return 0;
3128 
3129 	/*
3130 	 * If x > m, we need to back it off until it is in range.  This will be
3131 	 * required at most twice.
3132 	 */
3133 	if (mp_int_compare(x, m) >= 0)
3134 	{
3135 		(void) mp_int_sub(x, m, x);
3136 		if (mp_int_compare(x, m) >= 0)
3137 		{
3138 			(void) mp_int_sub(x, m, x);
3139 		}
3140 	}
3141 
3142 	/* At this point, x has been properly reduced. */
3143 	return 1;
3144 }
3145 
3146 /* Perform modular exponentiation using Barrett's method, where mu is the
3147    reduction constant for m.  Assumes a < m, b > 0. */
3148 static mp_result
3149 s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
3150 {
3151 	mp_digit	umu = MP_USED(mu);
3152 	mp_digit   *db = MP_DIGITS(b);
3153 	mp_digit   *dbt = db + MP_USED(b) - 1;
3154 
3155 	DECLARE_TEMP(3);
3156 	REQUIRE(GROW(TEMP(0), 4 * umu));
3157 	REQUIRE(GROW(TEMP(1), 4 * umu));
3158 	REQUIRE(GROW(TEMP(2), 4 * umu));
3159 	ZERO(TEMP(0)->digits, TEMP(0)->alloc);
3160 	ZERO(TEMP(1)->digits, TEMP(1)->alloc);
3161 	ZERO(TEMP(2)->digits, TEMP(2)->alloc);
3162 
3163 	(void) mp_int_set_value(c, 1);
3164 
3165 	/* Take care of low-order digits */
3166 	while (db < dbt)
3167 	{
3168 		mp_digit	d = *db;
3169 
3170 		for (int i = MP_DIGIT_BIT; i > 0; --i, d >>= 1)
3171 		{
3172 			if (d & 1)
3173 			{
3174 				/* The use of a second temporary avoids allocation */
3175 				UMUL(c, a, TEMP(0));
3176 				if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3177 				{
3178 					REQUIRE(MP_MEMORY);
3179 				}
3180 				mp_int_copy(TEMP(0), c);
3181 			}
3182 
3183 			USQR(a, TEMP(0));
3184 			assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3185 			if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3186 			{
3187 				REQUIRE(MP_MEMORY);
3188 			}
3189 			assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
3190 			mp_int_copy(TEMP(0), a);
3191 		}
3192 
3193 		++db;
3194 	}
3195 
3196 	/* Take care of highest-order digit */
3197 	mp_digit	d = *dbt;
3198 
3199 	for (;;)
3200 	{
3201 		if (d & 1)
3202 		{
3203 			UMUL(c, a, TEMP(0));
3204 			if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3205 			{
3206 				REQUIRE(MP_MEMORY);
3207 			}
3208 			mp_int_copy(TEMP(0), c);
3209 		}
3210 
3211 		d >>= 1;
3212 		if (!d)
3213 			break;
3214 
3215 		USQR(a, TEMP(0));
3216 		if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2)))
3217 		{
3218 			REQUIRE(MP_MEMORY);
3219 		}
3220 		(void) mp_int_copy(TEMP(0), a);
3221 	}
3222 
3223 	CLEANUP_TEMP();
3224 	return MP_OK;
3225 }
3226 
3227 /* Division of nonnegative integers
3228 
3229    This function implements division algorithm for unsigned multi-precision
3230    integers. The algorithm is based on Algorithm D from Knuth's "The Art of
3231    Computer Programming", 3rd ed. 1998, pg 272-273.
3232 
3233    We diverge from Knuth's algorithm in that we do not perform the subtraction
3234    from the remainder until we have determined that we have the correct
3235    quotient digit. This makes our algorithm less efficient that Knuth because
3236    we might have to perform multiple multiplication and comparison steps before
3237    the subtraction. The advantage is that it is easy to implement and ensure
3238    correctness without worrying about underflow from the subtraction.
3239 
3240    inputs: u   a n+m digit integer in base b (b is 2^MP_DIGIT_BIT)
3241 		   v   a n   digit integer in base b (b is 2^MP_DIGIT_BIT)
3242 		   n >= 1
3243 		   m >= 0
3244   outputs: u / v stored in u
3245 		   u % v stored in v
3246  */
3247 static mp_result
3248 s_udiv_knuth(mp_int u, mp_int v)
3249 {
3250 	/* Force signs to positive */
3251 	u->sign = MP_ZPOS;
3252 	v->sign = MP_ZPOS;
3253 
3254 	/* Use simple division algorithm when v is only one digit long */
3255 	if (MP_USED(v) == 1)
3256 	{
3257 		mp_digit	d,
3258 					rem;
3259 
3260 		d = v->digits[0];
3261 		rem = s_ddiv(u, d);
3262 		mp_int_set_value(v, rem);
3263 		return MP_OK;
3264 	}
3265 
3266 	/*
3267 	 * Algorithm D
3268 	 *
3269 	 * The n and m variables are defined as used by Knuth. u is an n digit
3270 	 * number with digits u_{n-1}..u_0. v is an n+m digit number with digits
3271 	 * from v_{m+n-1}..v_0. We require that n > 1 and m >= 0
3272 	 */
3273 	mp_size		n = MP_USED(v);
3274 	mp_size		m = MP_USED(u) - n;
3275 
3276 	assert(n > 1);
3277 	/* assert(m >= 0) follows because m is unsigned. */
3278 
3279 	/*
3280 	 * D1: Normalize. The normalization step provides the necessary condition
3281 	 * for Theorem B, which states that the quotient estimate for q_j, call it
3282 	 * qhat
3283 	 *
3284 	 * qhat = u_{j+n}u_{j+n-1} / v_{n-1}
3285 	 *
3286 	 * is bounded by
3287 	 *
3288 	 * qhat - 2 <= q_j <= qhat.
3289 	 *
3290 	 * That is, qhat is always greater than the actual quotient digit q, and
3291 	 * it is never more than two larger than the actual quotient digit.
3292 	 */
3293 	int			k = s_norm(u, v);
3294 
3295 	/*
3296 	 * Extend size of u by one if needed.
3297 	 *
3298 	 * The algorithm begins with a value of u that has one more digit of
3299 	 * input. The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0.
3300 	 * If the multiplication did not increase the number of digits of u, we
3301 	 * need to add a leading zero here.
3302 	 */
3303 	if (k == 0 || MP_USED(u) != m + n + 1)
3304 	{
3305 		if (!s_pad(u, m + n + 1))
3306 			return MP_MEMORY;
3307 		u->digits[m + n] = 0;
3308 		u->used = m + n + 1;
3309 	}
3310 
3311 	/*
3312 	 * Add a leading 0 to v.
3313 	 *
3314 	 * The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0.  We need
3315 	 * to add the leading zero to v here to ensure that the multiplication
3316 	 * will produce the full n+1 digit result.
3317 	 */
3318 	if (!s_pad(v, n + 1))
3319 		return MP_MEMORY;
3320 	v->digits[n] = 0;
3321 
3322 	/*
3323 	 * Initialize temporary variables q and t. q allocates space for m+1
3324 	 * digits to store the quotient digits t allocates space for n+1 digits to
3325 	 * hold the result of q_j*v
3326 	 */
3327 	DECLARE_TEMP(2);
3328 	REQUIRE(GROW(TEMP(0), m + 1));
3329 	REQUIRE(GROW(TEMP(1), n + 1));
3330 
3331 	/* D2: Initialize j */
3332 	int			j = m;
3333 	mpz_t		r;
3334 
3335 	r.digits = MP_DIGITS(u) + j;	/* The contents of r are shared with u */
3336 	r.used = n + 1;
3337 	r.sign = MP_ZPOS;
3338 	r.alloc = MP_ALLOC(u);
3339 	ZERO(TEMP(1)->digits, TEMP(1)->alloc);
3340 
3341 	/* Calculate the m+1 digits of the quotient result */
3342 	for (; j >= 0; j--)
3343 	{
3344 		/* D3: Calculate q' */
3345 		/* r->digits is aligned to position j of the number u */
3346 		mp_word		pfx,
3347 					qhat;
3348 
3349 		pfx = r.digits[n];
3350 		pfx <<= MP_DIGIT_BIT / 2;
3351 		pfx <<= MP_DIGIT_BIT / 2;
3352 		pfx |= r.digits[n - 1]; /* pfx = u_{j+n}{j+n-1} */
3353 
3354 		qhat = pfx / v->digits[n - 1];
3355 
3356 		/*
3357 		 * Check to see if qhat > b, and decrease qhat if so. Theorem B
3358 		 * guarantess that qhat is at most 2 larger than the actual value, so
3359 		 * it is possible that qhat is greater than the maximum value that
3360 		 * will fit in a digit
3361 		 */
3362 		if (qhat > MP_DIGIT_MAX)
3363 			qhat = MP_DIGIT_MAX;
3364 
3365 		/*
3366 		 * D4,D5,D6: Multiply qhat * v and test for a correct value of q
3367 		 *
3368 		 * We proceed a bit different than the way described by Knuth. This
3369 		 * way is simpler but less efficent. Instead of doing the multiply and
3370 		 * subtract then checking for underflow, we first do the multiply of
3371 		 * qhat * v and see if it is larger than the current remainder r. If
3372 		 * it is larger, we decrease qhat by one and try again. We may need to
3373 		 * decrease qhat one more time before we get a value that is smaller
3374 		 * than r.
3375 		 *
3376 		 * This way is less efficent than Knuth becuase we do more multiplies,
3377 		 * but we do not need to worry about underflow this way.
3378 		 */
3379 		/* t = qhat * v */
3380 		s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1);
3381 		TEMP(1)->used = n + 1;
3382 		CLAMP(TEMP(1));
3383 
3384 		/* Clamp r for the comparison. Comparisons do not like leading zeros. */
3385 		CLAMP(&r);
3386 		if (s_ucmp(TEMP(1), &r) > 0)
3387 		{						/* would the remainder be negative? */
3388 			qhat -= 1;			/* try a smaller q */
3389 			s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1);
3390 			TEMP(1)->used = n + 1;
3391 			CLAMP(TEMP(1));
3392 			if (s_ucmp(TEMP(1), &r) > 0)
3393 			{					/* would the remainder be negative? */
3394 				assert(qhat > 0);
3395 				qhat -= 1;		/* try a smaller q */
3396 				s_dbmul(MP_DIGITS(v), (mp_digit) qhat, TEMP(1)->digits, n + 1);
3397 				TEMP(1)->used = n + 1;
3398 				CLAMP(TEMP(1));
3399 			}
3400 			assert(s_ucmp(TEMP(1), &r) <= 0 && "The mathematics failed us.");
3401 		}
3402 
3403 		/*
3404 		 * Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be
3405 		 * n+1 digits long.
3406 		 */
3407 		r.used = n + 1;
3408 
3409 		/*
3410 		 * D4: Multiply and subtract
3411 		 *
3412 		 * Note: The multiply was completed above so we only need to subtract
3413 		 * here.
3414 		 */
3415 		s_usub(r.digits, TEMP(1)->digits, r.digits, r.used, TEMP(1)->used);
3416 
3417 		/*
3418 		 * D5: Test remainder
3419 		 *
3420 		 * Note: Not needed because we always check that qhat is the correct
3421 		 * value before performing the subtract.  Value cast to mp_digit to
3422 		 * prevent warning, qhat has been clamped to MP_DIGIT_MAX
3423 		 */
3424 		TEMP(0)->digits[j] = (mp_digit) qhat;
3425 
3426 		/*
3427 		 * D6: Add back Note: Not needed because we always check that qhat is
3428 		 * the correct value before performing the subtract.
3429 		 */
3430 
3431 		/* D7: Loop on j */
3432 		r.digits--;
3433 		ZERO(TEMP(1)->digits, TEMP(1)->alloc);
3434 	}
3435 
3436 	/* Get rid of leading zeros in q */
3437 	TEMP(0)->used = m + 1;
3438 	CLAMP(TEMP(0));
3439 
3440 	/* Denormalize the remainder */
3441 	CLAMP(u);					/* use u here because the r.digits pointer is
3442 								 * off-by-one */
3443 	if (k != 0)
3444 		s_qdiv(u, k);
3445 
3446 	mp_int_copy(u, v);			/* ok:  0 <= r < v */
3447 	mp_int_copy(TEMP(0), u);	/* ok:  q <= u     */
3448 
3449 	CLEANUP_TEMP();
3450 	return MP_OK;
3451 }
3452 
3453 static int
3454 s_outlen(mp_int z, mp_size r)
3455 {
3456 	assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX);
3457 
3458 	mp_result	bits = mp_int_count_bits(z);
3459 	double		raw = (double) bits * s_log2[r];
3460 
3461 	return (int) (raw + 0.999999);
3462 }
3463 
3464 static mp_size
3465 s_inlen(int len, mp_size r)
3466 {
3467 	double		raw = (double) len / s_log2[r];
3468 	mp_size		bits = (mp_size) (raw + 0.5);
3469 
3470 	return (mp_size) ((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1;
3471 }
3472 
3473 static int
3474 s_ch2val(char c, int r)
3475 {
3476 	int			out;
3477 
3478 	/*
3479 	 * In some locales, isalpha() accepts characters outside the range A-Z,
3480 	 * producing out<0 or out>=36.  The "out >= r" check will always catch
3481 	 * out>=36.  Though nothing explicitly catches out<0, our caller reacts
3482 	 * the same way to every negative return value.
3483 	 */
3484 	if (isdigit((unsigned char) c))
3485 		out = c - '0';
3486 	else if (r > 10 && isalpha((unsigned char) c))
3487 		out = toupper((unsigned char) c) - 'A' + 10;
3488 	else
3489 		return -1;
3490 
3491 	return (out >= r) ? -1 : out;
3492 }
3493 
3494 static char
3495 s_val2ch(int v, int caps)
3496 {
3497 	assert(v >= 0);
3498 
3499 	if (v < 10)
3500 	{
3501 		return v + '0';
3502 	}
3503 	else
3504 	{
3505 		char		out = (v - 10) + 'a';
3506 
3507 		if (caps)
3508 		{
3509 			return toupper((unsigned char) out);
3510 		}
3511 		else
3512 		{
3513 			return out;
3514 		}
3515 	}
3516 }
3517 
3518 static void
3519 s_2comp(unsigned char *buf, int len)
3520 {
3521 	unsigned short s = 1;
3522 
3523 	for (int i = len - 1; i >= 0; --i)
3524 	{
3525 		unsigned char c = ~buf[i];
3526 
3527 		s = c + s;
3528 		c = s & UCHAR_MAX;
3529 		s >>= CHAR_BIT;
3530 
3531 		buf[i] = c;
3532 	}
3533 
3534 	/* last carry out is ignored */
3535 }
3536 
3537 static mp_result
3538 s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad)
3539 {
3540 	int			pos = 0,
3541 				limit = *limpos;
3542 	mp_size		uz = MP_USED(z);
3543 	mp_digit   *dz = MP_DIGITS(z);
3544 
3545 	while (uz > 0 && pos < limit)
3546 	{
3547 		mp_digit	d = *dz++;
3548 		int			i;
3549 
3550 		for (i = sizeof(mp_digit); i > 0 && pos < limit; --i)
3551 		{
3552 			buf[pos++] = (unsigned char) d;
3553 			d >>= CHAR_BIT;
3554 
3555 			/* Don't write leading zeroes */
3556 			if (d == 0 && uz == 1)
3557 				i = 0;			/* exit loop without signaling truncation */
3558 		}
3559 
3560 		/* Detect truncation (loop exited with pos >= limit) */
3561 		if (i > 0)
3562 			break;
3563 
3564 		--uz;
3565 	}
3566 
3567 	if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1)))
3568 	{
3569 		if (pos < limit)
3570 		{
3571 			buf[pos++] = 0;
3572 		}
3573 		else
3574 		{
3575 			uz = 1;
3576 		}
3577 	}
3578 
3579 	/* Digits are in reverse order, fix that */
3580 	REV(buf, pos);
3581 
3582 	/* Return the number of bytes actually written */
3583 	*limpos = pos;
3584 
3585 	return (uz == 0) ? MP_OK : MP_TRUNC;
3586 }
3587 
3588 /* Here there be dragons */
3589