1 2 3 #include "BigInt.h" 4 #include <ctype.h> 5 #include <string.h> 6 7 #include "RakAlloca.h" 8 #include "RakMemoryOverride.h" 9 #include "Rand.h" 10 11 #if defined(_MSC_VER) && !defined(_DEBUG) && _MSC_VER > 1310 12 #include <intrin.h> 13 #endif 14 15 namespace big 16 { 17 static const char Bits256[] = { 18 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 19 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 20 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 21 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 24 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 25 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 26 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 27 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 28 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 29 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 30 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 31 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 32 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 33 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8 34 }; 35 36 // returns the degree of the base 2 monic polynomial 37 // (the number of bits used to represent the number) 38 // eg, 0 0 0 0 1 0 1 1 ... => 28 out of 32 used Degree(uint32_t v)39 uint32_t Degree(uint32_t v) 40 { 41 //#if defined(_MSC_VER) && !defined(_DEBUG) 42 // unsigned long index; 43 // return _BitScanReverse(&index, v) ? (index + 1) : 0; 44 //#else 45 uint32_t r, t = v >> 16; 46 47 if (t) r = (r = t >> 8) ? 24 + Bits256[r] : 16 + Bits256[t]; 48 else r = (r = v >> 8) ? 8 + Bits256[r] : Bits256[v]; 49 50 return r; 51 //#endif 52 } 53 54 // returns the number of limbs that are actually used LimbDegree(const uint32_t * n,int limbs)55 int LimbDegree(const uint32_t *n, int limbs) 56 { 57 while (limbs--) 58 if (n[limbs]) 59 return limbs + 1; 60 61 return 0; 62 } 63 64 // return bits used Degree(const uint32_t * n,int limbs)65 uint32_t Degree(const uint32_t *n, int limbs) 66 { 67 uint32_t limb_degree = LimbDegree(n, limbs); 68 if (!limb_degree) return 0; 69 --limb_degree; 70 71 uint32_t msl_degree = Degree(n[limb_degree]); 72 73 return msl_degree + limb_degree*32; 74 } 75 Set(uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)76 void Set(uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 77 { 78 int min = lhs_limbs < rhs_limbs ? lhs_limbs : rhs_limbs; 79 80 memcpy(lhs, rhs, min*4); 81 memset(&lhs[min], 0, (lhs_limbs - min)*4); 82 } Set(uint32_t * lhs,int limbs,const uint32_t * rhs)83 void Set(uint32_t *lhs, int limbs, const uint32_t *rhs) 84 { 85 memcpy(lhs, rhs, limbs*4); 86 } Set32(uint32_t * lhs,int lhs_limbs,const uint32_t rhs)87 void Set32(uint32_t *lhs, int lhs_limbs, const uint32_t rhs) 88 { 89 *lhs = rhs; 90 memset(&lhs[1], 0, (lhs_limbs - 1)*4); 91 } 92 93 #if defined(__BIG_ENDIAN__) 94 95 // Flip the byte order as needed to make 'n' big-endian for sharing over a network ToLittleEndian(uint32_t * n,int limbs)96 void ToLittleEndian(uint32_t *n, int limbs) 97 { 98 for (int ii = 0; ii < limbs; ++ii) 99 { 100 swapLE(n[ii]); 101 } 102 } 103 104 // Flip the byte order as needed to make big-endian 'n' use the local byte order FromLittleEndian(uint32_t * n,int limbs)105 void FromLittleEndian(uint32_t *n, int limbs) 106 { 107 // Same operation as ToBigEndian() 108 ToLittleEndian(n, limbs); 109 } 110 111 #endif // __BIG_ENDIAN__ 112 Less(int limbs,const uint32_t * lhs,const uint32_t * rhs)113 bool Less(int limbs, const uint32_t *lhs, const uint32_t *rhs) 114 { 115 for (int ii = limbs-1; ii >= 0; --ii) 116 if (lhs[ii] != rhs[ii]) 117 return lhs[ii] < rhs[ii]; 118 119 return false; 120 } Greater(int limbs,const uint32_t * lhs,const uint32_t * rhs)121 bool Greater(int limbs, const uint32_t *lhs, const uint32_t *rhs) 122 { 123 for (int ii = limbs-1; ii >= 0; --ii) 124 if (lhs[ii] != rhs[ii]) 125 return lhs[ii] > rhs[ii]; 126 127 return false; 128 } Equal(int limbs,const uint32_t * lhs,const uint32_t * rhs)129 bool Equal(int limbs, const uint32_t *lhs, const uint32_t *rhs) 130 { 131 return 0 == memcmp(lhs, rhs, limbs*4); 132 } 133 Less(const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)134 bool Less(const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 135 { 136 if (lhs_limbs > rhs_limbs) 137 do if (lhs[--lhs_limbs] != 0) return false; while (lhs_limbs > rhs_limbs); 138 else if (lhs_limbs < rhs_limbs) 139 do if (rhs[--rhs_limbs] != 0) return true; while (lhs_limbs < rhs_limbs); 140 141 while (lhs_limbs--) if (lhs[lhs_limbs] != rhs[lhs_limbs]) return lhs[lhs_limbs] < rhs[lhs_limbs]; 142 return false; // equal 143 } Greater(const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)144 bool Greater(const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 145 { 146 if (lhs_limbs > rhs_limbs) 147 do if (lhs[--lhs_limbs] != 0) return true; while (lhs_limbs > rhs_limbs); 148 else if (lhs_limbs < rhs_limbs) 149 do if (rhs[--rhs_limbs] != 0) return false; while (lhs_limbs < rhs_limbs); 150 151 while (lhs_limbs--) if (lhs[lhs_limbs] != rhs[lhs_limbs]) return lhs[lhs_limbs] > rhs[lhs_limbs]; 152 return false; // equal 153 } Equal(const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)154 bool Equal(const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 155 { 156 if (lhs_limbs > rhs_limbs) 157 do if (lhs[--lhs_limbs] != 0) return false; while (lhs_limbs > rhs_limbs); 158 else if (lhs_limbs < rhs_limbs) 159 do if (rhs[--rhs_limbs] != 0) return false; while (lhs_limbs < rhs_limbs); 160 161 while (lhs_limbs--) if (lhs[lhs_limbs] != rhs[lhs_limbs]) return false; 162 return true; // equal 163 } 164 Greater32(const uint32_t * lhs,int lhs_limbs,uint32_t rhs)165 bool Greater32(const uint32_t *lhs, int lhs_limbs, uint32_t rhs) 166 { 167 if (*lhs > rhs) return true; 168 while (--lhs_limbs) 169 if (*++lhs) return true; 170 return false; 171 } Equal32(const uint32_t * lhs,int lhs_limbs,uint32_t rhs)172 bool Equal32(const uint32_t *lhs, int lhs_limbs, uint32_t rhs) 173 { 174 if (*lhs != rhs) return false; 175 while (--lhs_limbs) 176 if (*++lhs) return false; 177 return true; // equal 178 } 179 180 // out = in >>> shift 181 // Precondition: 0 <= shift < 31 ShiftRight(int limbs,uint32_t * out,const uint32_t * in,int shift)182 void ShiftRight(int limbs, uint32_t *out, const uint32_t *in, int shift) 183 { 184 if (!shift) 185 { 186 Set(out, limbs, in); 187 return; 188 } 189 190 uint32_t carry = 0; 191 192 for (int ii = limbs - 1; ii >= 0; --ii) 193 { 194 uint32_t r = in[ii]; 195 196 out[ii] = (r >> shift) | carry; 197 198 carry = r << (32 - shift); 199 } 200 } 201 202 // {out, carry} = in <<< shift 203 // Precondition: 0 <= shift < 31 ShiftLeft(int limbs,uint32_t * out,const uint32_t * in,int shift)204 uint32_t ShiftLeft(int limbs, uint32_t *out, const uint32_t *in, int shift) 205 { 206 if (!shift) 207 { 208 Set(out, limbs, in); 209 return 0; 210 } 211 212 uint32_t carry = 0; 213 214 for (int ii = 0; ii < limbs; ++ii) 215 { 216 uint32_t r = in[ii]; 217 218 out[ii] = (r << shift) | carry; 219 220 carry = r >> (32 - shift); 221 } 222 223 return carry; 224 } 225 226 // lhs += rhs, return carry out 227 // precondition: lhs_limbs >= rhs_limbs Add(uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)228 uint32_t Add(uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 229 { 230 int ii; 231 uint64_t r = (uint64_t)lhs[0] + rhs[0]; 232 lhs[0] = (uint32_t)r; 233 234 for (ii = 1; ii < rhs_limbs; ++ii) 235 { 236 r = ((uint64_t)lhs[ii] + rhs[ii]) + (uint32_t)(r >> 32); 237 lhs[ii] = (uint32_t)r; 238 } 239 240 for (; ii < lhs_limbs && (uint32_t)(r >>= 32) != 0; ++ii) 241 { 242 r += lhs[ii]; 243 lhs[ii] = (uint32_t)r; 244 } 245 246 return (uint32_t)(r >> 32); 247 } 248 249 // out = lhs + rhs, return carry out 250 // precondition: lhs_limbs >= rhs_limbs Add(uint32_t * out,const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)251 uint32_t Add(uint32_t *out, const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 252 { 253 int ii; 254 uint64_t r = (uint64_t)lhs[0] + rhs[0]; 255 out[0] = (uint32_t)r; 256 257 for (ii = 1; ii < rhs_limbs; ++ii) 258 { 259 r = ((uint64_t)lhs[ii] + rhs[ii]) + (uint32_t)(r >> 32); 260 out[ii] = (uint32_t)r; 261 } 262 263 for (; ii < lhs_limbs && (uint32_t)(r >>= 32) != 0; ++ii) 264 { 265 r += lhs[ii]; 266 out[ii] = (uint32_t)r; 267 } 268 269 return (uint32_t)(r >> 32); 270 } 271 272 // lhs += rhs, return carry out 273 // precondition: lhs_limbs > 0 Add32(uint32_t * lhs,int lhs_limbs,uint32_t rhs)274 uint32_t Add32(uint32_t *lhs, int lhs_limbs, uint32_t rhs) 275 { 276 uint32_t n = lhs[0]; 277 uint32_t r = n + rhs; 278 lhs[0] = r; 279 280 if (r >= n) 281 return 0; 282 283 for (int ii = 1; ii < lhs_limbs; ++ii) 284 if (++lhs[ii]) 285 return 0; 286 287 return 1; 288 } 289 290 // lhs -= rhs, return borrow out 291 // precondition: lhs_limbs >= rhs_limbs Subtract(uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)292 int32_t Subtract(uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 293 { 294 int ii; 295 int64_t r = (int64_t)lhs[0] - rhs[0]; 296 lhs[0] = (uint32_t)r; 297 298 for (ii = 1; ii < rhs_limbs; ++ii) 299 { 300 r = ((int64_t)lhs[ii] - rhs[ii]) + (int32_t)(r >> 32); 301 lhs[ii] = (uint32_t)r; 302 } 303 304 for (; ii < lhs_limbs && (int32_t)(r >>= 32) != 0; ++ii) 305 { 306 r += lhs[ii]; 307 lhs[ii] = (uint32_t)r; 308 } 309 310 return (int32_t)(r >> 32); 311 } 312 313 // out = lhs - rhs, return borrow out 314 // precondition: lhs_limbs >= rhs_limbs Subtract(uint32_t * out,const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)315 int32_t Subtract(uint32_t *out, const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs) 316 { 317 int ii; 318 int64_t r = (int64_t)lhs[0] - rhs[0]; 319 out[0] = (uint32_t)r; 320 321 for (ii = 1; ii < rhs_limbs; ++ii) 322 { 323 r = ((int64_t)lhs[ii] - rhs[ii]) + (int32_t)(r >> 32); 324 out[ii] = (uint32_t)r; 325 } 326 327 for (; ii < lhs_limbs && (int32_t)(r >>= 32) != 0; ++ii) 328 { 329 r += lhs[ii]; 330 out[ii] = (uint32_t)r; 331 } 332 333 return (int32_t)(r >> 32); 334 } 335 336 // lhs -= rhs, return borrow out 337 // precondition: lhs_limbs > 0, result limbs = lhs_limbs Subtract32(uint32_t * lhs,int lhs_limbs,uint32_t rhs)338 int32_t Subtract32(uint32_t *lhs, int lhs_limbs, uint32_t rhs) 339 { 340 uint32_t n = lhs[0]; 341 uint32_t r = n - rhs; 342 lhs[0] = r; 343 344 if (r <= n) 345 return 0; 346 347 for (int ii = 1; ii < lhs_limbs; ++ii) 348 if (lhs[ii]--) 349 return 0; 350 351 return -1; 352 } 353 354 // lhs = -rhs Negate(int limbs,uint32_t * lhs,const uint32_t * rhs)355 void Negate(int limbs, uint32_t *lhs, const uint32_t *rhs) 356 { 357 // Propagate negations until carries stop 358 while (limbs-- > 0 && !(*lhs++ = -(int32_t)(*rhs++))); 359 360 // Then just invert the remaining words 361 while (limbs-- > 0) *lhs++ = ~(*rhs++); 362 } 363 364 // n = ~n, only invert bits up to the MSB, but none above that BitNot(uint32_t * n,int limbs)365 void BitNot(uint32_t *n, int limbs) 366 { 367 limbs = LimbDegree(n, limbs); 368 if (limbs) 369 { 370 uint32_t high = n[--limbs]; 371 uint32_t high_degree = 32 - Degree(high); 372 373 n[limbs] = ((uint32_t)(~high << high_degree) >> high_degree); 374 while (limbs--) n[limbs] = ~n[limbs]; 375 } 376 } 377 378 // n = ~n, invert all bits, even ones above MSB LimbNot(uint32_t * n,int limbs)379 void LimbNot(uint32_t *n, int limbs) 380 { 381 while (limbs--) *n++ = ~(*n); 382 } 383 384 // lhs ^= rhs Xor(int limbs,uint32_t * lhs,const uint32_t * rhs)385 void Xor(int limbs, uint32_t *lhs, const uint32_t *rhs) 386 { 387 while (limbs--) *lhs++ ^= *rhs++; 388 } 389 390 // Return the carry out from A += B << S AddLeftShift32(int limbs,uint32_t * A,const uint32_t * B,uint32_t S)391 uint32_t AddLeftShift32( 392 int limbs, // Number of limbs in parameter A and B 393 uint32_t *A, // Large number 394 const uint32_t *B, // Large number 395 uint32_t S) // 32-bit number 396 { 397 uint64_t sum = 0; 398 uint32_t last = 0; 399 400 while (limbs--) 401 { 402 uint32_t b = *B++; 403 404 sum = (uint64_t)((b << S) | (last >> (32 - S))) + *A + (uint32_t)(sum >> 32); 405 406 last = b; 407 *A++ = (uint32_t)sum; 408 } 409 410 return (uint32_t)(sum >> 32) + (last >> (32 - S)); 411 } 412 413 // Return the carry out from result = A * B Multiply32(int limbs,uint32_t * result,const uint32_t * A,uint32_t B)414 uint32_t Multiply32( 415 int limbs, // Number of limbs in parameter A, result 416 uint32_t *result, // Large number 417 const uint32_t *A, // Large number 418 uint32_t B) // 32-bit number 419 { 420 uint64_t p = (uint64_t)A[0] * B; 421 result[0] = (uint32_t)p; 422 423 while (--limbs) 424 { 425 p = (uint64_t)*(++A) * B + (uint32_t)(p >> 32); 426 *(++result) = (uint32_t)p; 427 } 428 429 return (uint32_t)(p >> 32); 430 } 431 432 // Return the carry out from X = X * M + A MultiplyAdd32(int limbs,uint32_t * X,uint32_t M,uint32_t A)433 uint32_t MultiplyAdd32( 434 int limbs, // Number of limbs in parameter A and B 435 uint32_t *X, // Large number 436 uint32_t M, // Large number 437 uint32_t A) // 32-bit number 438 { 439 uint64_t p = (uint64_t)X[0] * M + A; 440 X[0] = (uint32_t)p; 441 442 while (--limbs) 443 { 444 p = (uint64_t)*(++X) * M + (uint32_t)(p >> 32); 445 *X = (uint32_t)p; 446 } 447 448 return (uint32_t)(p >> 32); 449 } 450 451 // Return the carry out from A += B * M AddMultiply32(int limbs,uint32_t * A,const uint32_t * B,uint32_t M)452 uint32_t AddMultiply32( 453 int limbs, // Number of limbs in parameter A and B 454 uint32_t *A, // Large number 455 const uint32_t *B, // Large number 456 uint32_t M) // 32-bit number 457 { 458 // This function is roughly 85% of the cost of exponentiation 459 #if defined(ASSEMBLY_INTEL_SYNTAX) 460 ASSEMBLY_BLOCK // VS.NET, x86, 32-bit words 461 { 462 mov esi, [B] 463 mov edi, [A] 464 mov eax, [esi] 465 mul [M] ; (edx,eax) = [M]*[esi] 466 add eax, [edi] ; (edx,eax) += [edi] 467 adc edx, 0 468 ; (edx,eax) = [B]*[M] + [A] 469 470 mov [edi], eax 471 ; [A] = eax 472 473 mov ecx, [limbs] 474 sub ecx, 1 475 jz loop_done 476 loop_head: 477 lea esi, [esi + 4] ; ++B 478 mov eax, [esi] ; eax = [B] 479 mov ebx, edx ; ebx = last carry 480 lea edi, [edi + 4] ; ++A 481 mul [M] ; (edx,eax) = [M]*[esi] 482 add eax, [edi] ; (edx,eax) += [edi] 483 adc edx, 0 484 add eax, ebx ; (edx,eax) += ebx 485 adc edx, 0 486 ; (edx,eax) = [esi]*[M] + [edi] + (ebx=last carry) 487 488 mov [edi], eax 489 ; [A] = eax 490 491 sub ecx, 1 492 jnz loop_head 493 loop_done: 494 mov [M], edx ; Use [M] to copy the carry into C++ land 495 } 496 497 return M; 498 #else 499 // Unrolled first loop 500 uint64_t p = B[0] * (uint64_t)M + A[0]; 501 A[0] = (uint32_t)p; 502 503 while (--limbs) 504 { 505 p = (*(++B) * (uint64_t)M + *(++A)) + (uint32_t)(p >> 32); 506 A[0] = (uint32_t)p; 507 } 508 509 return (uint32_t)(p >> 32); 510 #endif 511 } 512 513 // product = x * y SimpleMultiply(int limbs,uint32_t * product,const uint32_t * x,const uint32_t * y)514 void SimpleMultiply( 515 int limbs, // Number of limbs in parameters x, y 516 uint32_t *product, // Large number; buffer size = limbs*2 517 const uint32_t *x, // Large number 518 const uint32_t *y) // Large number 519 { 520 // Roughly 25% of the cost of exponentiation 521 product[limbs] = Multiply32(limbs, product, x, y[0]); 522 523 uint32_t ctr = limbs; 524 while (--ctr) 525 { 526 ++product; 527 product[limbs] = AddMultiply32(limbs, product, x, (++y)[0]); 528 } 529 } 530 531 // product = low half of x * y product SimpleMultiplyLowHalf(int limbs,uint32_t * product,const uint32_t * x,const uint32_t * y)532 void SimpleMultiplyLowHalf( 533 int limbs, // Number of limbs in parameters x, y 534 uint32_t *product, // Large number; buffer size = limbs 535 const uint32_t *x, // Large number 536 const uint32_t *y) // Large number 537 { 538 Multiply32(limbs, product, x, y[0]); 539 540 while (--limbs) 541 { 542 ++product; 543 ++y; 544 AddMultiply32(limbs, product, x, y[0]); 545 } 546 } 547 548 // product = x ^ 2 SimpleSquare(int limbs,uint32_t * product,const uint32_t * x)549 void SimpleSquare( 550 int limbs, // Number of limbs in parameter x 551 uint32_t *product, // Large number; buffer size = limbs*2 552 const uint32_t *x) // Large number 553 { 554 // Seems about 15% faster than SimpleMultiply() in practice 555 uint32_t *cross_product = (uint32_t*)alloca(limbs*2*4); 556 557 // Calculate square-less and repeat-less cross products 558 cross_product[limbs] = Multiply32(limbs - 1, cross_product + 1, x + 1, x[0]); 559 for (int ii = 1; ii < limbs - 1; ++ii) 560 { 561 cross_product[limbs + ii] = AddMultiply32(limbs - ii - 1, cross_product + ii*2 + 1, x + ii + 1, x[ii]); 562 } 563 564 // Calculate square products 565 for (int ii = 0; ii < limbs; ++ii) 566 { 567 uint32_t xi = x[ii]; 568 uint64_t si = (uint64_t)xi * xi; 569 product[ii*2] = (uint32_t)si; 570 product[ii*2+1] = (uint32_t)(si >> 32); 571 } 572 573 // Multiply the cross product by 2 and add it to the square products 574 product[limbs*2 - 1] += AddLeftShift32(limbs*2 - 2, product + 1, cross_product + 1, 1); 575 } 576 577 // product = xy 578 // memory space for product may not overlap with x,y Multiply(int limbs,uint32_t * product,const uint32_t * x,const uint32_t * y)579 void Multiply( 580 int limbs, // Number of limbs in x,y 581 uint32_t *product, // Product; buffer size = limbs*2 582 const uint32_t *x, // Large number; buffer size = limbs 583 const uint32_t *y) // Large number; buffer size = limbs 584 { 585 // Stop recursing under 640 bits or odd limb count 586 if (limbs < 30 || (limbs & 1)) 587 { 588 SimpleMultiply(limbs, product, x, y); 589 return; 590 } 591 592 // Compute high and low products 593 Multiply(limbs/2, product, x, y); 594 Multiply(limbs/2, product + limbs, x + limbs/2, y + limbs/2); 595 596 // Compute (x1 + x2), xc = carry out 597 uint32_t *xsum = (uint32_t*)alloca((limbs/2)*4); 598 uint32_t xcarry = Add(xsum, x, limbs/2, x + limbs/2, limbs/2); 599 600 // Compute (y1 + y2), yc = carry out 601 uint32_t *ysum = (uint32_t*)alloca((limbs/2)*4); 602 uint32_t ycarry = Add(ysum, y, limbs/2, y + limbs/2, limbs/2); 603 604 // Compute (x1 + x2) * (y1 + y2) 605 uint32_t *cross_product = (uint32_t*)alloca(limbs*4); 606 Multiply(limbs/2, cross_product, xsum, ysum); 607 608 // Subtract out the high and low products 609 int32_t cross_carry = Subtract(cross_product, limbs, product, limbs); 610 cross_carry += Subtract(cross_product, limbs, product + limbs, limbs); 611 612 // Fix the extra high carry bits of the result 613 if (ycarry) cross_carry += Add(cross_product + limbs/2, limbs/2, xsum, limbs/2); 614 if (xcarry) cross_carry += Add(cross_product + limbs/2, limbs/2, ysum, limbs/2); 615 cross_carry += (xcarry & ycarry); 616 617 // Add the cross product into the result 618 cross_carry += Add(product + limbs/2, limbs*3/2, cross_product, limbs); 619 620 // Add in the fixed high carry bits 621 if (cross_carry) Add32(product + limbs*3/2, limbs/2, cross_carry); 622 } 623 624 // product = x^2 625 // memory space for product may not overlap with x Square(int limbs,uint32_t * product,const uint32_t * x)626 void Square( 627 int limbs, // Number of limbs in x 628 uint32_t *product, // Product; buffer size = limbs*2 629 const uint32_t *x) // Large number; buffer size = limbs 630 { 631 // Stop recursing under 1280 bits or odd limb count 632 if (limbs < 40 || (limbs & 1)) 633 { 634 SimpleSquare(limbs, product, x); 635 return; 636 } 637 638 // Compute high and low squares 639 Square(limbs/2, product, x); 640 Square(limbs/2, product + limbs, x + limbs/2); 641 642 // Generate the cross product 643 uint32_t *cross_product = (uint32_t*)alloca(limbs*4); 644 Multiply(limbs/2, cross_product, x, x + limbs/2); 645 646 // Multiply the cross product by 2 and add it to the result 647 uint32_t cross_carry = AddLeftShift32(limbs, product + limbs/2, cross_product, 1); 648 649 // Roll the carry out up to the highest limb 650 if (cross_carry) Add32(product + limbs*3/2, limbs/2, cross_carry); 651 } 652 653 // Returns the remainder of N / divisor for a 32-bit divisor Modulus32(int limbs,const uint32_t * N,uint32_t divisor)654 uint32_t Modulus32( 655 int limbs, // Number of limbs in parameter N 656 const uint32_t *N, // Large number, buffer size = limbs 657 uint32_t divisor) // 32-bit number 658 { 659 uint32_t remainder = N[limbs-1] < divisor ? N[limbs-1] : 0; 660 uint32_t counter = N[limbs-1] < divisor ? limbs-1 : limbs; 661 662 while (counter--) remainder = (uint32_t)((((uint64_t)remainder << 32) | N[counter]) % divisor); 663 664 return remainder; 665 } 666 667 /* 668 * 'A' is overwritten with the quotient of the operation 669 * Returns the remainder of 'A' / divisor for a 32-bit divisor 670 * 671 * Does not check for divide-by-zero 672 */ Divide32(int limbs,uint32_t * A,uint32_t divisor)673 uint32_t Divide32( 674 int limbs, // Number of limbs in parameter A 675 uint32_t *A, // Large number, buffer size = limbs 676 uint32_t divisor) // 32-bit number 677 { 678 uint64_t r = 0; 679 for (int ii = limbs-1; ii >= 0; --ii) 680 { 681 uint64_t n = (r << 32) | A[ii]; 682 A[ii] = (uint32_t)(n / divisor); 683 r = n % divisor; 684 } 685 686 return (uint32_t)r; 687 } 688 689 // returns (n ^ -1) Mod 2^32 MulInverse32(uint32_t n)690 uint32_t MulInverse32(uint32_t n) 691 { 692 // {u1, g1} = 2^32 / n 693 uint32_t hb = (~(n - 1) >> 31); 694 uint32_t u1 = -(int32_t)(0xFFFFFFFF / n + hb); 695 uint32_t g1 = ((-(int32_t)hb) & (0xFFFFFFFF % n + 1)) - n; 696 697 if (!g1) { 698 if (n != 1) return 0; 699 else return 1; 700 } 701 702 uint32_t q, u = 1, g = n; 703 704 for (;;) { 705 q = g / g1; 706 g %= g1; 707 708 if (!g) { 709 if (g1 != 1) return 0; 710 else return u1; 711 } 712 713 u -= q*u1; 714 q = g1 / g; 715 g1 %= g; 716 717 if (!g1) { 718 if (g != 1) return 0; 719 else return u; 720 } 721 722 u1 -= q*u; 723 } 724 } 725 726 /* 727 * Computes multiplicative inverse of given number 728 * Such that: result * u = 1 729 * Using Extended Euclid's Algorithm (GCDe) 730 * 731 * This is not always possible, so it will return false iff not possible. 732 */ MulInverse(int limbs,const uint32_t * u,uint32_t * result)733 bool MulInverse( 734 int limbs, // Limbs in u and result 735 const uint32_t *u, // Large number, buffer size = limbs 736 uint32_t *result) // Large number, buffer size = limbs 737 { 738 uint32_t *u1 = (uint32_t*)alloca(limbs*4); 739 uint32_t *u3 = (uint32_t*)alloca(limbs*4); 740 uint32_t *v1 = (uint32_t*)alloca(limbs*4); 741 uint32_t *v3 = (uint32_t*)alloca(limbs*4); 742 uint32_t *t1 = (uint32_t*)alloca(limbs*4); 743 uint32_t *t3 = (uint32_t*)alloca(limbs*4); 744 uint32_t *q = (uint32_t*)alloca((limbs+1)*4); 745 uint32_t *w = (uint32_t*)alloca((limbs+1)*4); 746 747 // Unrolled first iteration 748 { 749 Set32(u1, limbs, 0); 750 Set32(v1, limbs, 1); 751 Set(v3, limbs, u); 752 } 753 754 // Unrolled second iteration 755 if (!LimbDegree(v3, limbs)) 756 return false; 757 758 // {q, t3} <- R / v3 759 Set32(w, limbs, 0); 760 w[limbs] = 1; 761 Divide(w, limbs+1, v3, limbs, q, t3); 762 763 SimpleMultiplyLowHalf(limbs, t1, q, v1); 764 Add(t1, limbs, u1, limbs); 765 766 for (;;) 767 { 768 if (!LimbDegree(t3, limbs)) 769 { 770 Set(result, limbs, v1); 771 return Equal32(v3, limbs, 1); 772 } 773 774 Divide(v3, limbs, t3, limbs, q, u3); 775 SimpleMultiplyLowHalf(limbs, u1, q, t1); 776 Add(u1, limbs, v1, limbs); 777 778 if (!LimbDegree(u3, limbs)) 779 { 780 Negate(limbs, result, t1); 781 return Equal32(t3, limbs, 1); 782 } 783 784 Divide(t3, limbs, u3, limbs, q, v3); 785 SimpleMultiplyLowHalf(limbs, v1, q, u1); 786 Add(v1, limbs, t1, limbs); 787 788 if (!LimbDegree(v3, limbs)) 789 { 790 Set(result, limbs, u1); 791 return Equal32(u3, limbs, 1); 792 } 793 794 Divide(u3, limbs, v3, limbs, q, t3); 795 SimpleMultiplyLowHalf(limbs, t1, q, v1); 796 Add(t1, limbs, u1, limbs); 797 798 if (!LimbDegree(t3, limbs)) 799 { 800 Negate(limbs, result, v1); 801 return Equal32(v3, limbs, 1); 802 } 803 804 Divide(v3, limbs, t3, limbs, q, u3); 805 SimpleMultiplyLowHalf(limbs, u1, q, t1); 806 Add(u1, limbs, v1, limbs); 807 808 if (!LimbDegree(u3, limbs)) 809 { 810 Set(result, limbs, t1); 811 return Equal32(t3, limbs, 1); 812 } 813 814 Divide(t3, limbs, u3, limbs, q, v3); 815 SimpleMultiplyLowHalf(limbs, v1, q, u1); 816 Add(v1, limbs, t1, limbs); 817 818 if (!LimbDegree(v3, limbs)) 819 { 820 Negate(limbs, result, u1); 821 return Equal32(u3, limbs, 1); 822 } 823 824 Divide(u3, limbs, v3, limbs, q, t3); 825 SimpleMultiplyLowHalf(limbs, t1, q, v1); 826 Add(t1, limbs, u1, limbs); 827 } 828 } 829 830 // {q, r} = u / v 831 // q is not u or v 832 // Return false on divide by zero Divide(const uint32_t * u,int u_limbs,const uint32_t * v,int v_limbs,uint32_t * q,uint32_t * r)833 bool Divide( 834 const uint32_t *u, // numerator, size = u_limbs 835 int u_limbs, 836 const uint32_t *v, // denominator, size = v_limbs 837 int v_limbs, 838 uint32_t *q, // quotient, size = u_limbs 839 uint32_t *r) // remainder, size = v_limbs 840 { 841 // calculate v_used and u_used 842 int v_used = LimbDegree(v, v_limbs); 843 if (!v_used) return false; 844 845 int u_used = LimbDegree(u, u_limbs); 846 847 // if u < v, avoid division 848 if (u_used <= v_used && Less(u, u_used, v, v_used)) 849 { 850 // r = u, q = 0 851 Set(r, v_limbs, u, u_used); 852 Set32(q, u_limbs, 0); 853 return true; 854 } 855 856 // if v is 32 bits, use faster Divide32 code 857 if (v_used == 1) 858 { 859 // {q, r} = u / v[0] 860 Set(q, u_limbs, u); 861 Set32(r, v_limbs, Divide32(u_limbs, q, v[0])); 862 return true; 863 } 864 865 // calculate high zero bits in v's high used limb 866 int shift = 32 - Degree(v[v_used - 1]); 867 int uu_used = u_used; 868 if (shift > 0) uu_used++; 869 870 uint32_t *uu = (uint32_t*)alloca(uu_used*4); 871 uint32_t *vv = (uint32_t*)alloca(v_used*4); 872 873 // shift left to fill high MSB of divisor 874 if (shift > 0) 875 { 876 ShiftLeft(v_used, vv, v, shift); 877 uu[u_used] = ShiftLeft(u_used, uu, u, shift); 878 } 879 else 880 { 881 Set(uu, u_used, u); 882 Set(vv, v_used, v); 883 } 884 885 int q_high_index = uu_used - v_used; 886 887 if (GreaterOrEqual(uu + q_high_index, v_used, vv, v_used)) 888 { 889 Subtract(uu + q_high_index, v_used, vv, v_used); 890 Set32(q + q_high_index, u_used - q_high_index, 1); 891 } 892 else 893 { 894 Set32(q + q_high_index, u_used - q_high_index, 0); 895 } 896 897 uint32_t *vq_product = (uint32_t*)alloca((v_used+1)*4); 898 899 // for each limb, 900 for (int ii = q_high_index - 1; ii >= 0; --ii) 901 { 902 uint64_t q_full = *(uint64_t*)(uu + ii + v_used - 1) / vv[v_used - 1]; 903 uint32_t q_low = (uint32_t)q_full; 904 uint32_t q_high = (uint32_t)(q_full >> 32); 905 906 vq_product[v_used] = Multiply32(v_used, vq_product, vv, q_low); 907 908 if (q_high) // it must be '1' 909 Add(vq_product + 1, v_used, vv, v_used); 910 911 if (Subtract(uu + ii, v_used + 1, vq_product, v_used + 1)) 912 { 913 --q_low; 914 if (Add(uu + ii, v_used + 1, vv, v_used) == 0) 915 { 916 --q_low; 917 Add(uu + ii, v_used + 1, vv, v_used); 918 } 919 } 920 921 q[ii] = q_low; 922 } 923 924 memset(r + v_used, 0, (v_limbs - v_used)*4); 925 ShiftRight(v_used, r, uu, shift); 926 927 return true; 928 } 929 930 // r = u % v 931 // Return false on divide by zero Modulus(const uint32_t * u,int u_limbs,const uint32_t * v,int v_limbs,uint32_t * r)932 bool Modulus( 933 const uint32_t *u, // numerator, size = u_limbs 934 int u_limbs, 935 const uint32_t *v, // denominator, size = v_limbs 936 int v_limbs, 937 uint32_t *r) // remainder, size = v_limbs 938 { 939 // calculate v_used and u_used 940 int v_used = LimbDegree(v, v_limbs); 941 if (!v_used) return false; 942 943 int u_used = LimbDegree(u, u_limbs); 944 945 // if u < v, avoid division 946 if (u_used <= v_used && Less(u, u_used, v, v_used)) 947 { 948 // r = u, q = 0 949 Set(r, v_limbs, u, u_used); 950 //Set32(q, u_limbs, 0); 951 return true; 952 } 953 954 // if v is 32 bits, use faster Divide32 code 955 if (v_used == 1) 956 { 957 // {q, r} = u / v[0] 958 //Set(q, u_limbs, u); 959 Set32(r, v_limbs, Modulus32(u_limbs, u, v[0])); 960 return true; 961 } 962 963 // calculate high zero bits in v's high used limb 964 int shift = 32 - Degree(v[v_used - 1]); 965 int uu_used = u_used; 966 if (shift > 0) uu_used++; 967 968 uint32_t *uu = (uint32_t*)alloca(uu_used*4); 969 uint32_t *vv = (uint32_t*)alloca(v_used*4); 970 971 // shift left to fill high MSB of divisor 972 if (shift > 0) 973 { 974 ShiftLeft(v_used, vv, v, shift); 975 uu[u_used] = ShiftLeft(u_used, uu, u, shift); 976 } 977 else 978 { 979 Set(uu, u_used, u); 980 Set(vv, v_used, v); 981 } 982 983 int q_high_index = uu_used - v_used; 984 985 if (GreaterOrEqual(uu + q_high_index, v_used, vv, v_used)) 986 { 987 Subtract(uu + q_high_index, v_used, vv, v_used); 988 //Set32(q + q_high_index, u_used - q_high_index, 1); 989 } 990 else 991 { 992 //Set32(q + q_high_index, u_used - q_high_index, 0); 993 } 994 995 uint32_t *vq_product = (uint32_t*)alloca((v_used+1)*4); 996 997 // for each limb, 998 for (int ii = q_high_index - 1; ii >= 0; --ii) 999 { 1000 uint64_t q_full = *(uint64_t*)(uu + ii + v_used - 1) / vv[v_used - 1]; 1001 uint32_t q_low = (uint32_t)q_full; 1002 uint32_t q_high = (uint32_t)(q_full >> 32); 1003 1004 vq_product[v_used] = Multiply32(v_used, vq_product, vv, q_low); 1005 1006 if (q_high) // it must be '1' 1007 Add(vq_product + 1, v_used, vv, v_used); 1008 1009 if (Subtract(uu + ii, v_used + 1, vq_product, v_used + 1)) 1010 { 1011 //--q_low; 1012 if (Add(uu + ii, v_used + 1, vv, v_used) == 0) 1013 { 1014 //--q_low; 1015 Add(uu + ii, v_used + 1, vv, v_used); 1016 } 1017 } 1018 1019 //q[ii] = q_low; 1020 } 1021 1022 memset(r + v_used, 0, (v_limbs - v_used)*4); 1023 ShiftRight(v_used, r, uu, shift); 1024 1025 return true; 1026 } 1027 1028 // m_inv ~= 2^(2k)/m 1029 // Generates m_inv parameter of BarrettModulus() 1030 // It is limbs in size, chopping off the 2^k bit 1031 // Only works for m with the high bit set BarrettModulusPrecomp(int limbs,const uint32_t * m,uint32_t * m_inv)1032 void BarrettModulusPrecomp( 1033 int limbs, // Number of limbs in m and m_inv 1034 const uint32_t *m, // Modulus, size = limbs 1035 uint32_t *m_inv) // Large number result, size = limbs 1036 { 1037 uint32_t *q = (uint32_t*)alloca((limbs*2+1)*4); 1038 1039 // q = 2^(2k) 1040 big::Set32(q, limbs*2, 0); 1041 q[limbs*2] = 1; 1042 1043 // q /= m 1044 big::Divide(q, limbs*2+1, m, limbs, q, m_inv); 1045 1046 // m_inv = q 1047 Set(m_inv, limbs, q); 1048 } 1049 1050 // r = x mod m 1051 // Using Barrett's method with precomputed m_inv BarrettModulus(int limbs,const uint32_t * x,const uint32_t * m,const uint32_t * m_inv,uint32_t * result)1052 void BarrettModulus( 1053 int limbs, // Number of limbs in m and m_inv 1054 const uint32_t *x, // Number to reduce, size = limbs*2 1055 const uint32_t *m, // Modulus, size = limbs 1056 const uint32_t *m_inv, // R/Modulus, precomputed, size = limbs 1057 uint32_t *result) // Large number result 1058 { 1059 // q2 = x * m_inv 1060 // Skips the low limbs+1 words and some high limbs too 1061 // Needs to partially calculate the next 2 words below for carries 1062 uint32_t *q2 = (uint32_t*)alloca((limbs+3)*4); 1063 int ii, jj = limbs - 1; 1064 1065 // derived from the fact that m_inv[limbs] was always 1, so m_inv is the same length as modulus now 1066 *(uint64_t*)q2 = (uint64_t)m_inv[jj] * x[jj]; 1067 *(uint64_t*)(q2 + 1) = (uint64_t)q2[1] + x[jj]; 1068 1069 for (ii = 1; ii < limbs; ++ii) 1070 *(uint64_t*)(q2 + ii + 1) = ((uint64_t)q2[ii + 1] + x[jj + ii]) + AddMultiply32(ii + 1, q2, m_inv + jj - ii, x[jj + ii]); 1071 1072 *(uint64_t*)(q2 + ii + 1) = ((uint64_t)q2[ii + 1] + x[jj + ii]) + AddMultiply32(ii, q2 + 1, m_inv, x[jj + ii]); 1073 1074 q2 += 2; 1075 1076 // r2 = (q3 * m2) mod b^(k+1) 1077 uint32_t *r2 = (uint32_t*)alloca((limbs+1)*4); 1078 1079 // Skip high words in product, also input limbs are different by 1 1080 Multiply32(limbs + 1, r2, q2, m[0]); 1081 for (int ii = 1; ii < limbs; ++ii) 1082 AddMultiply32(limbs + 1 - ii, r2 + ii, q2, m[ii]); 1083 1084 // Correct the error of up to two modulii 1085 uint32_t *r = (uint32_t*)alloca((limbs+1)*4); 1086 if (Subtract(r, x, limbs+1, r2, limbs+1)) 1087 { 1088 while (!Subtract(r, limbs+1, m, limbs)); 1089 } 1090 else 1091 { 1092 while (GreaterOrEqual(r, limbs+1, m, limbs)) 1093 Subtract(r, limbs+1, m, limbs); 1094 } 1095 1096 Set(result, limbs, r); 1097 } 1098 1099 // result = (x * y) (Mod modulus) MulMod(int limbs,const uint32_t * x,const uint32_t * y,const uint32_t * modulus,uint32_t * result)1100 bool MulMod( 1101 int limbs, // Number of limbs in x,y,modulus 1102 const uint32_t *x, // Large number x 1103 const uint32_t *y, // Large number y 1104 const uint32_t *modulus, // Large number modulus 1105 uint32_t *result) // Large number result 1106 { 1107 uint32_t *product = (uint32_t*)alloca(limbs*2*4); 1108 1109 Multiply(limbs, product, x, y); 1110 1111 return Modulus(product, limbs * 2, modulus, limbs, result); 1112 } 1113 1114 // Convert bigint to string 1115 /* 1116 std::string ToStr(const uint32_t *n, int limbs, int base) 1117 { 1118 limbs = LimbDegree(n, limbs); 1119 if (!limbs) return "0"; 1120 1121 std::string out; 1122 char ch; 1123 1124 uint32_t *m = (uint32_t*)alloca(limbs*4); 1125 Set(m, limbs, n, limbs); 1126 1127 while (limbs) 1128 { 1129 uint32_t mod = Divide32(limbs, m, base); 1130 if (mod <= 9) ch = '0' + mod; 1131 else ch = 'A' + mod - 10; 1132 out = ch + out; 1133 limbs = LimbDegree(m, limbs); 1134 } 1135 1136 return out; 1137 } 1138 */ 1139 1140 // Convert string to bigint 1141 // Return 0 if string contains non-digit characters, else number of limbs used ToInt(uint32_t * lhs,int max_limbs,const char * rhs,uint32_t base)1142 int ToInt(uint32_t *lhs, int max_limbs, const char *rhs, uint32_t base) 1143 { 1144 if (max_limbs < 2) return 0; 1145 1146 lhs[0] = 0; 1147 int used = 1; 1148 1149 char ch; 1150 while ((ch = *rhs++)) 1151 { 1152 uint32_t mod; 1153 if (ch >= '0' && ch <= '9') mod = ch - '0'; 1154 else mod = toupper(ch) - 'A' + 10; 1155 if (mod >= base) return 0; 1156 1157 // lhs *= base 1158 uint32_t carry = MultiplyAdd32(used, lhs, base, mod); 1159 1160 // react to running out of room 1161 if (carry) 1162 { 1163 if (used >= max_limbs) 1164 return 0; 1165 1166 lhs[used++] = carry; 1167 } 1168 } 1169 1170 if (used < max_limbs) 1171 Set32(lhs+used, max_limbs-used, 0); 1172 1173 return used; 1174 } 1175 1176 /* 1177 * Computes: result = GCD(a, b) (greatest common divisor) 1178 * 1179 * Length of result is the length of the smallest argument 1180 */ GCD(const uint32_t * a,int a_limbs,const uint32_t * b,int b_limbs,uint32_t * result)1181 void GCD( 1182 const uint32_t *a, // Large number, buffer size = a_limbs 1183 int a_limbs, // Size of a 1184 const uint32_t *b, // Large number, buffer size = b_limbs 1185 int b_limbs, // Size of b 1186 uint32_t *result) // Large number, buffer size = min(a, b) 1187 { 1188 int limbs = (a_limbs <= b_limbs) ? a_limbs : b_limbs; 1189 1190 uint32_t *g = (uint32_t*)alloca(limbs*4); 1191 uint32_t *g1 = (uint32_t*)alloca(limbs*4); 1192 1193 if (a_limbs <= b_limbs) 1194 { 1195 // g = a, g1 = b (mod a) 1196 Set(g, limbs, a, a_limbs); 1197 Modulus(b, b_limbs, a, a_limbs, g1); 1198 } 1199 else 1200 { 1201 // g = b, g1 = a (mod b) 1202 Set(g, limbs, b, b_limbs); 1203 Modulus(a, a_limbs, b, b_limbs, g1); 1204 } 1205 1206 for (;;) { 1207 // g = (g mod g1) 1208 Modulus(g, limbs, g1, limbs, g); 1209 1210 if (!LimbDegree(g, limbs)) { 1211 Set(result, limbs, g1, limbs); 1212 return; 1213 } 1214 1215 // g1 = (g1 mod g) 1216 Modulus(g1, limbs, g, limbs, g1); 1217 1218 if (!LimbDegree(g1, limbs)) { 1219 Set(result, limbs, g, limbs); 1220 return; 1221 } 1222 } 1223 } 1224 1225 /* 1226 * Computes: result = (1/u) (Mod v) 1227 * Such that: result * u (Mod v) = 1 1228 * Using Extended Euclid's Algorithm (GCDe) 1229 * 1230 * This is not always possible, so it will return false iff not possible. 1231 */ InvMod(const uint32_t * u,int u_limbs,const uint32_t * v,int limbs,uint32_t * result)1232 bool InvMod( 1233 const uint32_t *u, // Large number, buffer size = u_limbs 1234 int u_limbs, // Limbs in u 1235 const uint32_t *v, // Large number, buffer size = limbs 1236 int limbs, // Limbs in modulus(v) and result 1237 uint32_t *result) // Large number, buffer size = limbs 1238 { 1239 uint32_t *u1 = (uint32_t*)alloca(limbs*4); 1240 uint32_t *u3 = (uint32_t*)alloca(limbs*4); 1241 uint32_t *v1 = (uint32_t*)alloca(limbs*4); 1242 uint32_t *v3 = (uint32_t*)alloca(limbs*4); 1243 uint32_t *t1 = (uint32_t*)alloca(limbs*4); 1244 uint32_t *t3 = (uint32_t*)alloca(limbs*4); 1245 uint32_t *q = (uint32_t*)alloca((limbs + u_limbs)*4); 1246 1247 // Unrolled first iteration 1248 { 1249 Set32(u1, limbs, 0); 1250 Set32(v1, limbs, 1); 1251 Set(u3, limbs, v); 1252 1253 // v3 = u % v 1254 Modulus(u, u_limbs, v, limbs, v3); 1255 } 1256 1257 for (;;) 1258 { 1259 if (!LimbDegree(v3, limbs)) 1260 { 1261 Subtract(result, v, limbs, u1, limbs); 1262 return Equal32(u3, limbs, 1); 1263 } 1264 1265 Divide(u3, limbs, v3, limbs, q, t3); 1266 SimpleMultiplyLowHalf(limbs, t1, q, v1); 1267 Add(t1, limbs, u1, limbs); 1268 1269 if (!LimbDegree(t3, limbs)) 1270 { 1271 Set(result, limbs, v1); 1272 return Equal32(v3, limbs, 1); 1273 } 1274 1275 Divide(v3, limbs, t3, limbs, q, u3); 1276 SimpleMultiplyLowHalf(limbs, u1, q, t1); 1277 Add(u1, limbs, v1, limbs); 1278 1279 if (!LimbDegree(u3, limbs)) 1280 { 1281 Subtract(result, v, limbs, t1, limbs); 1282 return Equal32(t3, limbs, 1); 1283 } 1284 1285 Divide(t3, limbs, u3, limbs, q, v3); 1286 SimpleMultiplyLowHalf(limbs, v1, q, u1); 1287 Add(v1, limbs, t1, limbs); 1288 1289 if (!LimbDegree(v3, limbs)) 1290 { 1291 Set(result, limbs, u1); 1292 return Equal32(u3, limbs, 1); 1293 } 1294 1295 Divide(u3, limbs, v3, limbs, q, t3); 1296 SimpleMultiplyLowHalf(limbs, t1, q, v1); 1297 Add(t1, limbs, u1, limbs); 1298 1299 if (!LimbDegree(t3, limbs)) 1300 { 1301 Subtract(result, v, limbs, v1, limbs); 1302 return Equal32(v3, limbs, 1); 1303 } 1304 1305 Divide(v3, limbs, t3, limbs, q, u3); 1306 SimpleMultiplyLowHalf(limbs, u1, q, t1); 1307 Add(u1, limbs, v1, limbs); 1308 1309 if (!LimbDegree(u3, limbs)) 1310 { 1311 Set(result, limbs, t1); 1312 return Equal32(t3, limbs, 1); 1313 } 1314 1315 Divide(t3, limbs, u3, limbs, q, v3); 1316 SimpleMultiplyLowHalf(limbs, v1, q, u1); 1317 Add(v1, limbs, t1, limbs); 1318 } 1319 } 1320 1321 // root = sqrt(square) 1322 // Based on Newton-Raphson iteration: root_n+1 = (root_n + square/root_n) / 2 1323 // Doubles number of correct bits each iteration 1324 // Precondition: The high limb of square is non-zero 1325 // Returns false if it was unable to determine the root SquareRoot(int limbs,const uint32_t * square,uint32_t * root)1326 bool SquareRoot( 1327 int limbs, // Number of limbs in root 1328 const uint32_t *square, // Square to root, size = limbs * 2 1329 uint32_t *root) // Output root, size = limbs 1330 { 1331 uint32_t *q = (uint32_t*)alloca(limbs*2*4); 1332 uint32_t *r = (uint32_t*)alloca((limbs+1)*4); 1333 1334 // Take high limbs of square as the initial root guess 1335 Set(root, limbs, square + limbs); 1336 1337 int ctr = 64; 1338 while (ctr--) 1339 { 1340 // {q, r} = square / root 1341 Divide(square, limbs*2, root, limbs, q, r); 1342 1343 // root = (root + q) / 2, assuming high limbs of q = 0 1344 Add(q, limbs+1, root, limbs); 1345 1346 // Round division up to the nearest bit 1347 // Fixes a problem where root is off by 1 1348 if (q[0] & 1) Add32(q, limbs+1, 2); 1349 1350 ShiftRight(limbs+1, q, q, 1); 1351 1352 // Return success if there was no change 1353 if (Equal(limbs, q, root)) 1354 return true; 1355 1356 // Else update root and continue 1357 Set(root, limbs, q); 1358 } 1359 1360 // In practice only takes about 9 iterations, as many as 31 1361 // Varies slightly as number of limbs increases but not by much 1362 return false; 1363 } 1364 1365 // Calculates mod_inv from low limb of modulus for Mon*() MonReducePrecomp(uint32_t modulus0)1366 uint32_t MonReducePrecomp(uint32_t modulus0) 1367 { 1368 // mod_inv = -M ^ -1 (Mod 2^32) 1369 return MulInverse32(-(int32_t)modulus0); 1370 } 1371 1372 // Compute n_residue for Montgomery reduction MonInputResidue(const uint32_t * n,int n_limbs,const uint32_t * modulus,int m_limbs,uint32_t * n_residue)1373 void MonInputResidue( 1374 const uint32_t *n, // Large number, buffer size = n_limbs 1375 int n_limbs, // Number of limbs in n 1376 const uint32_t *modulus, // Large number, buffer size = m_limbs 1377 int m_limbs, // Number of limbs in modulus 1378 uint32_t *n_residue) // Result, buffer size = m_limbs 1379 { 1380 // p = n * 2^(k*m) 1381 uint32_t *p = (uint32_t*)alloca((n_limbs+m_limbs)*4); 1382 Set(p+m_limbs, n_limbs, n, n_limbs); 1383 Set32(p, m_limbs, 0); 1384 1385 // n_residue = p (Mod modulus) 1386 Modulus(p, n_limbs+m_limbs, modulus, m_limbs, n_residue); 1387 } 1388 1389 // result = a * b * r^-1 (Mod modulus) in Montgomery domain MonPro(int limbs,const uint32_t * a_residue,const uint32_t * b_residue,const uint32_t * modulus,uint32_t mod_inv,uint32_t * result)1390 void MonPro( 1391 int limbs, // Number of limbs in each parameter 1392 const uint32_t *a_residue, // Large number, buffer size = limbs 1393 const uint32_t *b_residue, // Large number, buffer size = limbs 1394 const uint32_t *modulus, // Large number, buffer size = limbs 1395 uint32_t mod_inv, // MonReducePrecomp() return 1396 uint32_t *result) // Large number, buffer size = limbs 1397 { 1398 uint32_t *t = (uint32_t*)alloca(limbs*2*4); 1399 1400 Multiply(limbs, t, a_residue, b_residue); 1401 MonReduce(limbs, t, modulus, mod_inv, result); 1402 } 1403 1404 // result = a^-1 (Mod modulus) in Montgomery domain MonInverse(int limbs,const uint32_t * a_residue,const uint32_t * modulus,uint32_t mod_inv,uint32_t * result)1405 void MonInverse( 1406 int limbs, // Number of limbs in each parameter 1407 const uint32_t *a_residue, // Large number, buffer size = limbs 1408 const uint32_t *modulus, // Large number, buffer size = limbs 1409 uint32_t mod_inv, // MonReducePrecomp() return 1410 uint32_t *result) // Large number, buffer size = limbs 1411 { 1412 Set(result, limbs, a_residue); 1413 MonFinish(limbs, result, modulus, mod_inv); 1414 InvMod(result, limbs, modulus, limbs, result); 1415 MonInputResidue(result, limbs, modulus, limbs, result); 1416 } 1417 1418 // result = a * r^-1 (Mod modulus) in Montgomery domain 1419 // The result may be greater than the modulus, but this is okay since 1420 // the result is still in the RNS. MonFinish() corrects this at the end. MonReduce(int limbs,uint32_t * s,const uint32_t * modulus,uint32_t mod_inv,uint32_t * result)1421 void MonReduce( 1422 int limbs, // Number of limbs in modulus 1423 uint32_t *s, // Large number, buffer size = limbs*2, gets clobbered 1424 const uint32_t *modulus, // Large number, buffer size = limbs 1425 uint32_t mod_inv, // MonReducePrecomp() return 1426 uint32_t *result) // Large number, buffer size = limbs 1427 { 1428 // This function is roughly 60% of the cost of exponentiation 1429 for (int ii = 0; ii < limbs; ++ii) 1430 { 1431 uint32_t q = s[0] * mod_inv; 1432 s[0] = AddMultiply32(limbs, s, modulus, q); 1433 ++s; 1434 } 1435 1436 // Add the saved carries 1437 if (Add(result, s, limbs, s - limbs, limbs)) 1438 { 1439 // Reduce the result only when needed 1440 Subtract(result, limbs, modulus, limbs); 1441 } 1442 } 1443 1444 // result = a * r^-1 (Mod modulus) in Montgomery domain MonFinish(int limbs,uint32_t * n,const uint32_t * modulus,uint32_t mod_inv)1445 void MonFinish( 1446 int limbs, // Number of limbs in each parameter 1447 uint32_t *n, // Large number, buffer size = limbs 1448 const uint32_t *modulus, // Large number, buffer size = limbs 1449 uint32_t mod_inv) // MonReducePrecomp() return 1450 { 1451 uint32_t *t = (uint32_t*)alloca(limbs*2*4); 1452 memcpy(t, n, limbs*4); 1453 memset(t + limbs, 0, limbs*4); 1454 1455 // Reduce the number 1456 MonReduce(limbs, t, modulus, mod_inv, n); 1457 1458 // Fix MonReduce() results greater than the modulus 1459 if (!Less(limbs, n, modulus)) 1460 Subtract(n, limbs, modulus, limbs); 1461 } 1462 1463 // Simple internal version without windowing for small exponents SimpleMonExpMod(const uint32_t * base,const uint32_t * exponent,int exponent_limbs,const uint32_t * modulus,int mod_limbs,uint32_t mod_inv,uint32_t * result)1464 static void SimpleMonExpMod( 1465 const uint32_t *base, // Base for exponentiation, buffer size = mod_limbs 1466 const uint32_t *exponent,// Exponent, buffer size = exponent_limbs 1467 int exponent_limbs, // Number of limbs in exponent 1468 const uint32_t *modulus, // Modulus, buffer size = mod_limbs 1469 int mod_limbs, // Number of limbs in modulus 1470 uint32_t mod_inv, // MonReducePrecomp() return 1471 uint32_t *result) // Result, buffer size = mod_limbs 1472 { 1473 bool set = false; 1474 1475 uint32_t *temp = (uint32_t*)alloca((mod_limbs*2)*4); 1476 1477 // Run down exponent bits and use the squaring method 1478 for (int ii = exponent_limbs - 1; ii >= 0; --ii) 1479 { 1480 uint32_t e_i = exponent[ii]; 1481 1482 for (uint32_t mask = 0x80000000; mask; mask >>= 1) 1483 { 1484 if (set) 1485 { 1486 // result = result^2 1487 Square(mod_limbs, temp, result); 1488 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1489 1490 if (e_i & mask) 1491 { 1492 // result *= base 1493 Multiply(mod_limbs, temp, result, base); 1494 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1495 } 1496 } 1497 else 1498 { 1499 if (e_i & mask) 1500 { 1501 // result = base 1502 Set(result, mod_limbs, base, mod_limbs); 1503 set = true; 1504 } 1505 } 1506 } 1507 } 1508 } 1509 1510 // Precompute a window for ExpMod() and MonExpMod() 1511 // Requires 2^window_bits multiplies PrecomputeWindow(const uint32_t * base,const uint32_t * modulus,int limbs,uint32_t mod_inv,int window_bits)1512 uint32_t *PrecomputeWindow(const uint32_t *base, const uint32_t *modulus, int limbs, uint32_t mod_inv, int window_bits) 1513 { 1514 uint32_t *temp = (uint32_t*)alloca(limbs*2*4); 1515 1516 uint32_t *base_squared = (uint32_t*)alloca(limbs*4); 1517 Square(limbs, temp, base); 1518 MonReduce(limbs, temp, modulus, mod_inv, base_squared); 1519 1520 // precomputed window starts with 000001, 000011, 000101, 000111, ... 1521 uint32_t k = (1 << (window_bits - 1)); 1522 1523 uint32_t *window = RakNet::OP_NEW_ARRAY<uint32_t>(limbs * k, __FILE__, __LINE__ ); 1524 1525 uint32_t *cw = window; 1526 Set(window, limbs, base); 1527 1528 while (--k) 1529 { 1530 // cw+1 = cw * base^2 1531 Multiply(limbs, temp, cw, base_squared); 1532 MonReduce(limbs, temp, modulus, mod_inv, cw + limbs); 1533 cw += limbs; 1534 } 1535 1536 return window; 1537 }; 1538 1539 // Computes: result = base ^ exponent (Mod modulus) 1540 // Using Montgomery multiplication with simple squaring method 1541 // Base parameter must be a Montgomery Residue created with MonInputResidue() MonExpMod(const uint32_t * base,const uint32_t * exponent,int exponent_limbs,const uint32_t * modulus,int mod_limbs,uint32_t mod_inv,uint32_t * result)1542 void MonExpMod( 1543 const uint32_t *base, // Base for exponentiation, buffer size = mod_limbs 1544 const uint32_t *exponent,// Exponent, buffer size = exponent_limbs 1545 int exponent_limbs, // Number of limbs in exponent 1546 const uint32_t *modulus, // Modulus, buffer size = mod_limbs 1547 int mod_limbs, // Number of limbs in modulus 1548 uint32_t mod_inv, // MonReducePrecomp() return 1549 uint32_t *result) // Result, buffer size = mod_limbs 1550 { 1551 // Calculate the number of window bits to use (decent approximation..) 1552 int window_bits = Degree(exponent_limbs); 1553 1554 // If the window bits are too small, might as well just use left-to-right S&M method 1555 if (window_bits < 4) 1556 { 1557 SimpleMonExpMod(base, exponent, exponent_limbs, modulus, mod_limbs, mod_inv, result); 1558 return; 1559 } 1560 1561 // Precompute a window of the size determined above 1562 uint32_t *window = PrecomputeWindow(base, modulus, mod_limbs, mod_inv, window_bits); 1563 1564 bool seen_bits = false; 1565 uint32_t e_bits=0, trailing_zeroes=0, used_bits = 0; 1566 1567 uint32_t *temp = (uint32_t*)alloca((mod_limbs*2)*4); 1568 1569 for (int ii = exponent_limbs - 1; ii >= 0; --ii) 1570 { 1571 uint32_t e_i = exponent[ii]; 1572 1573 int wordbits = 32; 1574 while (wordbits--) 1575 { 1576 // If we have been accumulating bits, 1577 if (used_bits) 1578 { 1579 // If this new bit is set, 1580 if (e_i >> 31) 1581 { 1582 e_bits <<= 1; 1583 e_bits |= 1; 1584 1585 trailing_zeroes = 0; 1586 } 1587 else // the new bit is unset 1588 { 1589 e_bits <<= 1; 1590 1591 ++trailing_zeroes; 1592 } 1593 1594 ++used_bits; 1595 1596 // If we have used up the window bits, 1597 if (used_bits == (uint32_t) window_bits) 1598 { 1599 // Select window index 1011 from "101110" 1600 uint32_t window_index = e_bits >> (trailing_zeroes + 1); 1601 1602 if (seen_bits) 1603 { 1604 uint32_t ctr = used_bits - trailing_zeroes; 1605 while (ctr--) 1606 { 1607 // result = result^2 1608 Square(mod_limbs, temp, result); 1609 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1610 } 1611 1612 // result = result * window[index] 1613 Multiply(mod_limbs, temp, result, &window[window_index * mod_limbs]); 1614 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1615 } 1616 else 1617 { 1618 // result = window[index] 1619 Set(result, mod_limbs, &window[window_index * mod_limbs]); 1620 seen_bits = true; 1621 } 1622 1623 while (trailing_zeroes--) 1624 { 1625 // result = result^2 1626 Square(mod_limbs, temp, result); 1627 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1628 } 1629 1630 used_bits = 0; 1631 } 1632 } 1633 else 1634 { 1635 // If this new bit is set, 1636 if (e_i >> 31) 1637 { 1638 used_bits = 1; 1639 e_bits = 1; 1640 trailing_zeroes = 0; 1641 } 1642 else // the new bit is unset 1643 { 1644 // If we have processed any bits yet, 1645 if (seen_bits) 1646 { 1647 // result = result^2 1648 Square(mod_limbs, temp, result); 1649 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1650 } 1651 } 1652 } 1653 1654 e_i <<= 1; 1655 } 1656 } 1657 1658 if (used_bits) 1659 { 1660 // Select window index 1011 from "101110" 1661 uint32_t window_index = e_bits >> (trailing_zeroes + 1); 1662 1663 if (seen_bits) 1664 { 1665 uint32_t ctr = used_bits - trailing_zeroes; 1666 while (ctr--) 1667 { 1668 // result = result^2 1669 Square(mod_limbs, temp, result); 1670 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1671 } 1672 1673 // result = result * window[index] 1674 Multiply(mod_limbs, temp, result, &window[window_index * mod_limbs]); 1675 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1676 } 1677 else 1678 { 1679 // result = window[index] 1680 Set(result, mod_limbs, &window[window_index * mod_limbs]); 1681 //seen_bits = true; 1682 } 1683 1684 while (trailing_zeroes--) 1685 { 1686 // result = result^2 1687 Square(mod_limbs, temp, result); 1688 MonReduce(mod_limbs, temp, modulus, mod_inv, result); 1689 } 1690 1691 //e_bits = 0; 1692 } 1693 1694 RakNet::OP_DELETE_ARRAY(window, __FILE__, __LINE__); 1695 } 1696 1697 // Computes: result = base ^ exponent (Mod modulus) 1698 // Using Montgomery multiplication with simple squaring method ExpMod(const uint32_t * base,int base_limbs,const uint32_t * exponent,int exponent_limbs,const uint32_t * modulus,int mod_limbs,uint32_t mod_inv,uint32_t * result)1699 void ExpMod( 1700 const uint32_t *base, // Base for exponentiation, buffer size = base_limbs 1701 int base_limbs, // Number of limbs in base 1702 const uint32_t *exponent,// Exponent, buffer size = exponent_limbs 1703 int exponent_limbs, // Number of limbs in exponent 1704 const uint32_t *modulus, // Modulus, buffer size = mod_limbs 1705 int mod_limbs, // Number of limbs in modulus 1706 uint32_t mod_inv, // MonReducePrecomp() return 1707 uint32_t *result) // Result, buffer size = mod_limbs 1708 { 1709 uint32_t *mon_base = (uint32_t*)alloca(mod_limbs*4); 1710 MonInputResidue(base, base_limbs, modulus, mod_limbs, mon_base); 1711 1712 MonExpMod(mon_base, exponent, exponent_limbs, modulus, mod_limbs, mod_inv, result); 1713 1714 MonFinish(mod_limbs, result, modulus, mod_inv); 1715 } 1716 1717 // returns b ^ e (Mod m) ExpMod(uint32_t b,uint32_t e,uint32_t m)1718 uint32_t ExpMod(uint32_t b, uint32_t e, uint32_t m) 1719 { 1720 // validate arguments 1721 if (b == 0 || m <= 1) return 0; 1722 if (e == 0) return 1; 1723 1724 // find high bit of exponent 1725 uint32_t mask = 0x80000000; 1726 while ((e & mask) == 0) mask >>= 1; 1727 1728 // seen 1 set bit, so result = base so far 1729 uint32_t r = b; 1730 1731 while (mask >>= 1) 1732 { 1733 // VS.NET does a poor job recognizing that the division 1734 // is just an IDIV with a 32-bit dividend (not 64-bit) :-( 1735 1736 // r = r^2 (mod m) 1737 r = (uint32_t)(((uint64_t)r * r) % m); 1738 1739 // if exponent bit is set, r = r*b (mod m) 1740 if (e & mask) r = (uint32_t)(((uint64_t)r * b) % m); 1741 } 1742 1743 return r; 1744 } 1745 1746 // Rabin-Miller method for finding a strong pseudo-prime 1747 // Preconditions: High bit and low bit of n = 1 RabinMillerPrimeTest(const uint32_t * n,int limbs,uint32_t k)1748 bool RabinMillerPrimeTest( 1749 const uint32_t *n, // Number to check for primality 1750 int limbs, // Number of limbs in n 1751 uint32_t k) // Confidence level (40 is pretty good) 1752 { 1753 // n1 = n - 1 1754 uint32_t *n1 = (uint32_t *)alloca(limbs*4); 1755 Set(n1, limbs, n); 1756 Subtract32(n1, limbs, 1); 1757 1758 // d = n1 1759 uint32_t *d = (uint32_t *)alloca(limbs*4); 1760 Set(d, limbs, n1); 1761 1762 // remove factors of two from d 1763 while (!(d[0] & 1)) 1764 ShiftRight(limbs, d, d, 1); 1765 1766 uint32_t *a = (uint32_t *)alloca(limbs*4); 1767 uint32_t *t = (uint32_t *)alloca(limbs*4); 1768 uint32_t *p = (uint32_t *)alloca((limbs*2)*4); 1769 uint32_t n_inv = MonReducePrecomp(n[0]); 1770 1771 // iterate k times 1772 while (k--) 1773 { 1774 //do Random::ref()->generate(a, limbs*4); 1775 do fillBufferMT(a,limbs*4); 1776 while (GreaterOrEqual(a, limbs, n, limbs)); 1777 1778 // a = a ^ d (Mod n) 1779 ExpMod(a, limbs, d, limbs, n, limbs, n_inv, a); 1780 1781 Set(t, limbs, d); 1782 while (!Equal(limbs, t, n1) && 1783 !Equal32(a, limbs, 1) && 1784 !Equal(limbs, a, n1)) 1785 { 1786 // a = a^2 (Mod n), non-critical path 1787 Square(limbs, p, a); 1788 Modulus(p, limbs*2, n, limbs, a); 1789 1790 // t <<= 1 1791 ShiftLeft(limbs, t, t, 1); 1792 } 1793 1794 if (!Equal(limbs, a, n1) && !(t[0] & 1)) return false; 1795 } 1796 1797 return true; 1798 } 1799 1800 // Generate a strong pseudo-prime using the Rabin-Miller primality test GenerateStrongPseudoPrime(uint32_t * n,int limbs)1801 void GenerateStrongPseudoPrime( 1802 uint32_t *n, // Output prime 1803 int limbs) // Number of limbs in n 1804 { 1805 do { 1806 fillBufferMT(n,limbs*4); 1807 n[limbs-1] |= 0x80000000; 1808 n[0] |= 1; 1809 } while (!RabinMillerPrimeTest(n, limbs, 40)); // 40 iterations 1810 } 1811 } 1812 1813 1814