1 /*
2   ==============================================================================
3 
4    This file is part of the JUCE library.
5    Copyright (c) 2017 - ROLI Ltd.
6 
7    JUCE is an open source library subject to commercial or open-source
8    licensing.
9 
10    The code included in this file is provided under the terms of the ISC license
11    http://www.isc.org/downloads/software-support-policy/isc-license. Permission
12    To use, copy, modify, and/or distribute this software for any purpose with or
13    without fee is hereby granted provided that the above copyright notice and
14    this permission notice appear in all copies.
15 
16    JUCE IS PROVIDED "AS IS" WITHOUT ANY WARRANTY, AND ALL WARRANTIES, WHETHER
17    EXPRESSED OR IMPLIED, INCLUDING MERCHANTABILITY AND FITNESS FOR PURPOSE, ARE
18    DISCLAIMED.
19 
20   ==============================================================================
21 */
22 
23 namespace juce
24 {
25 
26 namespace
27 {
bitToMask(const int bit)28     inline uint32 bitToMask  (const int bit) noexcept           { return (uint32) 1 << (bit & 31); }
bitToIndex(const int bit)29     inline size_t bitToIndex (const int bit) noexcept           { return (size_t) (bit >> 5); }
sizeNeededToHold(int highestBit)30     inline size_t sizeNeededToHold (int highestBit) noexcept    { return (size_t) (highestBit >> 5) + 1; }
31 }
32 
findHighestSetBit(uint32 n)33 int findHighestSetBit (uint32 n) noexcept
34 {
35     jassert (n != 0); // (the built-in functions may not work for n = 0)
36 
37   #if JUCE_GCC || JUCE_CLANG
38     return 31 - __builtin_clz (n);
39   #elif JUCE_MSVC
40     unsigned long highest;
41     _BitScanReverse (&highest, n);
42     return (int) highest;
43   #else
44     n |= (n >> 1);
45     n |= (n >> 2);
46     n |= (n >> 4);
47     n |= (n >> 8);
48     n |= (n >> 16);
49     return countNumberOfBits (n >> 1);
50   #endif
51 }
52 
53 //==============================================================================
BigInteger()54 BigInteger::BigInteger()
55     : allocatedSize (numPreallocatedInts)
56 {
57     for (int i = 0; i < numPreallocatedInts; ++i)
58         preallocated[i] = 0;
59 }
60 
BigInteger(const int32 value)61 BigInteger::BigInteger (const int32 value)
62     : allocatedSize (numPreallocatedInts),
63       highestBit (31),
64       negative (value < 0)
65 {
66     preallocated[0] = (uint32) std::abs (value);
67 
68     for (int i = 1; i < numPreallocatedInts; ++i)
69         preallocated[i] = 0;
70 
71     highestBit = getHighestBit();
72 }
73 
BigInteger(const uint32 value)74 BigInteger::BigInteger (const uint32 value)
75     : allocatedSize (numPreallocatedInts),
76       highestBit (31)
77 {
78     preallocated[0] = value;
79 
80     for (int i = 1; i < numPreallocatedInts; ++i)
81         preallocated[i] = 0;
82 
83     highestBit = getHighestBit();
84 }
85 
BigInteger(int64 value)86 BigInteger::BigInteger (int64 value)
87     : allocatedSize (numPreallocatedInts),
88       highestBit (63),
89       negative (value < 0)
90 {
91     if (value < 0)
92         value = -value;
93 
94     preallocated[0] = (uint32) value;
95     preallocated[1] = (uint32) (value >> 32);
96 
97     for (int i = 2; i < numPreallocatedInts; ++i)
98         preallocated[i] = 0;
99 
100     highestBit = getHighestBit();
101 }
102 
BigInteger(const BigInteger & other)103 BigInteger::BigInteger (const BigInteger& other)
104     : allocatedSize (other.allocatedSize),
105       highestBit (other.getHighestBit()),
106       negative (other.negative)
107 {
108     if (allocatedSize > numPreallocatedInts)
109         heapAllocation.malloc (allocatedSize);
110 
111     memcpy (getValues(), other.getValues(), sizeof (uint32) * allocatedSize);
112 }
113 
BigInteger(BigInteger && other)114 BigInteger::BigInteger (BigInteger&& other) noexcept
115     : heapAllocation (std::move (other.heapAllocation)),
116       allocatedSize (other.allocatedSize),
117       highestBit (other.highestBit),
118       negative (other.negative)
119 {
120     memcpy (preallocated, other.preallocated, sizeof (preallocated));
121 }
122 
operator =(BigInteger && other)123 BigInteger& BigInteger::operator= (BigInteger&& other) noexcept
124 {
125     heapAllocation = std::move (other.heapAllocation);
126     memcpy (preallocated, other.preallocated, sizeof (preallocated));
127     allocatedSize = other.allocatedSize;
128     highestBit = other.highestBit;
129     negative = other.negative;
130     return *this;
131 }
132 
~BigInteger()133 BigInteger::~BigInteger()
134 {
135 }
136 
swapWith(BigInteger & other)137 void BigInteger::swapWith (BigInteger& other) noexcept
138 {
139     for (int i = 0; i < numPreallocatedInts; ++i)
140         std::swap (preallocated[i], other.preallocated[i]);
141 
142     heapAllocation.swapWith (other.heapAllocation);
143     std::swap (allocatedSize, other.allocatedSize);
144     std::swap (highestBit, other.highestBit);
145     std::swap (negative, other.negative);
146 }
147 
operator =(const BigInteger & other)148 BigInteger& BigInteger::operator= (const BigInteger& other)
149 {
150     if (this != &other)
151     {
152         highestBit = other.getHighestBit();
153         auto newAllocatedSize = (size_t) jmax ((size_t) numPreallocatedInts, sizeNeededToHold (highestBit));
154 
155         if (newAllocatedSize <= numPreallocatedInts)
156             heapAllocation.free();
157         else if (newAllocatedSize != allocatedSize)
158             heapAllocation.malloc (newAllocatedSize);
159 
160         allocatedSize = newAllocatedSize;
161 
162         memcpy (getValues(), other.getValues(), sizeof (uint32) * allocatedSize);
163         negative = other.negative;
164     }
165 
166     return *this;
167 }
168 
getValues() const169 uint32* BigInteger::getValues() const noexcept
170 {
171     jassert (heapAllocation != nullptr || allocatedSize <= numPreallocatedInts);
172 
173     return heapAllocation != nullptr ? heapAllocation
174                                      : const_cast<uint32*> (preallocated);
175 }
176 
ensureSize(const size_t numVals)177 uint32* BigInteger::ensureSize (const size_t numVals)
178 {
179     if (numVals > allocatedSize)
180     {
181         auto oldSize = allocatedSize;
182         allocatedSize = ((numVals + 2) * 3) / 2;
183 
184         if (heapAllocation == nullptr)
185         {
186             heapAllocation.calloc (allocatedSize);
187             memcpy (heapAllocation, preallocated, sizeof (uint32) * numPreallocatedInts);
188         }
189         else
190         {
191             heapAllocation.realloc (allocatedSize);
192 
193             for (auto* values = getValues(); oldSize < allocatedSize; ++oldSize)
194                 values[oldSize] = 0;
195         }
196     }
197 
198     return getValues();
199 }
200 
201 //==============================================================================
operator [](const int bit) const202 bool BigInteger::operator[] (const int bit) const noexcept
203 {
204     return bit <= highestBit && bit >= 0
205              && ((getValues() [bitToIndex (bit)] & bitToMask (bit)) != 0);
206 }
207 
toInteger() const208 int BigInteger::toInteger() const noexcept
209 {
210     auto n = (int) (getValues()[0] & 0x7fffffff);
211     return negative ? -n : n;
212 }
213 
toInt64() const214 int64 BigInteger::toInt64() const noexcept
215 {
216     auto* values = getValues();
217     auto n = (((int64) (values[1] & 0x7fffffff)) << 32) | values[0];
218     return negative ? -n : n;
219 }
220 
getBitRange(int startBit,int numBits) const221 BigInteger BigInteger::getBitRange (int startBit, int numBits) const
222 {
223     BigInteger r;
224     numBits = jmax (0, jmin (numBits, getHighestBit() + 1 - startBit));
225     auto* destValues = r.ensureSize (sizeNeededToHold (numBits));
226     r.highestBit = numBits;
227 
228     for (int i = 0; numBits > 0;)
229     {
230         destValues[i++] = getBitRangeAsInt (startBit, (int) jmin (32, numBits));
231         numBits -= 32;
232         startBit += 32;
233     }
234 
235     r.highestBit = r.getHighestBit();
236     return r;
237 }
238 
getBitRangeAsInt(const int startBit,int numBits) const239 uint32 BigInteger::getBitRangeAsInt (const int startBit, int numBits) const noexcept
240 {
241     if (numBits > 32)
242     {
243         jassertfalse;  // use getBitRange() if you need more than 32 bits..
244         numBits = 32;
245     }
246 
247     numBits = jmin (numBits, highestBit + 1 - startBit);
248 
249     if (numBits <= 0)
250         return 0;
251 
252     auto pos = bitToIndex (startBit);
253     auto offset = startBit & 31;
254     auto endSpace = 32 - numBits;
255     auto* values = getValues();
256 
257     auto n = ((uint32) values [pos]) >> offset;
258 
259     if (offset > endSpace)
260         n |= ((uint32) values [pos + 1]) << (32 - offset);
261 
262     return n & (((uint32) 0xffffffff) >> endSpace);
263 }
264 
setBitRangeAsInt(const int startBit,int numBits,uint32 valueToSet)265 void BigInteger::setBitRangeAsInt (const int startBit, int numBits, uint32 valueToSet)
266 {
267     if (numBits > 32)
268     {
269         jassertfalse;
270         numBits = 32;
271     }
272 
273     for (int i = 0; i < numBits; ++i)
274     {
275         setBit (startBit + i, (valueToSet & 1) != 0);
276         valueToSet >>= 1;
277     }
278 }
279 
280 //==============================================================================
clear()281 void BigInteger::clear() noexcept
282 {
283     heapAllocation.free();
284     allocatedSize = numPreallocatedInts;
285     highestBit = -1;
286     negative = false;
287 
288     for (int i = 0; i < numPreallocatedInts; ++i)
289         preallocated[i] = 0;
290 }
291 
setBit(const int bit)292 void BigInteger::setBit (const int bit)
293 {
294     if (bit >= 0)
295     {
296         if (bit > highestBit)
297         {
298             ensureSize (sizeNeededToHold (bit));
299             highestBit = bit;
300         }
301 
302         getValues() [bitToIndex (bit)] |= bitToMask (bit);
303     }
304 }
305 
setBit(const int bit,const bool shouldBeSet)306 void BigInteger::setBit (const int bit, const bool shouldBeSet)
307 {
308     if (shouldBeSet)
309         setBit (bit);
310     else
311         clearBit (bit);
312 }
313 
clearBit(const int bit)314 void BigInteger::clearBit (const int bit) noexcept
315 {
316     if (bit >= 0 && bit <= highestBit)
317     {
318         getValues() [bitToIndex (bit)] &= ~bitToMask (bit);
319 
320         if (bit == highestBit)
321             highestBit = getHighestBit();
322     }
323 }
324 
setRange(int startBit,int numBits,const bool shouldBeSet)325 void BigInteger::setRange (int startBit, int numBits, const bool shouldBeSet)
326 {
327     while (--numBits >= 0)
328         setBit (startBit++, shouldBeSet);
329 }
330 
insertBit(const int bit,const bool shouldBeSet)331 void BigInteger::insertBit (const int bit, const bool shouldBeSet)
332 {
333     if (bit >= 0)
334         shiftBits (1, bit);
335 
336     setBit (bit, shouldBeSet);
337 }
338 
339 //==============================================================================
isZero() const340 bool BigInteger::isZero() const noexcept
341 {
342     return getHighestBit() < 0;
343 }
344 
isOne() const345 bool BigInteger::isOne() const noexcept
346 {
347     return getHighestBit() == 0 && ! negative;
348 }
349 
isNegative() const350 bool BigInteger::isNegative() const noexcept
351 {
352     return negative && ! isZero();
353 }
354 
setNegative(const bool neg)355 void BigInteger::setNegative (const bool neg) noexcept
356 {
357     negative = neg;
358 }
359 
negate()360 void BigInteger::negate() noexcept
361 {
362     negative = (! negative) && ! isZero();
363 }
364 
365 #if JUCE_MSVC && ! defined (__INTEL_COMPILER)
366  #pragma intrinsic (_BitScanReverse)
367 #endif
368 
countNumberOfSetBits() const369 int BigInteger::countNumberOfSetBits() const noexcept
370 {
371     int total = 0;
372     auto* values = getValues();
373 
374     for (int i = (int) sizeNeededToHold (highestBit); --i >= 0;)
375         total += countNumberOfBits (values[i]);
376 
377     return total;
378 }
379 
getHighestBit() const380 int BigInteger::getHighestBit() const noexcept
381 {
382     auto* values = getValues();
383 
384     for (int i = (int) bitToIndex (highestBit); i >= 0; --i)
385         if (uint32 n = values[i])
386             return findHighestSetBit (n) + (i << 5);
387 
388     return -1;
389 }
390 
findNextSetBit(int i) const391 int BigInteger::findNextSetBit (int i) const noexcept
392 {
393     auto* values = getValues();
394 
395     for (; i <= highestBit; ++i)
396         if ((values [bitToIndex (i)] & bitToMask (i)) != 0)
397             return i;
398 
399     return -1;
400 }
401 
findNextClearBit(int i) const402 int BigInteger::findNextClearBit (int i) const noexcept
403 {
404     auto* values = getValues();
405 
406     for (; i <= highestBit; ++i)
407         if ((values [bitToIndex (i)] & bitToMask (i)) == 0)
408             break;
409 
410     return i;
411 }
412 
413 //==============================================================================
operator +=(const BigInteger & other)414 BigInteger& BigInteger::operator+= (const BigInteger& other)
415 {
416     if (this == &other)
417         return operator+= (BigInteger (other));
418 
419     if (other.isNegative())
420         return operator-= (-other);
421 
422     if (isNegative())
423     {
424         if (compareAbsolute (other) < 0)
425         {
426             auto temp = *this;
427             temp.negate();
428             *this = other;
429             *this -= temp;
430         }
431         else
432         {
433             negate();
434             *this -= other;
435             negate();
436         }
437     }
438     else
439     {
440         highestBit = jmax (highestBit, other.highestBit) + 1;
441 
442         auto numInts = sizeNeededToHold (highestBit);
443         auto* values = ensureSize (numInts);
444         auto* otherValues = other.getValues();
445         int64 remainder = 0;
446 
447         for (size_t i = 0; i < numInts; ++i)
448         {
449             remainder += values[i];
450 
451             if (i < other.allocatedSize)
452                 remainder += otherValues[i];
453 
454             values[i] = (uint32) remainder;
455             remainder >>= 32;
456         }
457 
458         jassert (remainder == 0);
459         highestBit = getHighestBit();
460     }
461 
462     return *this;
463 }
464 
operator -=(const BigInteger & other)465 BigInteger& BigInteger::operator-= (const BigInteger& other)
466 {
467     if (this == &other)
468     {
469         clear();
470         return *this;
471     }
472 
473     if (other.isNegative())
474         return operator+= (-other);
475 
476     if (isNegative())
477     {
478         negate();
479         *this += other;
480         negate();
481         return *this;
482     }
483 
484     if (compareAbsolute (other) < 0)
485     {
486         auto temp = other;
487         swapWith (temp);
488         *this -= temp;
489         negate();
490         return *this;
491     }
492 
493     auto numInts = sizeNeededToHold (getHighestBit());
494     auto maxOtherInts = sizeNeededToHold (other.getHighestBit());
495     jassert (numInts >= maxOtherInts);
496     auto* values = getValues();
497     auto* otherValues = other.getValues();
498     int64 amountToSubtract = 0;
499 
500     for (size_t i = 0; i < numInts; ++i)
501     {
502         if (i < maxOtherInts)
503             amountToSubtract += (int64) otherValues[i];
504 
505         if (values[i] >= amountToSubtract)
506         {
507             values[i] = (uint32) (values[i] - amountToSubtract);
508             amountToSubtract = 0;
509         }
510         else
511         {
512             const int64 n = ((int64) values[i] + (((int64) 1) << 32)) - amountToSubtract;
513             values[i] = (uint32) n;
514             amountToSubtract = 1;
515         }
516     }
517 
518     highestBit = getHighestBit();
519     return *this;
520 }
521 
operator *=(const BigInteger & other)522 BigInteger& BigInteger::operator*= (const BigInteger& other)
523 {
524     if (this == &other)
525         return operator*= (BigInteger (other));
526 
527     auto n = getHighestBit();
528     auto t = other.getHighestBit();
529 
530     auto wasNegative = isNegative();
531     setNegative (false);
532 
533     BigInteger total;
534     total.highestBit = n + t + 1;
535     auto* totalValues = total.ensureSize (sizeNeededToHold (total.highestBit) + 1);
536 
537     n >>= 5;
538     t >>= 5;
539 
540     auto m = other;
541     m.setNegative (false);
542 
543     auto* mValues = m.getValues();
544     auto* values = getValues();
545 
546     for (int i = 0; i <= t; ++i)
547     {
548         uint32 c = 0;
549 
550         for (int j = 0; j <= n; ++j)
551         {
552             auto uv = (uint64) totalValues[i + j] + (uint64) values[j] * (uint64) mValues[i] + (uint64) c;
553             totalValues[i + j] = (uint32) uv;
554             c = uv >> 32;
555         }
556 
557         totalValues[i + n + 1] = c;
558     }
559 
560     total.highestBit = total.getHighestBit();
561     total.setNegative (wasNegative ^ other.isNegative());
562     swapWith (total);
563 
564     return *this;
565 }
566 
divideBy(const BigInteger & divisor,BigInteger & remainder)567 void BigInteger::divideBy (const BigInteger& divisor, BigInteger& remainder)
568 {
569     if (this == &divisor)
570         return divideBy (BigInteger (divisor), remainder);
571 
572     jassert (this != &remainder); // (can't handle passing itself in to get the remainder)
573 
574     auto divHB = divisor.getHighestBit();
575     auto ourHB = getHighestBit();
576 
577     if (divHB < 0 || ourHB < 0)
578     {
579         // division by zero
580         remainder.clear();
581         clear();
582     }
583     else
584     {
585         auto wasNegative = isNegative();
586 
587         swapWith (remainder);
588         remainder.setNegative (false);
589         clear();
590 
591         BigInteger temp (divisor);
592         temp.setNegative (false);
593 
594         auto leftShift = ourHB - divHB;
595         temp <<= leftShift;
596 
597         while (leftShift >= 0)
598         {
599             if (remainder.compareAbsolute (temp) >= 0)
600             {
601                 remainder -= temp;
602                 setBit (leftShift);
603             }
604 
605             if (--leftShift >= 0)
606                 temp >>= 1;
607         }
608 
609         negative = wasNegative ^ divisor.isNegative();
610         remainder.setNegative (wasNegative);
611     }
612 }
613 
operator /=(const BigInteger & other)614 BigInteger& BigInteger::operator/= (const BigInteger& other)
615 {
616     BigInteger remainder;
617     divideBy (other, remainder);
618     return *this;
619 }
620 
operator |=(const BigInteger & other)621 BigInteger& BigInteger::operator|= (const BigInteger& other)
622 {
623     if (this == &other)
624         return *this;
625 
626     // this operation doesn't take into account negative values..
627     jassert (isNegative() == other.isNegative());
628 
629     if (other.highestBit >= 0)
630     {
631         auto* values = ensureSize (sizeNeededToHold (other.highestBit));
632         auto* otherValues = other.getValues();
633 
634         auto n = (int) bitToIndex (other.highestBit) + 1;
635 
636         while (--n >= 0)
637             values[n] |= otherValues[n];
638 
639         if (other.highestBit > highestBit)
640             highestBit = other.highestBit;
641 
642         highestBit = getHighestBit();
643     }
644 
645     return *this;
646 }
647 
operator &=(const BigInteger & other)648 BigInteger& BigInteger::operator&= (const BigInteger& other)
649 {
650     if (this == &other)
651         return *this;
652 
653     // this operation doesn't take into account negative values..
654     jassert (isNegative() == other.isNegative());
655 
656     auto* values = getValues();
657     auto* otherValues = other.getValues();
658 
659     auto n = (int) allocatedSize;
660 
661     while (n > (int) other.allocatedSize)
662         values[--n] = 0;
663 
664     while (--n >= 0)
665         values[n] &= otherValues[n];
666 
667     if (other.highestBit < highestBit)
668         highestBit = other.highestBit;
669 
670     highestBit = getHighestBit();
671     return *this;
672 }
673 
operator ^=(const BigInteger & other)674 BigInteger& BigInteger::operator^= (const BigInteger& other)
675 {
676     if (this == &other)
677     {
678         clear();
679         return *this;
680     }
681 
682     // this operation will only work with the absolute values
683     jassert (isNegative() == other.isNegative());
684 
685     if (other.highestBit >= 0)
686     {
687         auto* values = ensureSize (sizeNeededToHold (other.highestBit));
688         auto* otherValues = other.getValues();
689 
690         auto n = (int) bitToIndex (other.highestBit) + 1;
691 
692         while (--n >= 0)
693             values[n] ^= otherValues[n];
694 
695         if (other.highestBit > highestBit)
696             highestBit = other.highestBit;
697 
698         highestBit = getHighestBit();
699     }
700 
701     return *this;
702 }
703 
operator %=(const BigInteger & divisor)704 BigInteger& BigInteger::operator%= (const BigInteger& divisor)
705 {
706     BigInteger remainder;
707     divideBy (divisor, remainder);
708     swapWith (remainder);
709     return *this;
710 }
711 
operator ++()712 BigInteger& BigInteger::operator++()      { return operator+= (1); }
operator --()713 BigInteger& BigInteger::operator--()      { return operator-= (1); }
operator ++(int)714 BigInteger  BigInteger::operator++ (int)  { const auto old (*this); operator+= (1); return old; }
operator --(int)715 BigInteger  BigInteger::operator-- (int)  { const auto old (*this); operator-= (1); return old; }
716 
operator -() const717 BigInteger  BigInteger::operator-() const                            { auto b (*this); b.negate(); return b; }
operator +(const BigInteger & other) const718 BigInteger  BigInteger::operator+   (const BigInteger& other) const  { auto b (*this); return b += other; }
operator -(const BigInteger & other) const719 BigInteger  BigInteger::operator-   (const BigInteger& other) const  { auto b (*this); return b -= other; }
operator *(const BigInteger & other) const720 BigInteger  BigInteger::operator*   (const BigInteger& other) const  { auto b (*this); return b *= other; }
operator /(const BigInteger & other) const721 BigInteger  BigInteger::operator/   (const BigInteger& other) const  { auto b (*this); return b /= other; }
operator |(const BigInteger & other) const722 BigInteger  BigInteger::operator|   (const BigInteger& other) const  { auto b (*this); return b |= other; }
operator &(const BigInteger & other) const723 BigInteger  BigInteger::operator&   (const BigInteger& other) const  { auto b (*this); return b &= other; }
operator ^(const BigInteger & other) const724 BigInteger  BigInteger::operator^   (const BigInteger& other) const  { auto b (*this); return b ^= other; }
operator %(const BigInteger & other) const725 BigInteger  BigInteger::operator%   (const BigInteger& other) const  { auto b (*this); return b %= other; }
operator <<(const int numBits) const726 BigInteger  BigInteger::operator<<  (const int numBits) const        { auto b (*this); return b <<= numBits; }
operator >>(const int numBits) const727 BigInteger  BigInteger::operator>>  (const int numBits) const        { auto b (*this); return b >>= numBits; }
operator <<=(const int numBits)728 BigInteger& BigInteger::operator<<= (const int numBits)              { shiftBits (numBits, 0);  return *this; }
operator >>=(const int numBits)729 BigInteger& BigInteger::operator>>= (const int numBits)              { shiftBits (-numBits, 0); return *this; }
730 
731 //==============================================================================
compare(const BigInteger & other) const732 int BigInteger::compare (const BigInteger& other) const noexcept
733 {
734     auto isNeg = isNegative();
735 
736     if (isNeg == other.isNegative())
737     {
738         auto absComp = compareAbsolute (other);
739         return isNeg ? -absComp : absComp;
740     }
741 
742     return isNeg ? -1 : 1;
743 }
744 
compareAbsolute(const BigInteger & other) const745 int BigInteger::compareAbsolute (const BigInteger& other) const noexcept
746 {
747     auto h1 = getHighestBit();
748     auto h2 = other.getHighestBit();
749 
750     if (h1 > h2) return 1;
751     if (h1 < h2) return -1;
752 
753     auto* values = getValues();
754     auto* otherValues = other.getValues();
755 
756     for (int i = (int) bitToIndex (h1); i >= 0; --i)
757         if (values[i] != otherValues[i])
758             return values[i] > otherValues[i] ? 1 : -1;
759 
760     return 0;
761 }
762 
operator ==(const BigInteger & other) const763 bool BigInteger::operator== (const BigInteger& other) const noexcept    { return compare (other) == 0; }
operator !=(const BigInteger & other) const764 bool BigInteger::operator!= (const BigInteger& other) const noexcept    { return compare (other) != 0; }
operator <(const BigInteger & other) const765 bool BigInteger::operator<  (const BigInteger& other) const noexcept    { return compare (other) <  0; }
operator <=(const BigInteger & other) const766 bool BigInteger::operator<= (const BigInteger& other) const noexcept    { return compare (other) <= 0; }
operator >(const BigInteger & other) const767 bool BigInteger::operator>  (const BigInteger& other) const noexcept    { return compare (other) >  0; }
operator >=(const BigInteger & other) const768 bool BigInteger::operator>= (const BigInteger& other) const noexcept    { return compare (other) >= 0; }
769 
770 //==============================================================================
shiftLeft(int bits,const int startBit)771 void BigInteger::shiftLeft (int bits, const int startBit)
772 {
773     if (startBit > 0)
774     {
775         for (int i = highestBit; i >= startBit; --i)
776             setBit (i + bits, (*this) [i]);
777 
778         while (--bits >= 0)
779             clearBit (bits + startBit);
780     }
781     else
782     {
783         auto* values = ensureSize (sizeNeededToHold (highestBit + bits));
784         auto wordsToMove = bitToIndex (bits);
785         auto numOriginalInts = bitToIndex (highestBit);
786         highestBit += bits;
787 
788         if (wordsToMove > 0)
789         {
790             for (int i = (int) numOriginalInts; i >= 0; --i)
791                 values[(size_t) i + wordsToMove] = values[i];
792 
793             for (size_t j = 0; j < wordsToMove; ++j)
794                 values[j] = 0;
795 
796             bits &= 31;
797         }
798 
799         if (bits != 0)
800         {
801             auto invBits = 32 - bits;
802 
803             for (size_t i = bitToIndex (highestBit); i > wordsToMove; --i)
804                 values[i] = (values[i] << bits) | (values[i - 1] >> invBits);
805 
806             values[wordsToMove] = values[wordsToMove] << bits;
807         }
808 
809         highestBit = getHighestBit();
810     }
811 }
812 
shiftRight(int bits,const int startBit)813 void BigInteger::shiftRight (int bits, const int startBit)
814 {
815     if (startBit > 0)
816     {
817         for (int i = startBit; i <= highestBit; ++i)
818             setBit (i, (*this) [i + bits]);
819 
820         highestBit = getHighestBit();
821     }
822     else
823     {
824         if (bits > highestBit)
825         {
826             clear();
827         }
828         else
829         {
830             auto wordsToMove = bitToIndex (bits);
831             auto top = 1 + bitToIndex (highestBit) - wordsToMove;
832             highestBit -= bits;
833             auto* values = getValues();
834 
835             if (wordsToMove > 0)
836             {
837                 for (size_t i = 0; i < top; ++i)
838                     values[i] = values[i + wordsToMove];
839 
840                 for (size_t i = 0; i < wordsToMove; ++i)
841                     values[top + i] = 0;
842 
843                 bits &= 31;
844             }
845 
846             if (bits != 0)
847             {
848                 auto invBits = 32 - bits;
849                 --top;
850 
851                 for (size_t i = 0; i < top; ++i)
852                     values[i] = (values[i] >> bits) | (values[i + 1] << invBits);
853 
854                 values[top] = (values[top] >> bits);
855             }
856 
857             highestBit = getHighestBit();
858         }
859     }
860 }
861 
shiftBits(int bits,const int startBit)862 void BigInteger::shiftBits (int bits, const int startBit)
863 {
864     if (highestBit >= 0)
865     {
866         if (bits < 0)
867             shiftRight (-bits, startBit);
868         else if (bits > 0)
869             shiftLeft (bits, startBit);
870     }
871 }
872 
873 //==============================================================================
simpleGCD(BigInteger * m,BigInteger * n)874 static BigInteger simpleGCD (BigInteger* m, BigInteger* n)
875 {
876     while (! m->isZero())
877     {
878         if (n->compareAbsolute (*m) > 0)
879             std::swap (m, n);
880 
881         *m -= *n;
882     }
883 
884     return *n;
885 }
886 
findGreatestCommonDivisor(BigInteger n) const887 BigInteger BigInteger::findGreatestCommonDivisor (BigInteger n) const
888 {
889     auto m = *this;
890 
891     while (! n.isZero())
892     {
893         if (std::abs (m.getHighestBit() - n.getHighestBit()) <= 16)
894             return simpleGCD (&m, &n);
895 
896         BigInteger temp2;
897         m.divideBy (n, temp2);
898 
899         m.swapWith (n);
900         n.swapWith (temp2);
901     }
902 
903     return m;
904 }
905 
exponentModulo(const BigInteger & exponent,const BigInteger & modulus)906 void BigInteger::exponentModulo (const BigInteger& exponent, const BigInteger& modulus)
907 {
908     *this %= modulus;
909     auto exp = exponent;
910     exp %= modulus;
911 
912     if (modulus.getHighestBit() <= 32 || modulus % 2 == 0)
913     {
914         auto a = *this;
915         auto n = exp.getHighestBit();
916 
917         for (int i = n; --i >= 0;)
918         {
919             *this *= *this;
920 
921             if (exp[i])
922                 *this *= a;
923 
924             if (compareAbsolute (modulus) >= 0)
925                 *this %= modulus;
926         }
927     }
928     else
929     {
930         auto Rfactor = modulus.getHighestBit() + 1;
931         BigInteger R (1);
932         R.shiftLeft (Rfactor, 0);
933 
934         BigInteger R1, m1, g;
935         g.extendedEuclidean (modulus, R, m1, R1);
936 
937         if (! g.isOne())
938         {
939             BigInteger a (*this);
940 
941             for (int i = exp.getHighestBit(); --i >= 0;)
942             {
943                 *this *= *this;
944 
945                 if (exp[i])
946                     *this *= a;
947 
948                 if (compareAbsolute (modulus) >= 0)
949                     *this %= modulus;
950             }
951         }
952         else
953         {
954             auto am  = (*this * R) % modulus;
955             auto xm = am;
956             auto um = R % modulus;
957 
958             for (int i = exp.getHighestBit(); --i >= 0;)
959             {
960                 xm.montgomeryMultiplication (xm, modulus, m1, Rfactor);
961 
962                 if (exp[i])
963                     xm.montgomeryMultiplication (am, modulus, m1, Rfactor);
964             }
965 
966             xm.montgomeryMultiplication (1, modulus, m1, Rfactor);
967             swapWith (xm);
968         }
969     }
970 }
971 
montgomeryMultiplication(const BigInteger & other,const BigInteger & modulus,const BigInteger & modulusp,const int k)972 void BigInteger::montgomeryMultiplication (const BigInteger& other, const BigInteger& modulus,
973                                            const BigInteger& modulusp, const int k)
974 {
975     *this *= other;
976     auto t = *this;
977 
978     setRange (k, highestBit - k + 1, false);
979     *this *= modulusp;
980 
981     setRange (k, highestBit - k + 1, false);
982     *this *= modulus;
983     *this += t;
984     shiftRight (k, 0);
985 
986     if (compare (modulus) >= 0)
987         *this -= modulus;
988     else if (isNegative())
989         *this += modulus;
990 }
991 
extendedEuclidean(const BigInteger & a,const BigInteger & b,BigInteger & x,BigInteger & y)992 void BigInteger::extendedEuclidean (const BigInteger& a, const BigInteger& b,
993                                     BigInteger& x, BigInteger& y)
994 {
995     BigInteger p(a), q(b), gcd(1);
996     Array<BigInteger> tempValues;
997 
998     while (! q.isZero())
999     {
1000         tempValues.add (p / q);
1001         gcd = q;
1002         q = p % q;
1003         p = gcd;
1004     }
1005 
1006     x.clear();
1007     y = 1;
1008 
1009     for (int i = 1; i < tempValues.size(); ++i)
1010     {
1011         auto& v = tempValues.getReference (tempValues.size() - i - 1);
1012 
1013         if ((i & 1) != 0)
1014             x += y * v;
1015         else
1016             y += x * v;
1017     }
1018 
1019     if (gcd.compareAbsolute (y * b - x * a) != 0)
1020     {
1021         x.negate();
1022         x.swapWith (y);
1023         x.negate();
1024     }
1025 
1026     swapWith (gcd);
1027 }
1028 
inverseModulo(const BigInteger & modulus)1029 void BigInteger::inverseModulo (const BigInteger& modulus)
1030 {
1031     if (modulus.isOne() || modulus.isNegative())
1032     {
1033         clear();
1034         return;
1035     }
1036 
1037     if (isNegative() || compareAbsolute (modulus) >= 0)
1038         *this %= modulus;
1039 
1040     if (isOne())
1041         return;
1042 
1043     if (findGreatestCommonDivisor (modulus) != 1)
1044     {
1045         clear();  // not invertible!
1046         return;
1047     }
1048 
1049     BigInteger a1 (modulus), a2 (*this),
1050                b1 (modulus), b2 (1);
1051 
1052     while (! a2.isOne())
1053     {
1054         BigInteger temp1, multiplier (a1);
1055         multiplier.divideBy (a2, temp1);
1056 
1057         temp1 = a2;
1058         temp1 *= multiplier;
1059         auto temp2 = a1;
1060         temp2 -= temp1;
1061         a1 = a2;
1062         a2 = temp2;
1063 
1064         temp1 = b2;
1065         temp1 *= multiplier;
1066         temp2 = b1;
1067         temp2 -= temp1;
1068         b1 = b2;
1069         b2 = temp2;
1070     }
1071 
1072     while (b2.isNegative())
1073         b2 += modulus;
1074 
1075     b2 %= modulus;
1076     swapWith (b2);
1077 }
1078 
1079 //==============================================================================
operator <<(OutputStream & stream,const BigInteger & value)1080 OutputStream& JUCE_CALLTYPE operator<< (OutputStream& stream, const BigInteger& value)
1081 {
1082     return stream << value.toString (10);
1083 }
1084 
toString(const int base,const int minimumNumCharacters) const1085 String BigInteger::toString (const int base, const int minimumNumCharacters) const
1086 {
1087     String s;
1088     auto v = *this;
1089 
1090     if (base == 2 || base == 8 || base == 16)
1091     {
1092         auto bits = (base == 2) ? 1 : (base == 8 ? 3 : 4);
1093         static const char hexDigits[] = "0123456789abcdef";
1094 
1095         for (;;)
1096         {
1097             auto remainder = v.getBitRangeAsInt (0, bits);
1098             v >>= bits;
1099 
1100             if (remainder == 0 && v.isZero())
1101                 break;
1102 
1103             s = String::charToString ((juce_wchar) (uint8) hexDigits [remainder]) + s;
1104         }
1105     }
1106     else if (base == 10)
1107     {
1108         const BigInteger ten (10);
1109         BigInteger remainder;
1110 
1111         for (;;)
1112         {
1113             v.divideBy (ten, remainder);
1114 
1115             if (remainder.isZero() && v.isZero())
1116                 break;
1117 
1118             s = String (remainder.getBitRangeAsInt (0, 8)) + s;
1119         }
1120     }
1121     else
1122     {
1123         jassertfalse; // can't do the specified base!
1124         return {};
1125     }
1126 
1127     s = s.paddedLeft ('0', minimumNumCharacters);
1128 
1129     return isNegative() ? "-" + s : s;
1130 }
1131 
parseString(StringRef text,const int base)1132 void BigInteger::parseString (StringRef text, const int base)
1133 {
1134     clear();
1135     auto t = text.text.findEndOfWhitespace();
1136 
1137     setNegative (*t == (juce_wchar) '-');
1138 
1139     if (base == 2 || base == 8 || base == 16)
1140     {
1141         auto bits = (base == 2) ? 1 : (base == 8 ? 3 : 4);
1142 
1143         for (;;)
1144         {
1145             auto c = t.getAndAdvance();
1146             auto digit = CharacterFunctions::getHexDigitValue (c);
1147 
1148             if (((uint32) digit) < (uint32) base)
1149             {
1150                 *this <<= bits;
1151                 *this += digit;
1152             }
1153             else if (c == 0)
1154             {
1155                 break;
1156             }
1157         }
1158     }
1159     else if (base == 10)
1160     {
1161         const BigInteger ten ((uint32) 10);
1162 
1163         for (;;)
1164         {
1165             auto c = t.getAndAdvance();
1166 
1167             if (c >= '0' && c <= '9')
1168             {
1169                 *this *= ten;
1170                 *this += (int) (c - '0');
1171             }
1172             else if (c == 0)
1173             {
1174                 break;
1175             }
1176         }
1177     }
1178 }
1179 
toMemoryBlock() const1180 MemoryBlock BigInteger::toMemoryBlock() const
1181 {
1182     auto numBytes = (getHighestBit() + 8) >> 3;
1183     MemoryBlock mb ((size_t) numBytes);
1184     auto* values = getValues();
1185 
1186     for (int i = 0; i < numBytes; ++i)
1187         mb[i] = (char) ((values[i / 4] >> ((i & 3) * 8)) & 0xff);
1188 
1189     return mb;
1190 }
1191 
loadFromMemoryBlock(const MemoryBlock & data)1192 void BigInteger::loadFromMemoryBlock (const MemoryBlock& data)
1193 {
1194     auto numBytes = data.getSize();
1195     auto numInts = 1 + (numBytes / sizeof (uint32));
1196     auto* values = ensureSize (numInts);
1197 
1198     for (int i = 0; i < (int) numInts - 1; ++i)
1199         values[i] = (uint32) ByteOrder::littleEndianInt (addBytesToPointer (data.getData(), (size_t) i * sizeof (uint32)));
1200 
1201     values[numInts - 1] = 0;
1202 
1203     for (int i = (int) (numBytes & ~3u); i < (int) numBytes; ++i)
1204         this->setBitRangeAsInt (i << 3, 8, (uint32) data [i]);
1205 
1206     highestBit = (int) numBytes * 8;
1207     highestBit = getHighestBit();
1208 }
1209 
1210 //==============================================================================
writeLittleEndianBitsInBuffer(void * buffer,uint32 startBit,uint32 numBits,uint32 value)1211 void writeLittleEndianBitsInBuffer (void* buffer, uint32 startBit, uint32 numBits, uint32 value) noexcept
1212 {
1213     jassert (buffer != nullptr);
1214     jassert (numBits > 0 && numBits <= 32);
1215     jassert (numBits == 32 || (value >> numBits) == 0);
1216 
1217     uint8* data = static_cast<uint8*> (buffer) + startBit / 8;
1218 
1219     if (const uint32 offset = (startBit & 7))
1220     {
1221         const uint32 bitsInByte = 8 - offset;
1222         const uint8 current = *data;
1223 
1224         if (bitsInByte >= numBits)
1225         {
1226             *data = (uint8) ((current & ~(((1u << numBits) - 1u) << offset)) | (value << offset));
1227             return;
1228         }
1229 
1230         *data++ = current ^ (uint8) (((value << offset) ^ current) & (((1u << bitsInByte) - 1u) << offset));
1231         numBits -= bitsInByte;
1232         value >>= bitsInByte;
1233     }
1234 
1235     while (numBits >= 8)
1236     {
1237         *data++ = (uint8) value;
1238         value >>= 8;
1239         numBits -= 8;
1240     }
1241 
1242     if (numBits > 0)
1243         *data = (uint8) ((*data & (uint32) (0xff << numBits)) | value);
1244 }
1245 
readLittleEndianBitsInBuffer(const void * buffer,uint32 startBit,uint32 numBits)1246 uint32 readLittleEndianBitsInBuffer (const void* buffer, uint32 startBit, uint32 numBits) noexcept
1247 {
1248     jassert (buffer != nullptr);
1249     jassert (numBits > 0 && numBits <= 32);
1250 
1251     uint32 result = 0;
1252     uint32 bitsRead = 0;
1253     const uint8* data = static_cast<const uint8*> (buffer) + startBit / 8;
1254 
1255     if (const uint32 offset = (startBit & 7))
1256     {
1257         const uint32 bitsInByte = 8 - offset;
1258         result = (uint32) (*data >> offset);
1259 
1260         if (bitsInByte >= numBits)
1261             return result & ((1u << numBits) - 1u);
1262 
1263         numBits -= bitsInByte;
1264         bitsRead += bitsInByte;
1265         ++data;
1266     }
1267 
1268     while (numBits >= 8)
1269     {
1270         result |= (((uint32) *data++) << bitsRead);
1271         bitsRead += 8;
1272         numBits -= 8;
1273     }
1274 
1275     if (numBits > 0)
1276         result |= ((*data & ((1u << numBits) - 1u)) << bitsRead);
1277 
1278     return result;
1279 }
1280 
1281 
1282 //==============================================================================
1283 //==============================================================================
1284 #if JUCE_UNIT_TESTS
1285 
1286 class BigIntegerTests  : public UnitTest
1287 {
1288 public:
BigIntegerTests()1289     BigIntegerTests()
1290         : UnitTest ("BigInteger", UnitTestCategories::maths)
1291     {}
1292 
getBigRandom(Random & r)1293     static BigInteger getBigRandom (Random& r)
1294     {
1295         BigInteger b;
1296 
1297         while (b < 2)
1298             r.fillBitsRandomly (b, 0, r.nextInt (150) + 1);
1299 
1300         return b;
1301     }
1302 
runTest()1303     void runTest() override
1304     {
1305         {
1306             beginTest ("BigInteger");
1307 
1308             Random r = getRandom();
1309 
1310             expect (BigInteger().isZero());
1311             expect (BigInteger(1).isOne());
1312 
1313             for (int j = 10000; --j >= 0;)
1314             {
1315                 BigInteger b1 (getBigRandom(r)),
1316                            b2 (getBigRandom(r));
1317 
1318                 BigInteger b3 = b1 + b2;
1319                 expect (b3 > b1 && b3 > b2);
1320                 expect (b3 - b1 == b2);
1321                 expect (b3 - b2 == b1);
1322 
1323                 BigInteger b4 = b1 * b2;
1324                 expect (b4 > b1 && b4 > b2);
1325                 expect (b4 / b1 == b2);
1326                 expect (b4 / b2 == b1);
1327                 expect (((b4 << 1) >> 1) == b4);
1328                 expect (((b4 << 10) >> 10) == b4);
1329                 expect (((b4 << 100) >> 100) == b4);
1330 
1331                 // TODO: should add tests for other ops (although they also get pretty well tested in the RSA unit test)
1332 
1333                 BigInteger b5;
1334                 b5.loadFromMemoryBlock (b3.toMemoryBlock());
1335                 expect (b3 == b5);
1336             }
1337         }
1338 
1339         {
1340             beginTest ("Bit setting");
1341 
1342             Random r = getRandom();
1343             static uint8 test[2048];
1344 
1345             for (int j = 100000; --j >= 0;)
1346             {
1347                 uint32 offset = static_cast<uint32> (r.nextInt (200) + 10);
1348                 uint32 num = static_cast<uint32> (r.nextInt (32) + 1);
1349                 uint32 value = static_cast<uint32> (r.nextInt());
1350 
1351                 if (num < 32)
1352                     value &= ((1u << num) - 1);
1353 
1354                 auto old1 = readLittleEndianBitsInBuffer (test, offset - 6, 6);
1355                 auto old2 = readLittleEndianBitsInBuffer (test, offset + num, 6);
1356                 writeLittleEndianBitsInBuffer (test, offset, num, value);
1357                 auto result = readLittleEndianBitsInBuffer (test, offset, num);
1358 
1359                 expect (result == value);
1360                 expect (old1 == readLittleEndianBitsInBuffer (test, offset - 6, 6));
1361                 expect (old2 == readLittleEndianBitsInBuffer (test, offset + num, 6));
1362             }
1363         }
1364     }
1365 };
1366 
1367 static BigIntegerTests bigIntegerTests;
1368 
1369 #endif
1370 
1371 } // namespace juce
1372