1 /*
2  * Copyright (c) 2005 Beeyond Software Holding BV
3  *
4  * This library is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * This library is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with this library; if not, write to the Free Software
16  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17  */
18 
19 #define BEECRYPT_CXX_DLL_EXPORT
20 
21 #ifdef HAVE_CONFIG_H
22 # include "config.h"
23 #endif
24 
25 #include "beecrypt/c++/math/BigInteger.h"
26 #include "beecrypt/c++/lang/Character.h"
27 using beecrypt::lang::Character;
28 #include "beecrypt/c++/lang/StringBuilder.h"
29 using beecrypt::lang::StringBuilder;
30 #include "beecrypt/c++/lang/ArithmeticException.h"
31 using beecrypt::lang::ArithmeticException;
32 #include "beecrypt/c++/lang/OutOfMemoryError.h"
33 using beecrypt::lang::OutOfMemoryError;
34 
35 #include <cstdlib>
36 
37 using namespace beecrypt::math;
38 
39 namespace {
40 	const String STRZERO("0");
41 }
42 
43 const BigInteger BigInteger::ZERO;
44 const BigInteger BigInteger::ONE(1);
45 const BigInteger BigInteger::TEN(10);
46 
BigInteger()47 BigInteger::BigInteger() : size(0), data(0), sign(0)
48 {
49 }
50 
BigInteger(jlong val)51 BigInteger::BigInteger(jlong val)
52 {
53 	if (val == 0)
54 	{
55 		size = 0;
56 		data = 0;
57 		sign = 0;
58 	}
59 	else
60 	{
61 		if (val < 0)
62 		{
63 			sign = -1;
64 			val = -val;
65 		}
66 		else
67 			sign = 1;
68 
69 		if (sizeof(jlong) == sizeof(mpw))
70 			size = 1;
71 		else
72 			size = 2;
73 
74 		data = (mpw*) malloc(size * sizeof(mpw));
75 		if (data == 0)
76 			throw OutOfMemoryError();
77 
78 		if (sizeof(jlong) == sizeof(mpw))
79 		{
80 			data[0] = val;
81 		}
82 		else
83 		{
84 			data[0] = (val >> 32);
85 			data[1] = (val      );
86 		}
87 	}
88 }
89 
BigInteger(size_t size,mpw * data,int sign)90 BigInteger::BigInteger(size_t size, mpw* data, int sign) : size(size), data(data), sign(sign)
91 {
92 }
93 
BigInteger(const bytearray & val)94 BigInteger::BigInteger(const bytearray& val)
95 {
96 	if (val.size() == 0)
97 	{
98 		size = 0;
99 		data = 0;
100 		sign = 0;
101 	}
102 	else
103 	{
104 		int skip = 0;
105 		while ((skip < val.size()) && (val[skip] == 0))
106 			skip++;
107 
108 		size = MP_BYTES_TO_WORDS(val.size() - skip + MP_WBYTES - 1);
109 
110 		data = (mpw*) malloc(size * sizeof(mpw));
111 		if (data == 0)
112 			throw OutOfMemoryError();
113 
114 		os2ip(data, size, val.data(), val.size());
115 
116 		if (val[0] & 0x80)
117 		{
118 			mpneg(size, data);
119 			sign = -1;
120 		}
121 		else
122 			sign = 1;
123 	}
124 }
125 
BigInteger(const mpnumber & n)126 BigInteger::BigInteger(const mpnumber& n)
127 {
128 	if (mpz(n.size, n.data))
129 	{
130 		size = 0;
131 		data = 0;
132 		sign = 0;
133 	}
134 	else
135 	{
136 		size_t sigbits = mpbits(n.size, n.data);
137 
138 		size = MP_BITS_TO_WORDS(sigbits + MP_WBITS - 1);
139 		data = (mpw*) malloc(size * sizeof(mpw));
140 		if (data == 0)
141 			throw new OutOfMemoryError();
142 
143 		// eliminate zero most-significant-words
144 		mpcopy(size, data, n.data + n.size - size);
145 
146 		sign = 1;
147 	}
148 }
149 
BigInteger(const mpbarrett & b)150 BigInteger::BigInteger(const mpbarrett& b)
151 {
152 	if (mpz(b.size, b.modl))
153 	{
154 		size = 0;
155 		data = 0;
156 		sign = 0;
157 	}
158 	else
159 	{
160 		size_t sigbits = mpbits(b.size, b.modl);
161 
162 		size = MP_BITS_TO_WORDS(sigbits + MP_WBITS - 1);
163 		data = (mpw*) malloc(size * sizeof(mpw));
164 		if (data == 0)
165 			throw new OutOfMemoryError();
166 
167 		// eliminate zero most-significant-words
168 		mpcopy(size, data, b.modl + b.size - size);
169 
170 		sign = 1;
171 	}
172 }
173 
BigInteger(const BigInteger & copy)174 BigInteger::BigInteger(const BigInteger& copy) : size(copy.size), sign(copy.sign)
175 {
176 	if (sign)
177 	{
178 		data = (mpw*) malloc(size * sizeof(mpw));
179 		if (data == 0)
180 			throw OutOfMemoryError();
181 		mpcopy(size, data, copy.data);
182 	}
183 	else
184 		data = 0;
185 }
186 
~BigInteger()187 BigInteger::~BigInteger()
188 {
189 	if (sign)
190 		free(data);
191 }
192 
operator =(const BigInteger & copy)193 BigInteger& BigInteger::operator=(const BigInteger& copy)
194 {
195 	if (copy.sign == 0)
196 	{
197 		if (sign)
198 		{
199 			delete data;
200 			size = 0;
201 			data = 0;
202 			sign = 0;
203 		}
204 	}
205 	else
206 	{
207 		if (size != copy.size)
208 		{
209 			data = (mpw*) realloc(data, (size = copy.size) * sizeof(mpw));
210 			if (data == 0)
211 				throw OutOfMemoryError();
212 		}
213 		mpcopy(size, data, copy.data);
214 		sign = copy.sign;
215 	}
216 
217 	return *this;
218 }
219 
operator ==(const BigInteger & val) const220 bool BigInteger::operator==(const BigInteger& val) const throw ()
221 {
222 	return (sign == val.sign) && ((sign == 0) || mpeqx(size, data, val.size, val.data));
223 }
224 
operator !=(const BigInteger & val) const225 bool BigInteger::operator!=(const BigInteger& val) const throw ()
226 {
227 	return (sign != val.sign) || ((sign != 0) && mpnex(size, data, val.size, val.data));
228 }
229 
valueOf(jlong val)230 BigInteger BigInteger::valueOf(jlong val)
231 {
232 	return BigInteger(val);
233 }
234 
hashCode() const235 jint BigInteger::hashCode() const throw ()
236 {
237 	return 0;
238 }
239 
byteValue() const240 jbyte BigInteger::byteValue() const throw ()
241 {
242 	return (jbyte) (sign * data[size-1]);
243 }
244 
shortValue() const245 jshort BigInteger::shortValue() const throw ()
246 {
247 	return (jshort) (sign * data[size-1]);
248 }
249 
intValue() const250 jint BigInteger::intValue() const throw ()
251 {
252 	return (jint) (sign * data[size-1]);
253 }
254 
longValue() const255 jlong BigInteger::longValue() const throw ()
256 {
257 	#if MP_WBITS == 64
258 	return (jlong) (sign * data[size-1]);
259 	#else
260 	if (size == 1)
261 		return (jlong) (sign * data[size-1]);
262 	else
263 		return (jlong) (sign * ((data[size-2] << 32) + data[size-1]));
264 	#endif
265 }
266 
compareTo(const BigInteger & val) const267 jint BigInteger::compareTo(const BigInteger& val) const throw ()
268 {
269 	if (sign == val.sign)
270 	{
271 		if (sign == 0)
272 			return 0;
273 
274 		return sign * mpcmpx(size, data, val.size, val.data);
275 	}
276 	else
277 		return (sign > val.sign) ? 1 : -1;
278 }
279 
equals(const Object * obj) const280 bool BigInteger::equals(const Object* obj) const throw ()
281 {
282 	if (this == obj)
283 		return true;
284 
285 	const BigInteger* cmp = dynamic_cast<const BigInteger*>(obj);
286 	if (cmp)
287 	{
288 		if (sign != cmp->sign)
289 			return false;
290 
291 		if (sign && mpnex(size, data, cmp->size, cmp->data))
292 			return false;
293 
294 		return true;
295 	}
296 
297 	return false;
298 }
299 
toString() const300 String BigInteger::toString() const throw ()
301 {
302 	return toString(10);
303 }
304 
toString(int radix) const305 String BigInteger::toString(int radix) const
306 {
307 	if (radix < Character::MIN_RADIX || radix > Character::MAX_RADIX)
308 		radix = 10;
309 
310 	if (sign == 0)
311 		return STRZERO;
312 
313 	StringBuilder tmp;
314 
315 	/* allocate enough space to hold a copy of this (size+1), result (size+2) and workspace (2) */
316 	mpw* rdata = (mpw*) malloc((2*size+5) * sizeof(mpw));
317 	if (rdata)
318 	{
319 		mpw* result = rdata+size+1;
320 		mpw* wksp = result+size+2;
321 		mpw nradix = radix;
322 
323 		mpsetx(size+1, rdata, size, data);
324 
325 		size_t shift = mpnorm(1, &nradix);
326 		mplshift(size+1, rdata, shift);
327 
328 		do
329 		{
330 			int remainder;
331 
332 			mpndivmod(result, size+1, rdata, 1, &nradix, wksp);
333 			remainder = (result[size+1] >> shift);
334 			mpcopy(size+1, rdata, result);
335 			mplshift(size+1, rdata, shift);
336 			if (remainder < 10)
337 				tmp.append((jchar)(48 + remainder));
338 			else
339 				tmp.append((jchar)(55 + remainder));
340 		} while (mpnz(size+1, rdata));
341 
342 		free(rdata);
343 
344 		if (sign < 0)
345 			tmp.append('-');
346 
347 		return tmp.reverse().toString();
348 	}
349 	else
350 		throw OutOfMemoryError();
351 }
352 
signum() const353 jint BigInteger::signum() const throw ()
354 {
355 	return sign;
356 }
357 
add(const BigInteger & val) const358 BigInteger BigInteger::add(const BigInteger& val) const
359 {
360 	if (val.sign == 0)
361 		return *this;
362 
363 	if (sign == 0)
364 		return val;
365 
366 	if (sign == val.sign)
367 	{
368 		size_t rsize = size > val.size ? size : val.size;
369 		// allocate one extra word for addition carry-over
370 		mpw* rdata = (mpw*) malloc((rsize+1) * sizeof(mpw));
371 		if (rdata == 0)
372 			throw OutOfMemoryError();
373 
374 		mpsetx(rsize, rdata, size, data);
375 		if (mpaddx(rsize, rdata, val.size, val.data))
376 		{
377 			// there was a carry-over; move result up by one word
378 			mpmove(rsize, rdata+1, rdata);
379 			rsize++;
380 			rdata[0] = 1;
381 		}
382 
383 		return BigInteger(rsize, rdata, sign);
384 	}
385 	else
386 	{
387 		int cmp = mpcmpx(size, data, val.size, val.data);
388 		if (cmp == 0)
389 			return ZERO;
390 
391 		// subtract the smallest from the biggest value
392 
393 		size_t rsize = (cmp < 0) ? val.size : size;
394 		mpw* rdata = (mpw*) malloc(rsize * sizeof(mpw));
395 		if (rdata == 0)
396 			throw OutOfMemoryError();
397 
398 		if (cmp < 0)
399 		{
400 			mpcopy(rsize, rdata, val.data);
401 			mpsubx(rsize, rdata, size, data);
402 		}
403 		else
404 		{
405 			mpcopy(rsize, rdata, data);
406 			mpsubx(rsize, rdata, val.size, val.data);
407 		}
408 
409 		size_t skip = 0;
410 
411 		while ((skip < rsize) && rdata[skip] == 0)
412 			skip++;
413 
414 		if (skip == rsize)
415 		{
416 			free(rdata);
417 			return ZERO;
418 		}
419 
420 		if (skip)
421 		{
422 			rsize -= skip;
423 			mpmove(rsize, rdata, rdata+skip);
424 		}
425 
426 		return BigInteger(rsize, rdata, cmp * sign);
427 	}
428 }
429 
subtract(const BigInteger & val) const430 BigInteger BigInteger::subtract(const BigInteger& val) const
431 {
432 	if (sign == 0)
433 		return val.negate();
434 
435 	if (val.sign == 0)
436 		return *this;
437 
438 	if (sign != val.sign)
439 	{
440 		size_t rsize = size > val.size ? size : val.size;
441 		// allocate one extra word for addition carry-over
442 		mpw* rdata = (mpw*) malloc((rsize+1) * sizeof(mpw));
443 		if (rdata == 0)
444 			throw OutOfMemoryError();
445 
446 		mpsetx(rsize, rdata, size, data);
447 		if (mpaddx(rsize, rdata, val.size, val.data))
448 		{
449 			// there was a carry-over; move result up by one word
450 			mpmove(rsize, rdata+1, rdata);
451 			rsize++;
452 			rdata[0] = 1;
453 		}
454 
455 		return BigInteger(rsize, rdata, sign);
456 	}
457 	else
458 	{
459 		int cmp = mpcmpx(size, data, val.size, val.data);
460 		if (cmp == 0)
461 			return ZERO;
462 
463 		// subtract the smallest from the biggest value, so we don't get a carry
464 		size_t rsize = (cmp < 0) ? val.size : size;
465 		mpw* rdata = (mpw*) malloc(rsize * sizeof(mpw));
466 		if (rdata == 0)
467 			throw OutOfMemoryError();
468 
469 		if (cmp < 0)
470 		{
471 			mpcopy(rsize, rdata, val.data);
472 			mpsubx(rsize, rdata, size, data);
473 		}
474 		else
475 		{
476 			mpcopy(rsize, rdata, data);
477 			mpsubx(rsize, rdata, val.size, val.data);
478 		}
479 
480 		size_t skip = 0;
481 		while ((skip < rsize) && (rdata[skip] == 0))
482 			skip++;
483 
484 		if (skip)
485 		{
486 			rsize -= skip;
487 			mpmove(rsize, rdata, rdata+skip);
488 		}
489 
490 		return BigInteger(rsize, rdata, cmp * sign);
491 	}
492 }
493 
multiply(const BigInteger & val) const494 BigInteger BigInteger::multiply(const BigInteger& val) const
495 {
496 	if (sign == 0 || val.sign == 0)
497 		return BigInteger();
498 
499 	size_t rsize = size + val.size;
500 	mpw* rdata = (mpw*) malloc(rsize * sizeof(mpw));
501 	if (rdata == 0)
502 		throw OutOfMemoryError();
503 
504 	mpmul(rdata, size, data, val.size, val.data);
505 	if (rdata[0] == 0)
506 	{
507 		rsize--;
508 		mpmove(rsize, rdata, rdata+1);
509 	}
510 
511 	return BigInteger(rsize, rdata, sign * val.sign);
512 }
513 
514 #if 0
515 BigInteger BigInteger::mod(const BigInteger& m) const throw (ArithmeticException)
516 {
517 	if (m.compareTo(ZERO) <= 0)
518 		throw ArithmeticException("m must be > 0");
519 
520 	if (mpltx(size, data, m.size, m.data))
521 	{
522 		if (sign == -1)
523 			return m.subtract(*this);
524 		else
525 			return *this;
526 	}
527 	else
528 	{
529 		mpw* tmp = (mpw*) malloc(size + 2*m.size+1);
530 		if (tmp == 0)
531 			throw OutOfMemoryError();
532 
533 		mpmod(tmp, size, data, m.size, m.data, tmp+size);
534 		if (sign == -1)
535 		{
536 			mpneg(size, tmp);
537 			mpaddx(size, tmp, m.size, m.data);
538 		}
539 
540 		size_t skip = 0;
541 
542 		while ((skip < size) && (tmp[skip] == 0))
543 			skip++;
544 
545 		if (skip == size)
546 		{
547 			free(tmp);
548 			return ZERO;
549 		}
550 
551 		if (skip)
552 			mpmove(size - skip, tmp, tmp + skip);
553 
554 		return BigInteger(size - skip, tmp, 1);
555 	}
556 }
557 
558 BigInteger BigInteger::modPow(const BigInteger& exponent, const BigInteger& m) const
559 {
560 	// if the modulus is not positive, bail out
561 	if (m.sign <= 0)
562 		throw ArithmeticException("modulus must be > 0");
563 
564 	// this ^ 0 mod m is 1 except when m == 1, when it is 0
565 	// 1 ^ exponent mod m is 1 except when m == 1, when it is 0
566 	if (exponent.sign == 0 || equals(&ONE))
567 	{
568 		if (m.equals(&ONE))
569 			return ZERO;
570 		else
571 			return ONE;
572 	}
573 
574 	// 0 ^ exponent mod m is 0 when exponent != 0 (which was already excluded earlier)
575 	if (equals(&ZERO))
576 		return ZERO;
577 
578 	// we need to bring this into the range 0 .. (m-1)
579 
580 	// if the exponent is negative, compute the modular inverse of this before
581 }
582 #endif
583 
negate() const584 BigInteger BigInteger::negate() const
585 {
586 	if (mpz(size, data))
587 		return BigInteger();
588 
589 	mpw* rdata = (mpw*) malloc(size * sizeof(mpw));
590 	if (rdata == 0)
591 		throw OutOfMemoryError();
592 
593 	mpcopy(size, rdata, data);
594 
595 	return BigInteger(size, rdata, -sign);
596 }
597 
toByteArray() const598 bytearray* BigInteger::toByteArray() const
599 {
600 	bytearray* result = new bytearray();
601 
602 	toByteArray(*result);
603 
604 	return result;
605 }
606 
toByteArray(bytearray & b) const607 void BigInteger::toByteArray(bytearray& b) const
608 {
609 	if (sign == 0)
610 	{
611 		b.resize(1);
612 		b[0] = 0;
613 	}
614 	else if (sign == 1)
615 	{
616 		size_t sigbits = mpbits(size, data);
617 		size_t req = (sigbits+8) >> 3;
618 
619 		b.resize(req);
620 
621 		i2osp(b.data(), req, data, size);
622 	}
623 	else
624 	{
625 		size_t sigbits = mpbits(size, data);
626 		size_t req = (sigbits+7) >> 3;
627 
628 		b.resize(req);
629 
630 		mpw* tmp = (mpw*) malloc(size * sizeof(mpw));
631 		if (tmp == 0)
632 			throw OutOfMemoryError();
633 
634 		mpcopy(size, tmp, data);
635 		mpneg(size, tmp);
636 
637 		i2osp(b.data(), req, tmp, size);
638 
639 		free(tmp);
640 	}
641 }
642 
643 #if 0
644 void beecrypt::math::transform(BigInteger& b, const mpnumber& n)
645 {
646 	if (mpz(n.size, n.data))
647 	{
648 		if (b.sign)
649 		{
650 			free(b.data);
651 			b.size = 0;
652 			b.data = 0;
653 			b.sign = 0;
654 		}
655 	}
656 	else
657 	{
658 		if (b.size != n.size)
659 		{
660 			b.data = (mpw*) realloc(b.data, b.size * sizeof(mpw));
661 			if (b.data == 0)
662 				throw OutOfMemoryError();
663 		}
664 		mpcopy(b.size, b.data, n.data);
665 		b.sign = 1;
666 	}
667 }
668 #endif
669 
transform(mpnumber & n,const BigInteger & val)670 void beecrypt::math::transform(mpnumber& n, const BigInteger& val)
671 {
672 	switch (val.sign)
673 	{
674 	case 0:
675 		mpnfree(&n);
676 		break;
677 	case 1:
678 		mpnset(&n, val.size, val.data);
679 		break;
680 	default:
681 		throw IllegalArgumentException("can only transform non-negative numbers");
682 	}
683 }
684 
transform(mpbarrett & b,const BigInteger & val)685 void beecrypt::math::transform(mpbarrett& b, const BigInteger& val)
686 {
687 	switch (val.sign)
688 	{
689 	case 0:
690 		mpbfree(&b);
691 		break;
692 	case 1:
693 		mpbset(&b, val.size, val.data);
694 		break;
695 	default:
696 		throw IllegalArgumentException("can only transform non-negative numbers");
697 	}
698 }
699