1 /*
2 * Copyright (c) 2005 Beeyond Software Holding BV
3 *
4 * This library is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU Lesser General Public
6 * License as published by the Free Software Foundation; either
7 * version 2.1 of the License, or (at your option) any later version.
8 *
9 * This library is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 * Lesser General Public License for more details.
13 *
14 * You should have received a copy of the GNU Lesser General Public
15 * License along with this library; if not, write to the Free Software
16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17 */
18
19 #define BEECRYPT_CXX_DLL_EXPORT
20
21 #ifdef HAVE_CONFIG_H
22 # include "config.h"
23 #endif
24
25 #include "beecrypt/c++/math/BigInteger.h"
26 #include "beecrypt/c++/lang/Character.h"
27 using beecrypt::lang::Character;
28 #include "beecrypt/c++/lang/StringBuilder.h"
29 using beecrypt::lang::StringBuilder;
30 #include "beecrypt/c++/lang/ArithmeticException.h"
31 using beecrypt::lang::ArithmeticException;
32 #include "beecrypt/c++/lang/OutOfMemoryError.h"
33 using beecrypt::lang::OutOfMemoryError;
34
35 #include <cstdlib>
36
37 using namespace beecrypt::math;
38
39 namespace {
40 const String STRZERO("0");
41 }
42
43 const BigInteger BigInteger::ZERO;
44 const BigInteger BigInteger::ONE(1);
45 const BigInteger BigInteger::TEN(10);
46
BigInteger()47 BigInteger::BigInteger() : size(0), data(0), sign(0)
48 {
49 }
50
BigInteger(jlong val)51 BigInteger::BigInteger(jlong val)
52 {
53 if (val == 0)
54 {
55 size = 0;
56 data = 0;
57 sign = 0;
58 }
59 else
60 {
61 if (val < 0)
62 {
63 sign = -1;
64 val = -val;
65 }
66 else
67 sign = 1;
68
69 if (sizeof(jlong) == sizeof(mpw))
70 size = 1;
71 else
72 size = 2;
73
74 data = (mpw*) malloc(size * sizeof(mpw));
75 if (data == 0)
76 throw OutOfMemoryError();
77
78 if (sizeof(jlong) == sizeof(mpw))
79 {
80 data[0] = val;
81 }
82 else
83 {
84 data[0] = (val >> 32);
85 data[1] = (val );
86 }
87 }
88 }
89
BigInteger(size_t size,mpw * data,int sign)90 BigInteger::BigInteger(size_t size, mpw* data, int sign) : size(size), data(data), sign(sign)
91 {
92 }
93
BigInteger(const bytearray & val)94 BigInteger::BigInteger(const bytearray& val)
95 {
96 if (val.size() == 0)
97 {
98 size = 0;
99 data = 0;
100 sign = 0;
101 }
102 else
103 {
104 int skip = 0;
105 while ((skip < val.size()) && (val[skip] == 0))
106 skip++;
107
108 size = MP_BYTES_TO_WORDS(val.size() - skip + MP_WBYTES - 1);
109
110 data = (mpw*) malloc(size * sizeof(mpw));
111 if (data == 0)
112 throw OutOfMemoryError();
113
114 os2ip(data, size, val.data(), val.size());
115
116 if (val[0] & 0x80)
117 {
118 mpneg(size, data);
119 sign = -1;
120 }
121 else
122 sign = 1;
123 }
124 }
125
BigInteger(const mpnumber & n)126 BigInteger::BigInteger(const mpnumber& n)
127 {
128 if (mpz(n.size, n.data))
129 {
130 size = 0;
131 data = 0;
132 sign = 0;
133 }
134 else
135 {
136 size_t sigbits = mpbits(n.size, n.data);
137
138 size = MP_BITS_TO_WORDS(sigbits + MP_WBITS - 1);
139 data = (mpw*) malloc(size * sizeof(mpw));
140 if (data == 0)
141 throw new OutOfMemoryError();
142
143 // eliminate zero most-significant-words
144 mpcopy(size, data, n.data + n.size - size);
145
146 sign = 1;
147 }
148 }
149
BigInteger(const mpbarrett & b)150 BigInteger::BigInteger(const mpbarrett& b)
151 {
152 if (mpz(b.size, b.modl))
153 {
154 size = 0;
155 data = 0;
156 sign = 0;
157 }
158 else
159 {
160 size_t sigbits = mpbits(b.size, b.modl);
161
162 size = MP_BITS_TO_WORDS(sigbits + MP_WBITS - 1);
163 data = (mpw*) malloc(size * sizeof(mpw));
164 if (data == 0)
165 throw new OutOfMemoryError();
166
167 // eliminate zero most-significant-words
168 mpcopy(size, data, b.modl + b.size - size);
169
170 sign = 1;
171 }
172 }
173
BigInteger(const BigInteger & copy)174 BigInteger::BigInteger(const BigInteger& copy) : size(copy.size), sign(copy.sign)
175 {
176 if (sign)
177 {
178 data = (mpw*) malloc(size * sizeof(mpw));
179 if (data == 0)
180 throw OutOfMemoryError();
181 mpcopy(size, data, copy.data);
182 }
183 else
184 data = 0;
185 }
186
~BigInteger()187 BigInteger::~BigInteger()
188 {
189 if (sign)
190 free(data);
191 }
192
operator =(const BigInteger & copy)193 BigInteger& BigInteger::operator=(const BigInteger& copy)
194 {
195 if (copy.sign == 0)
196 {
197 if (sign)
198 {
199 delete data;
200 size = 0;
201 data = 0;
202 sign = 0;
203 }
204 }
205 else
206 {
207 if (size != copy.size)
208 {
209 data = (mpw*) realloc(data, (size = copy.size) * sizeof(mpw));
210 if (data == 0)
211 throw OutOfMemoryError();
212 }
213 mpcopy(size, data, copy.data);
214 sign = copy.sign;
215 }
216
217 return *this;
218 }
219
operator ==(const BigInteger & val) const220 bool BigInteger::operator==(const BigInteger& val) const throw ()
221 {
222 return (sign == val.sign) && ((sign == 0) || mpeqx(size, data, val.size, val.data));
223 }
224
operator !=(const BigInteger & val) const225 bool BigInteger::operator!=(const BigInteger& val) const throw ()
226 {
227 return (sign != val.sign) || ((sign != 0) && mpnex(size, data, val.size, val.data));
228 }
229
valueOf(jlong val)230 BigInteger BigInteger::valueOf(jlong val)
231 {
232 return BigInteger(val);
233 }
234
hashCode() const235 jint BigInteger::hashCode() const throw ()
236 {
237 return 0;
238 }
239
byteValue() const240 jbyte BigInteger::byteValue() const throw ()
241 {
242 return (jbyte) (sign * data[size-1]);
243 }
244
shortValue() const245 jshort BigInteger::shortValue() const throw ()
246 {
247 return (jshort) (sign * data[size-1]);
248 }
249
intValue() const250 jint BigInteger::intValue() const throw ()
251 {
252 return (jint) (sign * data[size-1]);
253 }
254
longValue() const255 jlong BigInteger::longValue() const throw ()
256 {
257 #if MP_WBITS == 64
258 return (jlong) (sign * data[size-1]);
259 #else
260 if (size == 1)
261 return (jlong) (sign * data[size-1]);
262 else
263 return (jlong) (sign * ((data[size-2] << 32) + data[size-1]));
264 #endif
265 }
266
compareTo(const BigInteger & val) const267 jint BigInteger::compareTo(const BigInteger& val) const throw ()
268 {
269 if (sign == val.sign)
270 {
271 if (sign == 0)
272 return 0;
273
274 return sign * mpcmpx(size, data, val.size, val.data);
275 }
276 else
277 return (sign > val.sign) ? 1 : -1;
278 }
279
equals(const Object * obj) const280 bool BigInteger::equals(const Object* obj) const throw ()
281 {
282 if (this == obj)
283 return true;
284
285 const BigInteger* cmp = dynamic_cast<const BigInteger*>(obj);
286 if (cmp)
287 {
288 if (sign != cmp->sign)
289 return false;
290
291 if (sign && mpnex(size, data, cmp->size, cmp->data))
292 return false;
293
294 return true;
295 }
296
297 return false;
298 }
299
toString() const300 String BigInteger::toString() const throw ()
301 {
302 return toString(10);
303 }
304
toString(int radix) const305 String BigInteger::toString(int radix) const
306 {
307 if (radix < Character::MIN_RADIX || radix > Character::MAX_RADIX)
308 radix = 10;
309
310 if (sign == 0)
311 return STRZERO;
312
313 StringBuilder tmp;
314
315 /* allocate enough space to hold a copy of this (size+1), result (size+2) and workspace (2) */
316 mpw* rdata = (mpw*) malloc((2*size+5) * sizeof(mpw));
317 if (rdata)
318 {
319 mpw* result = rdata+size+1;
320 mpw* wksp = result+size+2;
321 mpw nradix = radix;
322
323 mpsetx(size+1, rdata, size, data);
324
325 size_t shift = mpnorm(1, &nradix);
326 mplshift(size+1, rdata, shift);
327
328 do
329 {
330 int remainder;
331
332 mpndivmod(result, size+1, rdata, 1, &nradix, wksp);
333 remainder = (result[size+1] >> shift);
334 mpcopy(size+1, rdata, result);
335 mplshift(size+1, rdata, shift);
336 if (remainder < 10)
337 tmp.append((jchar)(48 + remainder));
338 else
339 tmp.append((jchar)(55 + remainder));
340 } while (mpnz(size+1, rdata));
341
342 free(rdata);
343
344 if (sign < 0)
345 tmp.append('-');
346
347 return tmp.reverse().toString();
348 }
349 else
350 throw OutOfMemoryError();
351 }
352
signum() const353 jint BigInteger::signum() const throw ()
354 {
355 return sign;
356 }
357
add(const BigInteger & val) const358 BigInteger BigInteger::add(const BigInteger& val) const
359 {
360 if (val.sign == 0)
361 return *this;
362
363 if (sign == 0)
364 return val;
365
366 if (sign == val.sign)
367 {
368 size_t rsize = size > val.size ? size : val.size;
369 // allocate one extra word for addition carry-over
370 mpw* rdata = (mpw*) malloc((rsize+1) * sizeof(mpw));
371 if (rdata == 0)
372 throw OutOfMemoryError();
373
374 mpsetx(rsize, rdata, size, data);
375 if (mpaddx(rsize, rdata, val.size, val.data))
376 {
377 // there was a carry-over; move result up by one word
378 mpmove(rsize, rdata+1, rdata);
379 rsize++;
380 rdata[0] = 1;
381 }
382
383 return BigInteger(rsize, rdata, sign);
384 }
385 else
386 {
387 int cmp = mpcmpx(size, data, val.size, val.data);
388 if (cmp == 0)
389 return ZERO;
390
391 // subtract the smallest from the biggest value
392
393 size_t rsize = (cmp < 0) ? val.size : size;
394 mpw* rdata = (mpw*) malloc(rsize * sizeof(mpw));
395 if (rdata == 0)
396 throw OutOfMemoryError();
397
398 if (cmp < 0)
399 {
400 mpcopy(rsize, rdata, val.data);
401 mpsubx(rsize, rdata, size, data);
402 }
403 else
404 {
405 mpcopy(rsize, rdata, data);
406 mpsubx(rsize, rdata, val.size, val.data);
407 }
408
409 size_t skip = 0;
410
411 while ((skip < rsize) && rdata[skip] == 0)
412 skip++;
413
414 if (skip == rsize)
415 {
416 free(rdata);
417 return ZERO;
418 }
419
420 if (skip)
421 {
422 rsize -= skip;
423 mpmove(rsize, rdata, rdata+skip);
424 }
425
426 return BigInteger(rsize, rdata, cmp * sign);
427 }
428 }
429
subtract(const BigInteger & val) const430 BigInteger BigInteger::subtract(const BigInteger& val) const
431 {
432 if (sign == 0)
433 return val.negate();
434
435 if (val.sign == 0)
436 return *this;
437
438 if (sign != val.sign)
439 {
440 size_t rsize = size > val.size ? size : val.size;
441 // allocate one extra word for addition carry-over
442 mpw* rdata = (mpw*) malloc((rsize+1) * sizeof(mpw));
443 if (rdata == 0)
444 throw OutOfMemoryError();
445
446 mpsetx(rsize, rdata, size, data);
447 if (mpaddx(rsize, rdata, val.size, val.data))
448 {
449 // there was a carry-over; move result up by one word
450 mpmove(rsize, rdata+1, rdata);
451 rsize++;
452 rdata[0] = 1;
453 }
454
455 return BigInteger(rsize, rdata, sign);
456 }
457 else
458 {
459 int cmp = mpcmpx(size, data, val.size, val.data);
460 if (cmp == 0)
461 return ZERO;
462
463 // subtract the smallest from the biggest value, so we don't get a carry
464 size_t rsize = (cmp < 0) ? val.size : size;
465 mpw* rdata = (mpw*) malloc(rsize * sizeof(mpw));
466 if (rdata == 0)
467 throw OutOfMemoryError();
468
469 if (cmp < 0)
470 {
471 mpcopy(rsize, rdata, val.data);
472 mpsubx(rsize, rdata, size, data);
473 }
474 else
475 {
476 mpcopy(rsize, rdata, data);
477 mpsubx(rsize, rdata, val.size, val.data);
478 }
479
480 size_t skip = 0;
481 while ((skip < rsize) && (rdata[skip] == 0))
482 skip++;
483
484 if (skip)
485 {
486 rsize -= skip;
487 mpmove(rsize, rdata, rdata+skip);
488 }
489
490 return BigInteger(rsize, rdata, cmp * sign);
491 }
492 }
493
multiply(const BigInteger & val) const494 BigInteger BigInteger::multiply(const BigInteger& val) const
495 {
496 if (sign == 0 || val.sign == 0)
497 return BigInteger();
498
499 size_t rsize = size + val.size;
500 mpw* rdata = (mpw*) malloc(rsize * sizeof(mpw));
501 if (rdata == 0)
502 throw OutOfMemoryError();
503
504 mpmul(rdata, size, data, val.size, val.data);
505 if (rdata[0] == 0)
506 {
507 rsize--;
508 mpmove(rsize, rdata, rdata+1);
509 }
510
511 return BigInteger(rsize, rdata, sign * val.sign);
512 }
513
514 #if 0
515 BigInteger BigInteger::mod(const BigInteger& m) const throw (ArithmeticException)
516 {
517 if (m.compareTo(ZERO) <= 0)
518 throw ArithmeticException("m must be > 0");
519
520 if (mpltx(size, data, m.size, m.data))
521 {
522 if (sign == -1)
523 return m.subtract(*this);
524 else
525 return *this;
526 }
527 else
528 {
529 mpw* tmp = (mpw*) malloc(size + 2*m.size+1);
530 if (tmp == 0)
531 throw OutOfMemoryError();
532
533 mpmod(tmp, size, data, m.size, m.data, tmp+size);
534 if (sign == -1)
535 {
536 mpneg(size, tmp);
537 mpaddx(size, tmp, m.size, m.data);
538 }
539
540 size_t skip = 0;
541
542 while ((skip < size) && (tmp[skip] == 0))
543 skip++;
544
545 if (skip == size)
546 {
547 free(tmp);
548 return ZERO;
549 }
550
551 if (skip)
552 mpmove(size - skip, tmp, tmp + skip);
553
554 return BigInteger(size - skip, tmp, 1);
555 }
556 }
557
558 BigInteger BigInteger::modPow(const BigInteger& exponent, const BigInteger& m) const
559 {
560 // if the modulus is not positive, bail out
561 if (m.sign <= 0)
562 throw ArithmeticException("modulus must be > 0");
563
564 // this ^ 0 mod m is 1 except when m == 1, when it is 0
565 // 1 ^ exponent mod m is 1 except when m == 1, when it is 0
566 if (exponent.sign == 0 || equals(&ONE))
567 {
568 if (m.equals(&ONE))
569 return ZERO;
570 else
571 return ONE;
572 }
573
574 // 0 ^ exponent mod m is 0 when exponent != 0 (which was already excluded earlier)
575 if (equals(&ZERO))
576 return ZERO;
577
578 // we need to bring this into the range 0 .. (m-1)
579
580 // if the exponent is negative, compute the modular inverse of this before
581 }
582 #endif
583
negate() const584 BigInteger BigInteger::negate() const
585 {
586 if (mpz(size, data))
587 return BigInteger();
588
589 mpw* rdata = (mpw*) malloc(size * sizeof(mpw));
590 if (rdata == 0)
591 throw OutOfMemoryError();
592
593 mpcopy(size, rdata, data);
594
595 return BigInteger(size, rdata, -sign);
596 }
597
toByteArray() const598 bytearray* BigInteger::toByteArray() const
599 {
600 bytearray* result = new bytearray();
601
602 toByteArray(*result);
603
604 return result;
605 }
606
toByteArray(bytearray & b) const607 void BigInteger::toByteArray(bytearray& b) const
608 {
609 if (sign == 0)
610 {
611 b.resize(1);
612 b[0] = 0;
613 }
614 else if (sign == 1)
615 {
616 size_t sigbits = mpbits(size, data);
617 size_t req = (sigbits+8) >> 3;
618
619 b.resize(req);
620
621 i2osp(b.data(), req, data, size);
622 }
623 else
624 {
625 size_t sigbits = mpbits(size, data);
626 size_t req = (sigbits+7) >> 3;
627
628 b.resize(req);
629
630 mpw* tmp = (mpw*) malloc(size * sizeof(mpw));
631 if (tmp == 0)
632 throw OutOfMemoryError();
633
634 mpcopy(size, tmp, data);
635 mpneg(size, tmp);
636
637 i2osp(b.data(), req, tmp, size);
638
639 free(tmp);
640 }
641 }
642
643 #if 0
644 void beecrypt::math::transform(BigInteger& b, const mpnumber& n)
645 {
646 if (mpz(n.size, n.data))
647 {
648 if (b.sign)
649 {
650 free(b.data);
651 b.size = 0;
652 b.data = 0;
653 b.sign = 0;
654 }
655 }
656 else
657 {
658 if (b.size != n.size)
659 {
660 b.data = (mpw*) realloc(b.data, b.size * sizeof(mpw));
661 if (b.data == 0)
662 throw OutOfMemoryError();
663 }
664 mpcopy(b.size, b.data, n.data);
665 b.sign = 1;
666 }
667 }
668 #endif
669
transform(mpnumber & n,const BigInteger & val)670 void beecrypt::math::transform(mpnumber& n, const BigInteger& val)
671 {
672 switch (val.sign)
673 {
674 case 0:
675 mpnfree(&n);
676 break;
677 case 1:
678 mpnset(&n, val.size, val.data);
679 break;
680 default:
681 throw IllegalArgumentException("can only transform non-negative numbers");
682 }
683 }
684
transform(mpbarrett & b,const BigInteger & val)685 void beecrypt::math::transform(mpbarrett& b, const BigInteger& val)
686 {
687 switch (val.sign)
688 {
689 case 0:
690 mpbfree(&b);
691 break;
692 case 1:
693 mpbset(&b, val.size, val.data);
694 break;
695 default:
696 throw IllegalArgumentException("can only transform non-negative numbers");
697 }
698 }
699