1 /* 2 This program is free software; you can redistribute it and/or modify 3 it under the terms of the GNU General Public License as published by 4 the Free Software Foundation; either version 2 of the License, or 5 (at your option) any later version. 6 7 This program is distributed in the hope that it will be useful, 8 but WITHOUT ANY WARRANTY; without even the implied warranty of 9 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 GNU General Public License for more details. 11 12 You should have received a copy of the GNU General Public License along 13 with this program; if not, write to the Free Software Foundation, Inc., 14 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 15 16 xrdp: A Remote Desktop Protocol server. 17 Copyright (C) Jay Sorg 2004-2005 18 19 ssl calls 20 21 */ 22 23 #include "precomp.h" 24 25 /*****************************************************************************/ 26 static void * g_malloc(int size, int zero) 27 { 28 void * p; 29 30 p = CryptMemAlloc(size); 31 if (zero) 32 { 33 memset(p, 0, size); 34 } 35 return p; 36 } 37 38 /*****************************************************************************/ 39 static void g_free(void * in) 40 { 41 CryptMemFree(in); 42 } 43 44 struct rc4_state 45 { 46 HCRYPTPROV hCryptProv; 47 HCRYPTKEY hKey; 48 }; 49 /*****************************************************************************/ 50 void* 51 rdssl_rc4_info_create(void) 52 { 53 struct rc4_state *info = g_malloc(sizeof(struct rc4_state), 1); 54 BOOL ret; 55 DWORD dwErr; 56 if (!info) 57 { 58 error("rdssl_rc4_info_create no memory\n"); 59 return NULL; 60 } 61 ret = CryptAcquireContext(&info->hCryptProv, 62 L"MSTSC", 63 MS_ENHANCED_PROV, 64 PROV_RSA_FULL, 65 0); 66 if (!ret) 67 { 68 dwErr = GetLastError(); 69 if (dwErr == NTE_BAD_KEYSET) 70 { 71 ret = CryptAcquireContext(&info->hCryptProv, 72 L"MSTSC", 73 MS_ENHANCED_PROV, 74 PROV_RSA_FULL, 75 CRYPT_NEWKEYSET); 76 } 77 } 78 if (!ret) 79 { 80 dwErr = GetLastError(); 81 error("CryptAcquireContext failed with %lx\n", dwErr); 82 g_free(info); 83 return NULL; 84 } 85 return info; 86 } 87 88 /*****************************************************************************/ 89 void 90 rdssl_rc4_info_delete(void* rc4_info) 91 { 92 struct rc4_state *info = rc4_info; 93 BOOL ret = TRUE; 94 DWORD dwErr; 95 if (!info) 96 { 97 //error("rdssl_rc4_info_delete rc4_info is null\n"); 98 return; 99 } 100 if (info->hKey) 101 { 102 ret = CryptDestroyKey(info->hKey); 103 if (!ret) 104 { 105 dwErr = GetLastError(); 106 error("CryptDestroyKey failed with %lx\n", dwErr); 107 } 108 } 109 if (info->hCryptProv) 110 { 111 ret = CryptReleaseContext(info->hCryptProv, 0); 112 if (!ret) 113 { 114 dwErr = GetLastError(); 115 error("CryptReleaseContext failed with %lx\n", dwErr); 116 } 117 } 118 g_free(rc4_info); 119 } 120 121 /*****************************************************************************/ 122 void 123 rdssl_rc4_set_key(void* rc4_info, char* key, int len) 124 { 125 struct rc4_state *info = rc4_info; 126 BOOL ret; 127 DWORD dwErr; 128 BYTE * blob; 129 PUBLICKEYSTRUC *desc; 130 DWORD * keySize; 131 BYTE * keyBuf; 132 if (!rc4_info || !key || !len || !info->hCryptProv) 133 { 134 error("rdssl_rc4_set_key %p %p %d\n", rc4_info, key, len); 135 return; 136 } 137 blob = g_malloc(sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + len, 0); 138 if (!blob) 139 { 140 error("rdssl_rc4_set_key no memory\n"); 141 return; 142 } 143 desc = (PUBLICKEYSTRUC *)blob; 144 keySize = (DWORD *)(blob + sizeof(PUBLICKEYSTRUC)); 145 keyBuf = blob + sizeof(PUBLICKEYSTRUC) + sizeof(DWORD); 146 desc->aiKeyAlg = CALG_RC4; 147 desc->bType = PLAINTEXTKEYBLOB; 148 desc->bVersion = CUR_BLOB_VERSION; 149 desc->reserved = 0; 150 *keySize = len; 151 memcpy(keyBuf, key, len); 152 if (info->hKey) 153 { 154 CryptDestroyKey(info->hKey); 155 info->hKey = 0; 156 } 157 ret = CryptImportKey(info->hCryptProv, 158 blob, 159 sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + len, 160 0, 161 CRYPT_EXPORTABLE, 162 &info->hKey); 163 g_free(blob); 164 if (!ret) 165 { 166 dwErr = GetLastError(); 167 error("CryptImportKey failed with %lx\n", dwErr); 168 } 169 } 170 171 /*****************************************************************************/ 172 void 173 rdssl_rc4_crypt(void* rc4_info, char* in_data, char* out_data, int len) 174 { 175 struct rc4_state *info = rc4_info; 176 BOOL ret; 177 DWORD dwErr; 178 BYTE * intermediate_data; 179 DWORD dwLen = len; 180 if (!rc4_info || !in_data || !out_data || !len || !info->hKey) 181 { 182 error("rdssl_rc4_crypt %p %p %p %d\n", rc4_info, in_data, out_data, len); 183 return; 184 } 185 intermediate_data = g_malloc(len, 0); 186 if (!intermediate_data) 187 { 188 error("rdssl_rc4_set_key no memory\n"); 189 return; 190 } 191 memcpy(intermediate_data, in_data, len); 192 ret = CryptEncrypt(info->hKey, 193 0, 194 FALSE, 195 0, 196 intermediate_data, 197 &dwLen, 198 dwLen); 199 if (!ret) 200 { 201 dwErr = GetLastError(); 202 g_free(intermediate_data); 203 error("CryptEncrypt failed with %lx\n", dwErr); 204 return; 205 } 206 memcpy(out_data, intermediate_data, len); 207 g_free(intermediate_data); 208 } 209 210 struct hash_context 211 { 212 HCRYPTPROV hCryptProv; 213 HCRYPTKEY hHash; 214 }; 215 216 /*****************************************************************************/ 217 void* 218 rdssl_hash_info_create(ALG_ID id) 219 { 220 struct hash_context *info = g_malloc(sizeof(struct hash_context), 1); 221 BOOL ret; 222 DWORD dwErr; 223 if (!info) 224 { 225 error("rdssl_hash_info_create %d no memory\n", id); 226 return NULL; 227 } 228 ret = CryptAcquireContext(&info->hCryptProv, 229 L"MSTSC", 230 MS_ENHANCED_PROV, 231 PROV_RSA_FULL, 232 0); 233 if (!ret) 234 { 235 dwErr = GetLastError(); 236 if (dwErr == NTE_BAD_KEYSET) 237 { 238 ret = CryptAcquireContext(&info->hCryptProv, 239 L"MSTSC", 240 MS_ENHANCED_PROV, 241 PROV_RSA_FULL, 242 CRYPT_NEWKEYSET); 243 } 244 } 245 if (!ret) 246 { 247 dwErr = GetLastError(); 248 g_free(info); 249 error("CryptAcquireContext failed with %lx\n", dwErr); 250 return NULL; 251 } 252 ret = CryptCreateHash(info->hCryptProv, 253 id, 254 0, 255 0, 256 &info->hHash); 257 if (!ret) 258 { 259 dwErr = GetLastError(); 260 CryptReleaseContext(info->hCryptProv, 0); 261 g_free(info); 262 error("CryptCreateHash failed with %lx\n", dwErr); 263 return NULL; 264 } 265 return info; 266 } 267 268 /*****************************************************************************/ 269 void 270 rdssl_hash_info_delete(void* hash_info) 271 { 272 struct hash_context *info = hash_info; 273 if (!info) 274 { 275 //error("ssl_hash_info_delete hash_info is null\n"); 276 return; 277 } 278 if (info->hHash) 279 { 280 CryptDestroyHash(info->hHash); 281 } 282 if (info->hCryptProv) 283 { 284 CryptReleaseContext(info->hCryptProv, 0); 285 } 286 g_free(hash_info); 287 } 288 289 /*****************************************************************************/ 290 void 291 rdssl_hash_clear(void* hash_info, ALG_ID id) 292 { 293 struct hash_context *info = hash_info; 294 BOOL ret; 295 DWORD dwErr; 296 if (!info || !info->hHash || !info->hCryptProv) 297 { 298 error("rdssl_hash_clear %p\n", info); 299 return; 300 } 301 ret = CryptDestroyHash(info->hHash); 302 if (!ret) 303 { 304 dwErr = GetLastError(); 305 error("CryptDestroyHash failed with %lx\n", dwErr); 306 return; 307 } 308 ret = CryptCreateHash(info->hCryptProv, 309 id, 310 0, 311 0, 312 &info->hHash); 313 if (!ret) 314 { 315 dwErr = GetLastError(); 316 error("CryptCreateHash failed with %lx\n", dwErr); 317 } 318 } 319 320 void 321 rdssl_hash_transform(void* hash_info, char* data, int len) 322 { 323 struct hash_context *info = hash_info; 324 BOOL ret; 325 DWORD dwErr; 326 if (!info || !info->hHash || !info->hCryptProv || !data || !len) 327 { 328 error("rdssl_hash_transform %p %p %d\n", hash_info, data, len); 329 return; 330 } 331 ret = CryptHashData(info->hHash, 332 (BYTE *)data, 333 len, 334 0); 335 if (!ret) 336 { 337 dwErr = GetLastError(); 338 error("CryptHashData failed with %lx\n", dwErr); 339 } 340 } 341 342 /*****************************************************************************/ 343 void 344 rdssl_hash_complete(void* hash_info, char* data) 345 { 346 struct hash_context *info = hash_info; 347 BOOL ret; 348 DWORD dwErr, dwDataLen; 349 if (!info || !info->hHash || !info->hCryptProv || !data) 350 { 351 error("rdssl_hash_complete %p %p\n", hash_info, data); 352 return; 353 } 354 ret = CryptGetHashParam(info->hHash, 355 HP_HASHVAL, 356 NULL, 357 &dwDataLen, 358 0); 359 if (!ret) 360 { 361 dwErr = GetLastError(); 362 error("CryptGetHashParam failed with %lx\n", dwErr); 363 return; 364 } 365 ret = CryptGetHashParam(info->hHash, 366 HP_HASHVAL, 367 (BYTE *)data, 368 &dwDataLen, 369 0); 370 if (!ret) 371 { 372 dwErr = GetLastError(); 373 error("CryptGetHashParam failed with %lx\n", dwErr); 374 } 375 } 376 377 /*****************************************************************************/ 378 void* 379 rdssl_sha1_info_create(void) 380 { 381 return rdssl_hash_info_create(CALG_SHA1); 382 } 383 384 /*****************************************************************************/ 385 void 386 rdssl_sha1_info_delete(void* sha1_info) 387 { 388 rdssl_hash_info_delete(sha1_info); 389 } 390 391 /*****************************************************************************/ 392 void 393 rdssl_sha1_clear(void* sha1_info) 394 { 395 rdssl_hash_clear(sha1_info, CALG_SHA1); 396 } 397 398 /*****************************************************************************/ 399 void 400 rdssl_sha1_transform(void* sha1_info, char* data, int len) 401 { 402 rdssl_hash_transform(sha1_info, data, len); 403 } 404 405 /*****************************************************************************/ 406 void 407 rdssl_sha1_complete(void* sha1_info, char* data) 408 { 409 rdssl_hash_complete(sha1_info, data); 410 } 411 412 /*****************************************************************************/ 413 void* 414 rdssl_md5_info_create(void) 415 { 416 return rdssl_hash_info_create(CALG_MD5); 417 } 418 419 /*****************************************************************************/ 420 void 421 rdssl_md5_info_delete(void* md5_info) 422 { 423 rdssl_hash_info_delete(md5_info); 424 } 425 426 /*****************************************************************************/ 427 void 428 rdssl_md5_clear(void* md5_info) 429 { 430 rdssl_hash_clear(md5_info, CALG_MD5); 431 } 432 433 /*****************************************************************************/ 434 void 435 rdssl_md5_transform(void* md5_info, char* data, int len) 436 { 437 rdssl_hash_transform(md5_info, data, len); 438 } 439 440 /*****************************************************************************/ 441 void 442 rdssl_md5_complete(void* md5_info, char* data) 443 { 444 rdssl_hash_complete(md5_info, data); 445 } 446 447 /*****************************************************************************/ 448 void 449 rdssl_hmac_md5(char* key, int keylen, char* data, int len, char* output) 450 { 451 HCRYPTPROV hCryptProv; 452 HCRYPTKEY hKey; 453 HCRYPTKEY hHash; 454 BOOL ret; 455 DWORD dwErr, dwDataLen; 456 HMAC_INFO info; 457 BYTE * blob; 458 PUBLICKEYSTRUC *desc; 459 DWORD * keySize; 460 BYTE * keyBuf; 461 BYTE sum[16]; 462 463 if (!key || !keylen || !data || !len ||!output) 464 { 465 error("rdssl_hmac_md5 %p %d %p %d %p\n", key, keylen, data, len, output); 466 return; 467 } 468 blob = g_malloc(sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + keylen, 0); 469 desc = (PUBLICKEYSTRUC *)blob; 470 keySize = (DWORD *)(blob + sizeof(PUBLICKEYSTRUC)); 471 keyBuf = blob + sizeof(PUBLICKEYSTRUC) + sizeof(DWORD); 472 if (!blob) 473 { 474 error("rdssl_hmac_md5 %d no memory\n"); 475 return; 476 } 477 ret = CryptAcquireContext(&hCryptProv, 478 L"MSTSC", 479 MS_ENHANCED_PROV, 480 PROV_RSA_FULL, 481 0); 482 if (!ret) 483 { 484 dwErr = GetLastError(); 485 if (dwErr == NTE_BAD_KEYSET) 486 { 487 ret = CryptAcquireContext(&hCryptProv, 488 L"MSTSC", 489 MS_ENHANCED_PROV, 490 PROV_RSA_FULL, 491 CRYPT_NEWKEYSET); 492 } 493 } 494 if (!ret) 495 { 496 dwErr = GetLastError(); 497 g_free(blob); 498 error("CryptAcquireContext failed with %lx\n", dwErr); 499 return; 500 } 501 desc->aiKeyAlg = CALG_RC4; 502 desc->bType = PLAINTEXTKEYBLOB; 503 desc->bVersion = CUR_BLOB_VERSION; 504 desc->reserved = 0; 505 if (keylen > 64) 506 { 507 HCRYPTKEY hHash; 508 ret = CryptCreateHash(hCryptProv, 509 CALG_MD5, 510 0, 511 0, 512 &hHash); 513 if (!ret) 514 { 515 dwErr = GetLastError(); 516 g_free(blob); 517 error("CryptCreateHash failed with %lx\n", dwErr); 518 return; 519 } 520 ret = CryptHashData(hHash, 521 (BYTE *)key, 522 keylen, 523 0); 524 if (!ret) 525 { 526 dwErr = GetLastError(); 527 g_free(blob); 528 error("CryptHashData failed with %lx\n", dwErr); 529 return; 530 } 531 ret = CryptGetHashParam(hHash, 532 HP_HASHVAL, 533 NULL, 534 &dwDataLen, 535 0); 536 if (!ret) 537 { 538 dwErr = GetLastError(); 539 g_free(blob); 540 error("CryptGetHashParam failed with %lx\n", dwErr); 541 return; 542 } 543 ret = CryptGetHashParam(hHash, 544 HP_HASHVAL, 545 sum, 546 &dwDataLen, 547 0); 548 if (!ret) 549 { 550 dwErr = GetLastError(); 551 g_free(blob); 552 error("CryptGetHashParam failed with %lx\n", dwErr); 553 return; 554 } 555 keylen = dwDataLen; 556 key = (char *)sum; 557 } 558 *keySize = keylen; 559 memcpy(keyBuf, key, keylen); 560 ret = CryptImportKey(hCryptProv, 561 blob, 562 sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + keylen, 563 0, 564 CRYPT_EXPORTABLE, 565 &hKey); 566 g_free(blob); 567 if (!ret) 568 { 569 dwErr = GetLastError(); 570 error("CryptImportKey failed with %lx\n", dwErr); 571 return; 572 } 573 ret = CryptCreateHash(hCryptProv, 574 CALG_HMAC, 575 hKey, 576 0, 577 &hHash); 578 if (!ret) 579 { 580 dwErr = GetLastError(); 581 error("CryptCreateHash failed with %lx\n", dwErr); 582 return; 583 } 584 info.HashAlgid = CALG_MD5; 585 info.cbInnerString = 0; 586 info.cbOuterString = 0; 587 ret = CryptSetHashParam(hHash, 588 HP_HMAC_INFO, 589 (BYTE *)&info, 590 0); 591 if (!ret) 592 { 593 dwErr = GetLastError(); 594 error("CryptSetHashParam failed with %lx\n", dwErr); 595 return; 596 } 597 ret = CryptHashData(hHash, 598 (BYTE *)data, 599 len, 600 0); 601 if (!ret) 602 { 603 dwErr = GetLastError(); 604 error("CryptHashData failed with %lx\n", dwErr); 605 return; 606 } 607 ret = CryptGetHashParam(hHash, 608 HP_HASHVAL, 609 NULL, 610 &dwDataLen, 611 0); 612 if (!ret) 613 { 614 dwErr = GetLastError(); 615 error("CryptGetHashParam failed with %lx\n", dwErr); 616 return; 617 } 618 ret = CryptGetHashParam(hHash, 619 HP_HASHVAL, 620 (BYTE *)output, 621 &dwDataLen, 622 0); 623 if (!ret) 624 { 625 dwErr = GetLastError(); 626 error("CryptGetHashParam failed with %lx\n", dwErr); 627 return; 628 } 629 CryptDestroyHash(hHash); 630 ret = CryptReleaseContext(hCryptProv, 0); 631 } 632 633 /*****************************************************************************/ 634 /*****************************************************************************/ 635 /* big number stuff */ 636 /******************* SHORT COPYRIGHT NOTICE************************* 637 This source code is part of the BigDigits multiple-precision 638 arithmetic library Version 1.0 originally written by David Ireland, 639 copyright (c) 2001 D.I. Management Services Pty Limited, all rights 640 reserved. It is provided "as is" with no warranties. You may use 641 this software under the terms of the full copyright notice 642 "bigdigitsCopyright.txt" that should have been included with 643 this library. To obtain a copy send an email to 644 <code@di-mgt.com.au> or visit <www.di-mgt.com.au/crypto.html>. 645 This notice must be retained in any copy. 646 ****************** END OF COPYRIGHT NOTICE*************************/ 647 /************************* COPYRIGHT NOTICE************************* 648 This source code is part of the BigDigits multiple-precision 649 arithmetic library Version 1.0 originally written by David Ireland, 650 copyright (c) 2001 D.I. Management Services Pty Limited, all rights 651 reserved. You are permitted to use compiled versions of this code as 652 part of your own executable files and to distribute unlimited copies 653 of such executable files for any purposes including commercial ones 654 provided you keep the copyright notices intact in the source code 655 and that you ensure that the following characters remain in any 656 object or executable files you distribute: 657 658 "Contains multiple-precision arithmetic code originally written 659 by David Ireland, copyright (c) 2001 by D.I. Management Services 660 Pty Limited <www.di-mgt.com.au>, and is used with permission." 661 662 David Ireland and DI Management Services Pty Limited make no 663 representations concerning either the merchantability of this 664 software or the suitability of this software for any particular 665 purpose. It is provided "as is" without express or implied warranty 666 of any kind. 667 668 Please forward any comments and bug reports to <code@di-mgt.com.au>. 669 The latest version of the source code can be downloaded from 670 www.di-mgt.com.au/crypto.html. 671 ****************** END OF COPYRIGHT NOTICE*************************/ 672 673 typedef unsigned int DIGIT_T; 674 #define HIBITMASK 0x80000000 675 #define MAX_DIG_LEN 51 676 #define MAX_DIGIT 0xffffffff 677 #define BITS_PER_DIGIT 32 678 #define MAX_HALF_DIGIT 0xffff 679 #define B_J (MAX_HALF_DIGIT + 1) 680 #define LOHALF(x) ((DIGIT_T)((x) & 0xffff)) 681 #define HIHALF(x) ((DIGIT_T)((x) >> 16 & 0xffff)) 682 #define TOHIGH(x) ((DIGIT_T)((x) << 16)) 683 684 #define mpNEXTBITMASK(mask, n) \ 685 { \ 686 if (mask == 1) \ 687 { \ 688 mask = HIBITMASK; \ 689 n--; \ 690 } \ 691 else \ 692 { \ 693 mask >>= 1; \ 694 } \ 695 } 696 697 /*****************************************************************************/ 698 static DIGIT_T 699 mpAdd(DIGIT_T* w, DIGIT_T* u, DIGIT_T* v, unsigned int ndigits) 700 { 701 /* Calculates w = u + v 702 where w, u, v are multiprecision integers of ndigits each 703 Returns carry if overflow. Carry = 0 or 1. 704 705 Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A. */ 706 DIGIT_T k; 707 unsigned int j; 708 709 /* Step A1. Initialise */ 710 k = 0; 711 for (j = 0; j < ndigits; j++) 712 { 713 /* Step A2. Add digits w_j = (u_j + v_j + k) 714 Set k = 1 if carry (overflow) occurs */ 715 w[j] = u[j] + k; 716 if (w[j] < k) 717 { 718 k = 1; 719 } 720 else 721 { 722 k = 0; 723 } 724 w[j] += v[j]; 725 if (w[j] < v[j]) 726 { 727 k++; 728 } 729 } /* Step A3. Loop on j */ 730 return k; /* w_n = k */ 731 } 732 733 /*****************************************************************************/ 734 static void 735 mpSetDigit(DIGIT_T* a, DIGIT_T d, unsigned int ndigits) 736 { /* Sets a = d where d is a single digit */ 737 unsigned int i; 738 739 for (i = 1; i < ndigits; i++) 740 { 741 a[i] = 0; 742 } 743 a[0] = d; 744 } 745 746 /*****************************************************************************/ 747 static int 748 mpCompare(DIGIT_T* a, DIGIT_T* b, unsigned int ndigits) 749 { 750 /* Returns sign of (a - b) */ 751 if (ndigits == 0) 752 { 753 return 0; 754 } 755 while (ndigits--) 756 { 757 if (a[ndigits] > b[ndigits]) 758 { 759 return 1; /* GT */ 760 } 761 if (a[ndigits] < b[ndigits]) 762 { 763 return -1; /* LT */ 764 } 765 } 766 return 0; /* EQ */ 767 } 768 769 /*****************************************************************************/ 770 static void 771 mpSetZero(DIGIT_T* a, unsigned int ndigits) 772 { /* Sets a = 0 */ 773 unsigned int i; 774 775 for (i = 0; i < ndigits; i++) 776 { 777 a[i] = 0; 778 } 779 } 780 781 /*****************************************************************************/ 782 static void 783 mpSetEqual(DIGIT_T* a, DIGIT_T* b, unsigned int ndigits) 784 { /* Sets a = b */ 785 unsigned int i; 786 787 for (i = 0; i < ndigits; i++) 788 { 789 a[i] = b[i]; 790 } 791 } 792 793 /*****************************************************************************/ 794 static unsigned int 795 mpSizeof(DIGIT_T* a, unsigned int ndigits) 796 { /* Returns size of significant digits in a */ 797 while (ndigits--) 798 { 799 if (a[ndigits] != 0) 800 { 801 return (++ndigits); 802 } 803 } 804 return 0; 805 } 806 807 /*****************************************************************************/ 808 static DIGIT_T 809 mpShiftLeft(DIGIT_T* a, DIGIT_T* b, unsigned int x, unsigned int ndigits) 810 { /* Computes a = b << x */ 811 unsigned int i; 812 unsigned int y; 813 DIGIT_T mask; 814 DIGIT_T carry; 815 DIGIT_T nextcarry; 816 817 /* Check input - NB unspecified result */ 818 if (x >= BITS_PER_DIGIT) 819 { 820 return 0; 821 } 822 /* Construct mask */ 823 mask = HIBITMASK; 824 for (i = 1; i < x; i++) 825 { 826 mask = (mask >> 1) | mask; 827 } 828 if (x == 0) 829 { 830 mask = 0x0; 831 } 832 y = BITS_PER_DIGIT - x; 833 carry = 0; 834 for (i = 0; i < ndigits; i++) 835 { 836 nextcarry = (b[i] & mask) >> y; 837 a[i] = b[i] << x | carry; 838 carry = nextcarry; 839 } 840 return carry; 841 } 842 843 /*****************************************************************************/ 844 static DIGIT_T 845 mpShiftRight(DIGIT_T* a, DIGIT_T* b, unsigned int x, unsigned int ndigits) 846 { /* Computes a = b >> x */ 847 unsigned int i; 848 unsigned int y; 849 DIGIT_T mask; 850 DIGIT_T carry; 851 DIGIT_T nextcarry; 852 853 /* Check input - NB unspecified result */ 854 if (x >= BITS_PER_DIGIT) 855 { 856 return 0; 857 } 858 /* Construct mask */ 859 mask = 0x1; 860 for (i = 1; i < x; i++) 861 { 862 mask = (mask << 1) | mask; 863 } 864 if (x == 0) 865 { 866 mask = 0x0; 867 } 868 y = BITS_PER_DIGIT - x; 869 carry = 0; 870 i = ndigits; 871 while (i--) 872 { 873 nextcarry = (b[i] & mask) << y; 874 a[i] = b[i] >> x | carry; 875 carry = nextcarry; 876 } 877 return carry; 878 } 879 880 /*****************************************************************************/ 881 static void 882 spMultSub(DIGIT_T* uu, DIGIT_T qhat, DIGIT_T v1, DIGIT_T v0) 883 { 884 /* Compute uu = uu - q(v1v0) 885 where uu = u3u2u1u0, u3 = 0 886 and u_n, v_n are all half-digits 887 even though v1, v2 are passed as full digits. */ 888 DIGIT_T p0; 889 DIGIT_T p1; 890 DIGIT_T t; 891 892 p0 = qhat * v0; 893 p1 = qhat * v1; 894 t = p0 + TOHIGH(LOHALF(p1)); 895 uu[0] -= t; 896 if (uu[0] > MAX_DIGIT - t) 897 { 898 uu[1]--; /* Borrow */ 899 } 900 uu[1] -= HIHALF(p1); 901 } 902 903 /*****************************************************************************/ 904 static int 905 spMultiply(DIGIT_T* p, DIGIT_T x, DIGIT_T y) 906 { /* Computes p = x * y */ 907 /* Ref: Arbitrary Precision Computation 908 http://numbers.computation.free.fr/Constants/constants.html 909 910 high p1 p0 low 911 +--------+--------+--------+--------+ 912 | x1*y1 | x0*y0 | 913 +--------+--------+--------+--------+ 914 +-+--------+--------+ 915 |1| (x0*y1 + x1*y1) | 916 +-+--------+--------+ 917 ^carry from adding (x0*y1+x1*y1) together 918 +-+ 919 |1|< carry from adding LOHALF t 920 +-+ to high half of p0 */ 921 DIGIT_T x0; 922 DIGIT_T y0; 923 DIGIT_T x1; 924 DIGIT_T y1; 925 DIGIT_T t; 926 DIGIT_T u; 927 DIGIT_T carry; 928 929 /* Split each x,y into two halves 930 x = x0 + B * x1 931 y = y0 + B * y1 932 where B = 2^16, half the digit size 933 Product is 934 xy = x0y0 + B(x0y1 + x1y0) + B^2(x1y1) */ 935 936 x0 = LOHALF(x); 937 x1 = HIHALF(x); 938 y0 = LOHALF(y); 939 y1 = HIHALF(y); 940 941 /* Calc low part - no carry */ 942 p[0] = x0 * y0; 943 944 /* Calc middle part */ 945 t = x0 * y1; 946 u = x1 * y0; 947 t += u; 948 if (t < u) 949 { 950 carry = 1; 951 } 952 else 953 { 954 carry = 0; 955 } 956 /* This carry will go to high half of p[1] 957 + high half of t into low half of p[1] */ 958 carry = TOHIGH(carry) + HIHALF(t); 959 960 /* Add low half of t to high half of p[0] */ 961 t = TOHIGH(t); 962 p[0] += t; 963 if (p[0] < t) 964 { 965 carry++; 966 } 967 968 p[1] = x1 * y1; 969 p[1] += carry; 970 971 return 0; 972 } 973 974 /*****************************************************************************/ 975 static DIGIT_T 976 spDivide(DIGIT_T* q, DIGIT_T* r, DIGIT_T* u, DIGIT_T v) 977 { /* Computes quotient q = u / v, remainder r = u mod v 978 where u is a double digit 979 and q, v, r are single precision digits. 980 Returns high digit of quotient (max value is 1) 981 Assumes normalised such that v1 >= b/2 982 where b is size of HALF_DIGIT 983 i.e. the most significant bit of v should be one 984 985 In terms of half-digits in Knuth notation: 986 (q2q1q0) = (u4u3u2u1u0) / (v1v0) 987 (r1r0) = (u4u3u2u1u0) mod (v1v0) 988 for m = 2, n = 2 where u4 = 0 989 q2 is either 0 or 1. 990 We set q = (q1q0) and return q2 as "overflow' */ 991 DIGIT_T qhat; 992 DIGIT_T rhat; 993 DIGIT_T t; 994 DIGIT_T v0; 995 DIGIT_T v1; 996 DIGIT_T u0; 997 DIGIT_T u1; 998 DIGIT_T u2; 999 DIGIT_T u3; 1000 DIGIT_T uu[2]; 1001 DIGIT_T q2; 1002 1003 /* Check for normalisation */ 1004 if (!(v & HIBITMASK)) 1005 { 1006 *q = *r = 0; 1007 return MAX_DIGIT; 1008 } 1009 1010 /* Split up into half-digits */ 1011 v0 = LOHALF(v); 1012 v1 = HIHALF(v); 1013 u0 = LOHALF(u[0]); 1014 u1 = HIHALF(u[0]); 1015 u2 = LOHALF(u[1]); 1016 u3 = HIHALF(u[1]); 1017 1018 /* Do three rounds of Knuth Algorithm D Vol 2 p272 */ 1019 1020 /* ROUND 1. Set j = 2 and calculate q2 */ 1021 /* Estimate qhat = (u4u3)/v1 = 0 or 1 1022 then set (u4u3u2) -= qhat(v1v0) 1023 where u4 = 0. */ 1024 qhat = u3 / v1; 1025 if (qhat > 0) 1026 { 1027 rhat = u3 - qhat * v1; 1028 t = TOHIGH(rhat) | u2; 1029 if (qhat * v0 > t) 1030 { 1031 qhat--; 1032 } 1033 } 1034 uu[1] = 0; /* (u4) */ 1035 uu[0] = u[1]; /* (u3u2) */ 1036 if (qhat > 0) 1037 { 1038 /* (u4u3u2) -= qhat(v1v0) where u4 = 0 */ 1039 spMultSub(uu, qhat, v1, v0); 1040 if (HIHALF(uu[1]) != 0) 1041 { /* Add back */ 1042 qhat--; 1043 uu[0] += v; 1044 uu[1] = 0; 1045 } 1046 } 1047 q2 = qhat; 1048 /* ROUND 2. Set j = 1 and calculate q1 */ 1049 /* Estimate qhat = (u3u2) / v1 1050 then set (u3u2u1) -= qhat(v1v0) */ 1051 t = uu[0]; 1052 qhat = t / v1; 1053 rhat = t - qhat * v1; 1054 /* Test on v0 */ 1055 t = TOHIGH(rhat) | u1; 1056 if ((qhat == B_J) || (qhat * v0 > t)) 1057 { 1058 qhat--; 1059 rhat += v1; 1060 t = TOHIGH(rhat) | u1; 1061 if ((rhat < B_J) && (qhat * v0 > t)) 1062 { 1063 qhat--; 1064 } 1065 } 1066 /* Multiply and subtract 1067 (u3u2u1)' = (u3u2u1) - qhat(v1v0) */ 1068 uu[1] = HIHALF(uu[0]); /* (0u3) */ 1069 uu[0] = TOHIGH(LOHALF(uu[0])) | u1; /* (u2u1) */ 1070 spMultSub(uu, qhat, v1, v0); 1071 if (HIHALF(uu[1]) != 0) 1072 { /* Add back */ 1073 qhat--; 1074 uu[0] += v; 1075 uu[1] = 0; 1076 } 1077 /* q1 = qhat */ 1078 *q = TOHIGH(qhat); 1079 /* ROUND 3. Set j = 0 and calculate q0 */ 1080 /* Estimate qhat = (u2u1) / v1 1081 then set (u2u1u0) -= qhat(v1v0) */ 1082 t = uu[0]; 1083 qhat = t / v1; 1084 rhat = t - qhat * v1; 1085 /* Test on v0 */ 1086 t = TOHIGH(rhat) | u0; 1087 if ((qhat == B_J) || (qhat * v0 > t)) 1088 { 1089 qhat--; 1090 rhat += v1; 1091 t = TOHIGH(rhat) | u0; 1092 if ((rhat < B_J) && (qhat * v0 > t)) 1093 { 1094 qhat--; 1095 } 1096 } 1097 /* Multiply and subtract 1098 (u2u1u0)" = (u2u1u0)' - qhat(v1v0) */ 1099 uu[1] = HIHALF(uu[0]); /* (0u2) */ 1100 uu[0] = TOHIGH(LOHALF(uu[0])) | u0; /* (u1u0) */ 1101 spMultSub(uu, qhat, v1, v0); 1102 if (HIHALF(uu[1]) != 0) 1103 { /* Add back */ 1104 qhat--; 1105 uu[0] += v; 1106 uu[1] = 0; 1107 } 1108 /* q0 = qhat */ 1109 *q |= LOHALF(qhat); 1110 /* Remainder is in (u1u0) i.e. uu[0] */ 1111 *r = uu[0]; 1112 return q2; 1113 } 1114 1115 /*****************************************************************************/ 1116 static int 1117 QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, DIGIT_T vn2, DIGIT_T ujn2) 1118 { /* Returns true if Qhat is too big 1119 i.e. if (Qhat * Vn-2) > (b.Rhat + Uj+n-2) */ 1120 DIGIT_T t[2]; 1121 1122 spMultiply(t, qhat, vn2); 1123 if (t[1] < rhat) 1124 { 1125 return 0; 1126 } 1127 else if (t[1] > rhat) 1128 { 1129 return 1; 1130 } 1131 else if (t[0] > ujn2) 1132 { 1133 return 1; 1134 } 1135 return 0; 1136 } 1137 1138 /*****************************************************************************/ 1139 static DIGIT_T 1140 mpShortDiv(DIGIT_T* q, DIGIT_T* u, DIGIT_T v, unsigned int ndigits) 1141 { 1142 /* Calculates quotient q = u div v 1143 Returns remainder r = u mod v 1144 where q, u are multiprecision integers of ndigits each 1145 and d, v are single precision digits. 1146 1147 Makes no assumptions about normalisation. 1148 1149 Ref: Knuth Vol 2 Ch 4.3.1 Exercise 16 p625 */ 1150 unsigned int j; 1151 unsigned int shift; 1152 DIGIT_T t[2]; 1153 DIGIT_T r; 1154 DIGIT_T bitmask; 1155 DIGIT_T overflow; 1156 DIGIT_T* uu; 1157 1158 if (ndigits == 0) 1159 { 1160 return 0; 1161 } 1162 if (v == 0) 1163 { 1164 return 0; /* Divide by zero error */ 1165 } 1166 /* Normalise first */ 1167 /* Requires high bit of V 1168 to be set, so find most signif. bit then shift left, 1169 i.e. d = 2^shift, u' = u * d, v' = v * d. */ 1170 bitmask = HIBITMASK; 1171 for (shift = 0; shift < BITS_PER_DIGIT; shift++) 1172 { 1173 if (v & bitmask) 1174 { 1175 break; 1176 } 1177 bitmask >>= 1; 1178 } 1179 v <<= shift; 1180 overflow = mpShiftLeft(q, u, shift, ndigits); 1181 uu = q; 1182 /* Step S1 - modified for extra digit. */ 1183 r = overflow; /* New digit Un */ 1184 j = ndigits; 1185 while (j--) 1186 { 1187 /* Step S2. */ 1188 t[1] = r; 1189 t[0] = uu[j]; 1190 overflow = spDivide(&q[j], &r, t, v); 1191 } 1192 /* Unnormalise */ 1193 r >>= shift; 1194 return r; 1195 } 1196 1197 /*****************************************************************************/ 1198 static DIGIT_T 1199 mpMultSub(DIGIT_T wn, DIGIT_T* w, DIGIT_T* v, DIGIT_T q, unsigned int n) 1200 { /* Compute w = w - qv 1201 where w = (WnW[n-1]...W[0]) 1202 return modified Wn. */ 1203 DIGIT_T k; 1204 DIGIT_T t[2]; 1205 unsigned int i; 1206 1207 if (q == 0) /* No change */ 1208 { 1209 return wn; 1210 } 1211 k = 0; 1212 for (i = 0; i < n; i++) 1213 { 1214 spMultiply(t, q, v[i]); 1215 w[i] -= k; 1216 if (w[i] > MAX_DIGIT - k) 1217 { 1218 k = 1; 1219 } 1220 else 1221 { 1222 k = 0; 1223 } 1224 w[i] -= t[0]; 1225 if (w[i] > MAX_DIGIT - t[0]) 1226 { 1227 k++; 1228 } 1229 k += t[1]; 1230 } 1231 /* Cope with Wn not stored in array w[0..n-1] */ 1232 wn -= k; 1233 return wn; 1234 } 1235 1236 /*****************************************************************************/ 1237 static int 1238 mpDivide(DIGIT_T* q, DIGIT_T* r, DIGIT_T* u, unsigned int udigits, 1239 DIGIT_T* v, unsigned int vdigits) 1240 { /* Computes quotient q = u / v and remainder r = u mod v 1241 where q, r, u are multiple precision digits 1242 all of udigits and the divisor v is vdigits. 1243 1244 Ref: Knuth Vol 2 Ch 4.3.1 p 272 Algorithm D. 1245 1246 Do without extra storage space, i.e. use r[] for 1247 normalised u[], unnormalise v[] at end, and cope with 1248 extra digit Uj+n added to u after normalisation. 1249 1250 WARNING: this trashes q and r first, so cannot do 1251 u = u / v or v = u mod v. */ 1252 unsigned int shift; 1253 int n; 1254 int m; 1255 int j; 1256 int qhatOK; 1257 int cmp; 1258 DIGIT_T bitmask; 1259 DIGIT_T overflow; 1260 DIGIT_T qhat; 1261 DIGIT_T rhat; 1262 DIGIT_T t[2]; 1263 DIGIT_T* uu; 1264 DIGIT_T* ww; 1265 1266 /* Clear q and r */ 1267 mpSetZero(q, udigits); 1268 mpSetZero(r, udigits); 1269 /* Work out exact sizes of u and v */ 1270 n = (int)mpSizeof(v, vdigits); 1271 m = (int)mpSizeof(u, udigits); 1272 m -= n; 1273 /* Catch special cases */ 1274 if (n == 0) 1275 { 1276 return -1; /* Error: divide by zero */ 1277 } 1278 if (n == 1) 1279 { /* Use short division instead */ 1280 r[0] = mpShortDiv(q, u, v[0], udigits); 1281 return 0; 1282 } 1283 if (m < 0) 1284 { /* v > u, so just set q = 0 and r = u */ 1285 mpSetEqual(r, u, udigits); 1286 return 0; 1287 } 1288 if (m == 0) 1289 { /* u and v are the same length */ 1290 cmp = mpCompare(u, v, (unsigned int)n); 1291 if (cmp < 0) 1292 { /* v > u, as above */ 1293 mpSetEqual(r, u, udigits); 1294 return 0; 1295 } 1296 else if (cmp == 0) 1297 { /* v == u, so set q = 1 and r = 0 */ 1298 mpSetDigit(q, 1, udigits); 1299 return 0; 1300 } 1301 } 1302 /* In Knuth notation, we have: 1303 Given 1304 u = (Um+n-1 ... U1U0) 1305 v = (Vn-1 ... V1V0) 1306 Compute 1307 q = u/v = (QmQm-1 ... Q0) 1308 r = u mod v = (Rn-1 ... R1R0) */ 1309 /* Step D1. Normalise */ 1310 /* Requires high bit of Vn-1 1311 to be set, so find most signif. bit then shift left, 1312 i.e. d = 2^shift, u' = u * d, v' = v * d. */ 1313 bitmask = HIBITMASK; 1314 for (shift = 0; shift < BITS_PER_DIGIT; shift++) 1315 { 1316 if (v[n - 1] & bitmask) 1317 { 1318 break; 1319 } 1320 bitmask >>= 1; 1321 } 1322 /* Normalise v in situ - NB only shift non-zero digits */ 1323 overflow = mpShiftLeft(v, v, shift, n); 1324 /* Copy normalised dividend u*d into r */ 1325 overflow = mpShiftLeft(r, u, shift, n + m); 1326 uu = r; /* Use ptr to keep notation constant */ 1327 t[0] = overflow; /* New digit Um+n */ 1328 /* Step D2. Initialise j. Set j = m */ 1329 for (j = m; j >= 0; j--) 1330 { 1331 /* Step D3. Calculate Qhat = (b.Uj+n + Uj+n-1)/Vn-1 */ 1332 qhatOK = 0; 1333 t[1] = t[0]; /* This is Uj+n */ 1334 t[0] = uu[j+n-1]; 1335 overflow = spDivide(&qhat, &rhat, t, v[n - 1]); 1336 /* Test Qhat */ 1337 if (overflow) 1338 { /* Qhat = b */ 1339 qhat = MAX_DIGIT; 1340 rhat = uu[j + n - 1]; 1341 rhat += v[n - 1]; 1342 if (rhat < v[n - 1]) /* Overflow */ 1343 { 1344 qhatOK = 1; 1345 } 1346 } 1347 if (!qhatOK && QhatTooBig(qhat, rhat, v[n - 2], uu[j + n - 2])) 1348 { /* Qhat.Vn-2 > b.Rhat + Uj+n-2 */ 1349 qhat--; 1350 rhat += v[n - 1]; 1351 if (!(rhat < v[n - 1])) 1352 { 1353 if (QhatTooBig(qhat, rhat, v[n - 2], uu[j + n - 2])) 1354 { 1355 qhat--; 1356 } 1357 } 1358 } 1359 /* Step D4. Multiply and subtract */ 1360 ww = &uu[j]; 1361 overflow = mpMultSub(t[1], ww, v, qhat, (unsigned int)n); 1362 /* Step D5. Test remainder. Set Qj = Qhat */ 1363 q[j] = qhat; 1364 if (overflow) 1365 { /* Step D6. Add back if D4 was negative */ 1366 q[j]--; 1367 overflow = mpAdd(ww, ww, v, (unsigned int)n); 1368 } 1369 t[0] = uu[j + n - 1]; /* Uj+n on next round */ 1370 } /* Step D7. Loop on j */ 1371 /* Clear high digits in uu */ 1372 for (j = n; j < m+n; j++) 1373 { 1374 uu[j] = 0; 1375 } 1376 /* Step D8. Unnormalise. */ 1377 mpShiftRight(r, r, shift, n); 1378 mpShiftRight(v, v, shift, n); 1379 return 0; 1380 } 1381 1382 /*****************************************************************************/ 1383 static int 1384 mpModulo(DIGIT_T* r, DIGIT_T* u, unsigned int udigits, 1385 DIGIT_T* v, unsigned int vdigits) 1386 { 1387 /* Calculates r = u mod v 1388 where r, v are multiprecision integers of length vdigits 1389 and u is a multiprecision integer of length udigits. 1390 r may overlap v. 1391 1392 Note that r here is only vdigits long, 1393 whereas in mpDivide it is udigits long. 1394 1395 Use remainder from mpDivide function. */ 1396 /* Double-length temp variable for divide fn */ 1397 DIGIT_T qq[MAX_DIG_LEN * 2]; 1398 /* Use a double-length temp for r to allow overlap of r and v */ 1399 DIGIT_T rr[MAX_DIG_LEN * 2]; 1400 1401 /* rr[2n] = u[2n] mod v[n] */ 1402 mpDivide(qq, rr, u, udigits, v, vdigits); 1403 mpSetEqual(r, rr, vdigits); 1404 mpSetZero(rr, udigits); 1405 mpSetZero(qq, udigits); 1406 return 0; 1407 } 1408 1409 /*****************************************************************************/ 1410 static int 1411 mpMultiply(DIGIT_T* w, DIGIT_T* u, DIGIT_T* v, unsigned int ndigits) 1412 { 1413 /* Computes product w = u * v 1414 where u, v are multiprecision integers of ndigits each 1415 and w is a multiprecision integer of 2*ndigits 1416 Ref: Knuth Vol 2 Ch 4.3.1 p 268 Algorithm M. */ 1417 DIGIT_T k; 1418 DIGIT_T t[2]; 1419 unsigned int i; 1420 unsigned int j; 1421 unsigned int m; 1422 unsigned int n; 1423 1424 n = ndigits; 1425 m = n; 1426 /* Step M1. Initialise */ 1427 for (i = 0; i < 2 * m; i++) 1428 { 1429 w[i] = 0; 1430 } 1431 for (j = 0; j < n; j++) 1432 { 1433 /* Step M2. Zero multiplier? */ 1434 if (v[j] == 0) 1435 { 1436 w[j + m] = 0; 1437 } 1438 else 1439 { 1440 /* Step M3. Initialise i */ 1441 k = 0; 1442 for (i = 0; i < m; i++) 1443 { 1444 /* Step M4. Multiply and add */ 1445 /* t = u_i * v_j + w_(i+j) + k */ 1446 spMultiply(t, u[i], v[j]); 1447 t[0] += k; 1448 if (t[0] < k) 1449 { 1450 t[1]++; 1451 } 1452 t[0] += w[i + j]; 1453 if (t[0] < w[i+j]) 1454 { 1455 t[1]++; 1456 } 1457 w[i + j] = t[0]; 1458 k = t[1]; 1459 } 1460 /* Step M5. Loop on i, set w_(j+m) = k */ 1461 w[j + m] = k; 1462 } 1463 } /* Step M6. Loop on j */ 1464 return 0; 1465 } 1466 1467 /*****************************************************************************/ 1468 static int 1469 mpModMult(DIGIT_T* a, DIGIT_T* x, DIGIT_T* y, 1470 DIGIT_T* m, unsigned int ndigits) 1471 { /* Computes a = (x * y) mod m */ 1472 /* Double-length temp variable */ 1473 DIGIT_T p[MAX_DIG_LEN * 2]; 1474 1475 /* Calc p[2n] = x * y */ 1476 mpMultiply(p, x, y, ndigits); 1477 /* Then modulo */ 1478 mpModulo(a, p, ndigits * 2, m, ndigits); 1479 mpSetZero(p, ndigits * 2); 1480 return 0; 1481 } 1482 1483 /*****************************************************************************/ 1484 int 1485 rdssl_mod_exp(char* out, int out_len, char* in, int in_len, 1486 char* mod, int mod_len, char* exp, int exp_len) 1487 { 1488 /* Computes y = x ^ e mod m */ 1489 /* Binary left-to-right method */ 1490 DIGIT_T mask; 1491 DIGIT_T* e; 1492 DIGIT_T* x; 1493 DIGIT_T* y; 1494 DIGIT_T* m; 1495 unsigned int n; 1496 int max_size; 1497 char* l_out; 1498 char* l_in; 1499 char* l_mod; 1500 char* l_exp; 1501 1502 if (in_len > out_len || in_len == 0 || 1503 out_len == 0 || mod_len == 0 || exp_len == 0) 1504 { 1505 return 0; 1506 } 1507 max_size = out_len; 1508 if (in_len > max_size) 1509 { 1510 max_size = in_len; 1511 } 1512 if (mod_len > max_size) 1513 { 1514 max_size = mod_len; 1515 } 1516 if (exp_len > max_size) 1517 { 1518 max_size = exp_len; 1519 } 1520 l_out = (char*)g_malloc(max_size, 1); 1521 l_in = (char*)g_malloc(max_size, 1); 1522 l_mod = (char*)g_malloc(max_size, 1); 1523 l_exp = (char*)g_malloc(max_size, 1); 1524 memcpy(l_in, in, in_len); 1525 memcpy(l_mod, mod, mod_len); 1526 memcpy(l_exp, exp, exp_len); 1527 e = (DIGIT_T*)l_exp; 1528 x = (DIGIT_T*)l_in; 1529 y = (DIGIT_T*)l_out; 1530 m = (DIGIT_T*)l_mod; 1531 /* Find second-most significant bit in e */ 1532 n = mpSizeof(e, max_size / 4); 1533 for (mask = HIBITMASK; mask > 0; mask >>= 1) 1534 { 1535 if (e[n - 1] & mask) 1536 { 1537 break; 1538 } 1539 } 1540 mpNEXTBITMASK(mask, n); 1541 /* Set y = x */ 1542 mpSetEqual(y, x, max_size / 4); 1543 /* For bit j = k - 2 downto 0 step -1 */ 1544 while (n) 1545 { 1546 mpModMult(y, y, y, m, max_size / 4); /* Square */ 1547 if (e[n - 1] & mask) 1548 { 1549 mpModMult(y, y, x, m, max_size / 4); /* Multiply */ 1550 } 1551 /* Move to next bit */ 1552 mpNEXTBITMASK(mask, n); 1553 } 1554 memcpy(out, l_out, out_len); 1555 g_free(l_out); 1556 g_free(l_in); 1557 g_free(l_mod); 1558 g_free(l_exp); 1559 return out_len; 1560 } 1561 1562 static uint8 g_ppk_n[72] = 1563 { 1564 0x3D, 0x3A, 0x5E, 0xBD, 0x72, 0x43, 0x3E, 0xC9, 1565 0x4D, 0xBB, 0xC1, 0x1E, 0x4A, 0xBA, 0x5F, 0xCB, 1566 0x3E, 0x88, 0x20, 0x87, 0xEF, 0xF5, 0xC1, 0xE2, 1567 0xD7, 0xB7, 0x6B, 0x9A, 0xF2, 0x52, 0x45, 0x95, 1568 0xCE, 0x63, 0x65, 0x6B, 0x58, 0x3A, 0xFE, 0xEF, 1569 0x7C, 0xE7, 0xBF, 0xFE, 0x3D, 0xF6, 0x5C, 0x7D, 1570 0x6C, 0x5E, 0x06, 0x09, 0x1A, 0xF5, 0x61, 0xBB, 1571 0x20, 0x93, 0x09, 0x5F, 0x05, 0x6D, 0xEA, 0x87, 1572 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 1573 }; 1574 1575 static uint8 g_ppk_d[108] = 1576 { 1577 0x87, 0xA7, 0x19, 0x32, 0xDA, 0x11, 0x87, 0x55, 1578 0x58, 0x00, 0x16, 0x16, 0x25, 0x65, 0x68, 0xF8, 1579 0x24, 0x3E, 0xE6, 0xFA, 0xE9, 0x67, 0x49, 0x94, 1580 0xCF, 0x92, 0xCC, 0x33, 0x99, 0xE8, 0x08, 0x60, 1581 0x17, 0x9A, 0x12, 0x9F, 0x24, 0xDD, 0xB1, 0x24, 1582 0x99, 0xC7, 0x3A, 0xB8, 0x0A, 0x7B, 0x0D, 0xDD, 1583 0x35, 0x07, 0x79, 0x17, 0x0B, 0x51, 0x9B, 0xB3, 1584 0xC7, 0x10, 0x01, 0x13, 0xE7, 0x3F, 0xF3, 0x5F, 1585 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 1586 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 1587 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 1588 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 1589 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 1590 0x00, 0x00, 0x00, 0x00 1591 }; 1592 1593 int 1594 rdssl_sign_ok(char* e_data, int e_len, char* n_data, int n_len, 1595 char* sign_data, int sign_len, char* sign_data2, int sign_len2, char* testkey) 1596 { 1597 char* key; 1598 char* md5_final; 1599 void* md5; 1600 1601 if ((e_len != 4) || (n_len != 64) || (sign_len != 64) || (sign_len2 != 64)) 1602 { 1603 return 1; 1604 } 1605 md5 = rdssl_md5_info_create(); 1606 if (!md5) 1607 { 1608 return 1; 1609 } 1610 key = (char*)xmalloc(176); 1611 md5_final = (char*)xmalloc(64); 1612 // copy the test key 1613 memcpy(key, testkey, 176); 1614 // replace e and n 1615 memcpy(key + 32, e_data, 4); 1616 memcpy(key + 36, n_data, 64); 1617 rdssl_md5_clear(md5); 1618 // the first 108 bytes 1619 rdssl_md5_transform(md5, key, 108); 1620 // set the whole thing with 0xff 1621 memset(md5_final, 0xff, 64); 1622 // digest 16 bytes 1623 rdssl_md5_complete(md5, md5_final); 1624 // set non 0xff array items 1625 md5_final[16] = 0; 1626 md5_final[62] = 1; 1627 md5_final[63] = 0; 1628 // encrypt 1629 rdssl_mod_exp(sign_data, 64, md5_final, 64, (char*)g_ppk_n, 64, 1630 (char*)g_ppk_d, 64); 1631 // cleanup 1632 rdssl_md5_info_delete(md5); 1633 xfree(key); 1634 xfree(md5_final); 1635 return memcmp(sign_data, sign_data2, sign_len2); 1636 } 1637 1638 /*****************************************************************************/ 1639 PCCERT_CONTEXT rdssl_cert_read(uint8 * data, uint32 len) 1640 { 1641 PCCERT_CONTEXT res; 1642 if (!data || !len) 1643 { 1644 error("rdssl_cert_read %p %ld\n", data, len); 1645 return NULL; 1646 } 1647 res = CertCreateCertificateContext(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, data, len); 1648 if (!res) 1649 { 1650 error("CertCreateCertificateContext call failed with %lx\n", GetLastError()); 1651 } 1652 return res; 1653 } 1654 1655 /*****************************************************************************/ 1656 void rdssl_cert_free(PCCERT_CONTEXT context) 1657 { 1658 if (context) 1659 CertFreeCertificateContext(context); 1660 } 1661 1662 /*****************************************************************************/ 1663 uint8 *rdssl_cert_to_rkey(PCCERT_CONTEXT cert, uint32 * key_len) 1664 { 1665 HCRYPTPROV hCryptProv; 1666 HCRYPTKEY hKey; 1667 BOOL ret; 1668 BYTE * rkey; 1669 DWORD dwSize, dwErr; 1670 ret = CryptAcquireContext(&hCryptProv, 1671 NULL, 1672 MS_ENHANCED_PROV, 1673 PROV_RSA_FULL, 1674 0); 1675 if (!ret) 1676 { 1677 dwErr = GetLastError(); 1678 if (dwErr == NTE_BAD_KEYSET) 1679 { 1680 ret = CryptAcquireContext(&hCryptProv, 1681 L"MSTSC", 1682 MS_ENHANCED_PROV, 1683 PROV_RSA_FULL, 1684 CRYPT_NEWKEYSET); 1685 } 1686 } 1687 if (!ret) 1688 { 1689 dwErr = GetLastError(); 1690 error("CryptAcquireContext call failed with %lx\n", dwErr); 1691 return NULL; 1692 } 1693 ret = CryptImportPublicKeyInfoEx(hCryptProv, 1694 X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, 1695 &(cert->pCertInfo->SubjectPublicKeyInfo), 1696 0, 1697 0, 1698 NULL, 1699 &hKey); 1700 if (!ret) 1701 { 1702 dwErr = GetLastError(); 1703 CryptReleaseContext(hCryptProv, 0); 1704 error("CryptImportPublicKeyInfoEx call failed with %lx\n", dwErr); 1705 return NULL; 1706 } 1707 ret = CryptExportKey(hKey, 1708 0, 1709 PUBLICKEYBLOB, 1710 0, 1711 NULL, 1712 &dwSize); 1713 if (!ret) 1714 { 1715 dwErr = GetLastError(); 1716 CryptDestroyKey(hKey); 1717 CryptReleaseContext(hCryptProv, 0); 1718 error("CryptExportKey call failed with %lx\n", dwErr); 1719 return NULL; 1720 } 1721 rkey = g_malloc(dwSize, 0); 1722 ret = CryptExportKey(hKey, 1723 0, 1724 PUBLICKEYBLOB, 1725 0, 1726 rkey, 1727 &dwSize); 1728 if (!ret) 1729 { 1730 dwErr = GetLastError(); 1731 g_free(rkey); 1732 CryptDestroyKey(hKey); 1733 CryptReleaseContext(hCryptProv, 0); 1734 error("CryptExportKey call failed with %lx\n", dwErr); 1735 return NULL; 1736 } 1737 CryptDestroyKey(hKey); 1738 CryptReleaseContext(hCryptProv, 0); 1739 return rkey; 1740 } 1741 1742 /*****************************************************************************/ 1743 RD_BOOL rdssl_certs_ok(PCCERT_CONTEXT server_cert, PCCERT_CONTEXT cacert) 1744 { 1745 /* FIXME should we check for expired certificates??? */ 1746 DWORD dwFlags = CERT_STORE_SIGNATURE_FLAG; /* CERT_STORE_TIME_VALIDITY_FLAG */ 1747 BOOL ret = CertVerifySubjectCertificateContext(server_cert, 1748 cacert, 1749 &dwFlags); 1750 if (!ret) 1751 { 1752 error("CertVerifySubjectCertificateContext call failed with %lx\n", GetLastError()); 1753 } 1754 if (dwFlags) 1755 { 1756 error("CertVerifySubjectCertificateContext check failed %lx\n", dwFlags); 1757 } 1758 return (dwFlags == 0); 1759 } 1760 1761 /*****************************************************************************/ 1762 int rdssl_rkey_get_exp_mod(uint8 * rkey, uint8 * exponent, uint32 max_exp_len, uint8 * modulus, 1763 uint32 max_mod_len) 1764 { 1765 RSAPUBKEY *desc = (RSAPUBKEY *)(rkey + sizeof(PUBLICKEYSTRUC)); 1766 if (!rkey || !exponent || !max_exp_len || !modulus || !max_mod_len) 1767 { 1768 error("rdssl_rkey_get_exp_mod %p %p %ld %p %ld\n", rkey, exponent, max_exp_len, modulus, max_mod_len); 1769 return -1; 1770 } 1771 memcpy (exponent, &desc->pubexp, max_exp_len); 1772 memcpy (modulus, rkey + sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY), max_mod_len); 1773 return 0; 1774 } 1775 1776 /*****************************************************************************/ 1777 void rdssl_rkey_free(uint8 * rkey) 1778 { 1779 if (!rkey) 1780 { 1781 error("rdssl_rkey_free rkey is null\n"); 1782 return; 1783 } 1784 g_free(rkey); 1785 } 1786