1 /*
2 This program is free software; you can redistribute it and/or modify
3 it under the terms of the GNU General Public License as published by
4 the Free Software Foundation; either version 2 of the License, or
5 (at your option) any later version.
6
7 This program is distributed in the hope that it will be useful,
8 but WITHOUT ANY WARRANTY; without even the implied warranty of
9 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 GNU General Public License for more details.
11
12 You should have received a copy of the GNU General Public License along
13 with this program; if not, write to the Free Software Foundation, Inc.,
14 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
15
16 xrdp: A Remote Desktop Protocol server.
17 Copyright (C) Jay Sorg 2004-2005
18
19 ssl calls
20
21 */
22
23 #include "precomp.h"
24
25 /*****************************************************************************/
g_malloc(int size,int zero)26 static void * g_malloc(int size, int zero)
27 {
28 void * p;
29
30 p = CryptMemAlloc(size);
31 if (zero)
32 {
33 memset(p, 0, size);
34 }
35 return p;
36 }
37
38 /*****************************************************************************/
g_free(void * in)39 static void g_free(void * in)
40 {
41 CryptMemFree(in);
42 }
43
44 struct rc4_state
45 {
46 HCRYPTPROV hCryptProv;
47 HCRYPTKEY hKey;
48 };
49 /*****************************************************************************/
50 void*
rdssl_rc4_info_create(void)51 rdssl_rc4_info_create(void)
52 {
53 struct rc4_state *info = g_malloc(sizeof(struct rc4_state), 1);
54 BOOL ret;
55 DWORD dwErr;
56 if (!info)
57 {
58 error("rdssl_rc4_info_create no memory\n");
59 return NULL;
60 }
61 ret = CryptAcquireContext(&info->hCryptProv,
62 L"MSTSC",
63 MS_ENHANCED_PROV,
64 PROV_RSA_FULL,
65 0);
66 if (!ret)
67 {
68 dwErr = GetLastError();
69 if (dwErr == NTE_BAD_KEYSET)
70 {
71 ret = CryptAcquireContext(&info->hCryptProv,
72 L"MSTSC",
73 MS_ENHANCED_PROV,
74 PROV_RSA_FULL,
75 CRYPT_NEWKEYSET);
76 }
77 }
78 if (!ret)
79 {
80 dwErr = GetLastError();
81 error("CryptAcquireContext failed with %lx\n", dwErr);
82 g_free(info);
83 return NULL;
84 }
85 return info;
86 }
87
88 /*****************************************************************************/
89 void
rdssl_rc4_info_delete(void * rc4_info)90 rdssl_rc4_info_delete(void* rc4_info)
91 {
92 struct rc4_state *info = rc4_info;
93 BOOL ret = TRUE;
94 DWORD dwErr;
95 if (!info)
96 {
97 //error("rdssl_rc4_info_delete rc4_info is null\n");
98 return;
99 }
100 if (info->hKey)
101 {
102 ret = CryptDestroyKey(info->hKey);
103 if (!ret)
104 {
105 dwErr = GetLastError();
106 error("CryptDestroyKey failed with %lx\n", dwErr);
107 }
108 }
109 if (info->hCryptProv)
110 {
111 ret = CryptReleaseContext(info->hCryptProv, 0);
112 if (!ret)
113 {
114 dwErr = GetLastError();
115 error("CryptReleaseContext failed with %lx\n", dwErr);
116 }
117 }
118 g_free(rc4_info);
119 }
120
121 /*****************************************************************************/
122 void
rdssl_rc4_set_key(void * rc4_info,char * key,int len)123 rdssl_rc4_set_key(void* rc4_info, char* key, int len)
124 {
125 struct rc4_state *info = rc4_info;
126 BOOL ret;
127 DWORD dwErr;
128 BYTE * blob;
129 PUBLICKEYSTRUC *desc;
130 DWORD * keySize;
131 BYTE * keyBuf;
132 if (!rc4_info || !key || !len || !info->hCryptProv)
133 {
134 error("rdssl_rc4_set_key %p %p %d\n", rc4_info, key, len);
135 return;
136 }
137 blob = g_malloc(sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + len, 0);
138 if (!blob)
139 {
140 error("rdssl_rc4_set_key no memory\n");
141 return;
142 }
143 desc = (PUBLICKEYSTRUC *)blob;
144 keySize = (DWORD *)(blob + sizeof(PUBLICKEYSTRUC));
145 keyBuf = blob + sizeof(PUBLICKEYSTRUC) + sizeof(DWORD);
146 desc->aiKeyAlg = CALG_RC4;
147 desc->bType = PLAINTEXTKEYBLOB;
148 desc->bVersion = CUR_BLOB_VERSION;
149 desc->reserved = 0;
150 *keySize = len;
151 memcpy(keyBuf, key, len);
152 if (info->hKey)
153 {
154 CryptDestroyKey(info->hKey);
155 info->hKey = 0;
156 }
157 ret = CryptImportKey(info->hCryptProv,
158 blob,
159 sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + len,
160 0,
161 CRYPT_EXPORTABLE,
162 &info->hKey);
163 g_free(blob);
164 if (!ret)
165 {
166 dwErr = GetLastError();
167 error("CryptImportKey failed with %lx\n", dwErr);
168 }
169 }
170
171 /*****************************************************************************/
172 void
rdssl_rc4_crypt(void * rc4_info,char * in_data,char * out_data,int len)173 rdssl_rc4_crypt(void* rc4_info, char* in_data, char* out_data, int len)
174 {
175 struct rc4_state *info = rc4_info;
176 BOOL ret;
177 DWORD dwErr;
178 BYTE * intermediate_data;
179 DWORD dwLen = len;
180 if (!rc4_info || !in_data || !out_data || !len || !info->hKey)
181 {
182 error("rdssl_rc4_crypt %p %p %p %d\n", rc4_info, in_data, out_data, len);
183 return;
184 }
185 intermediate_data = g_malloc(len, 0);
186 if (!intermediate_data)
187 {
188 error("rdssl_rc4_set_key no memory\n");
189 return;
190 }
191 memcpy(intermediate_data, in_data, len);
192 ret = CryptEncrypt(info->hKey,
193 0,
194 FALSE,
195 0,
196 intermediate_data,
197 &dwLen,
198 dwLen);
199 if (!ret)
200 {
201 dwErr = GetLastError();
202 g_free(intermediate_data);
203 error("CryptEncrypt failed with %lx\n", dwErr);
204 return;
205 }
206 memcpy(out_data, intermediate_data, len);
207 g_free(intermediate_data);
208 }
209
210 struct hash_context
211 {
212 HCRYPTPROV hCryptProv;
213 HCRYPTKEY hHash;
214 };
215
216 /*****************************************************************************/
217 void*
rdssl_hash_info_create(ALG_ID id)218 rdssl_hash_info_create(ALG_ID id)
219 {
220 struct hash_context *info = g_malloc(sizeof(struct hash_context), 1);
221 BOOL ret;
222 DWORD dwErr;
223 if (!info)
224 {
225 error("rdssl_hash_info_create %d no memory\n", id);
226 return NULL;
227 }
228 ret = CryptAcquireContext(&info->hCryptProv,
229 L"MSTSC",
230 MS_ENHANCED_PROV,
231 PROV_RSA_FULL,
232 0);
233 if (!ret)
234 {
235 dwErr = GetLastError();
236 if (dwErr == NTE_BAD_KEYSET)
237 {
238 ret = CryptAcquireContext(&info->hCryptProv,
239 L"MSTSC",
240 MS_ENHANCED_PROV,
241 PROV_RSA_FULL,
242 CRYPT_NEWKEYSET);
243 }
244 }
245 if (!ret)
246 {
247 dwErr = GetLastError();
248 g_free(info);
249 error("CryptAcquireContext failed with %lx\n", dwErr);
250 return NULL;
251 }
252 ret = CryptCreateHash(info->hCryptProv,
253 id,
254 0,
255 0,
256 &info->hHash);
257 if (!ret)
258 {
259 dwErr = GetLastError();
260 CryptReleaseContext(info->hCryptProv, 0);
261 g_free(info);
262 error("CryptCreateHash failed with %lx\n", dwErr);
263 return NULL;
264 }
265 return info;
266 }
267
268 /*****************************************************************************/
269 void
rdssl_hash_info_delete(void * hash_info)270 rdssl_hash_info_delete(void* hash_info)
271 {
272 struct hash_context *info = hash_info;
273 if (!info)
274 {
275 //error("ssl_hash_info_delete hash_info is null\n");
276 return;
277 }
278 if (info->hHash)
279 {
280 CryptDestroyHash(info->hHash);
281 }
282 if (info->hCryptProv)
283 {
284 CryptReleaseContext(info->hCryptProv, 0);
285 }
286 g_free(hash_info);
287 }
288
289 /*****************************************************************************/
290 void
rdssl_hash_clear(void * hash_info,ALG_ID id)291 rdssl_hash_clear(void* hash_info, ALG_ID id)
292 {
293 struct hash_context *info = hash_info;
294 BOOL ret;
295 DWORD dwErr;
296 if (!info || !info->hHash || !info->hCryptProv)
297 {
298 error("rdssl_hash_clear %p\n", info);
299 return;
300 }
301 ret = CryptDestroyHash(info->hHash);
302 if (!ret)
303 {
304 dwErr = GetLastError();
305 error("CryptDestroyHash failed with %lx\n", dwErr);
306 return;
307 }
308 ret = CryptCreateHash(info->hCryptProv,
309 id,
310 0,
311 0,
312 &info->hHash);
313 if (!ret)
314 {
315 dwErr = GetLastError();
316 error("CryptCreateHash failed with %lx\n", dwErr);
317 }
318 }
319
320 void
rdssl_hash_transform(void * hash_info,char * data,int len)321 rdssl_hash_transform(void* hash_info, char* data, int len)
322 {
323 struct hash_context *info = hash_info;
324 BOOL ret;
325 DWORD dwErr;
326 if (!info || !info->hHash || !info->hCryptProv || !data || !len)
327 {
328 error("rdssl_hash_transform %p %p %d\n", hash_info, data, len);
329 return;
330 }
331 ret = CryptHashData(info->hHash,
332 (BYTE *)data,
333 len,
334 0);
335 if (!ret)
336 {
337 dwErr = GetLastError();
338 error("CryptHashData failed with %lx\n", dwErr);
339 }
340 }
341
342 /*****************************************************************************/
343 void
rdssl_hash_complete(void * hash_info,char * data)344 rdssl_hash_complete(void* hash_info, char* data)
345 {
346 struct hash_context *info = hash_info;
347 BOOL ret;
348 DWORD dwErr, dwDataLen;
349 if (!info || !info->hHash || !info->hCryptProv || !data)
350 {
351 error("rdssl_hash_complete %p %p\n", hash_info, data);
352 return;
353 }
354 ret = CryptGetHashParam(info->hHash,
355 HP_HASHVAL,
356 NULL,
357 &dwDataLen,
358 0);
359 if (!ret)
360 {
361 dwErr = GetLastError();
362 error("CryptGetHashParam failed with %lx\n", dwErr);
363 return;
364 }
365 ret = CryptGetHashParam(info->hHash,
366 HP_HASHVAL,
367 (BYTE *)data,
368 &dwDataLen,
369 0);
370 if (!ret)
371 {
372 dwErr = GetLastError();
373 error("CryptGetHashParam failed with %lx\n", dwErr);
374 }
375 }
376
377 /*****************************************************************************/
378 void*
rdssl_sha1_info_create(void)379 rdssl_sha1_info_create(void)
380 {
381 return rdssl_hash_info_create(CALG_SHA1);
382 }
383
384 /*****************************************************************************/
385 void
rdssl_sha1_info_delete(void * sha1_info)386 rdssl_sha1_info_delete(void* sha1_info)
387 {
388 rdssl_hash_info_delete(sha1_info);
389 }
390
391 /*****************************************************************************/
392 void
rdssl_sha1_clear(void * sha1_info)393 rdssl_sha1_clear(void* sha1_info)
394 {
395 rdssl_hash_clear(sha1_info, CALG_SHA1);
396 }
397
398 /*****************************************************************************/
399 void
rdssl_sha1_transform(void * sha1_info,char * data,int len)400 rdssl_sha1_transform(void* sha1_info, char* data, int len)
401 {
402 rdssl_hash_transform(sha1_info, data, len);
403 }
404
405 /*****************************************************************************/
406 void
rdssl_sha1_complete(void * sha1_info,char * data)407 rdssl_sha1_complete(void* sha1_info, char* data)
408 {
409 rdssl_hash_complete(sha1_info, data);
410 }
411
412 /*****************************************************************************/
413 void*
rdssl_md5_info_create(void)414 rdssl_md5_info_create(void)
415 {
416 return rdssl_hash_info_create(CALG_MD5);
417 }
418
419 /*****************************************************************************/
420 void
rdssl_md5_info_delete(void * md5_info)421 rdssl_md5_info_delete(void* md5_info)
422 {
423 rdssl_hash_info_delete(md5_info);
424 }
425
426 /*****************************************************************************/
427 void
rdssl_md5_clear(void * md5_info)428 rdssl_md5_clear(void* md5_info)
429 {
430 rdssl_hash_clear(md5_info, CALG_MD5);
431 }
432
433 /*****************************************************************************/
434 void
rdssl_md5_transform(void * md5_info,char * data,int len)435 rdssl_md5_transform(void* md5_info, char* data, int len)
436 {
437 rdssl_hash_transform(md5_info, data, len);
438 }
439
440 /*****************************************************************************/
441 void
rdssl_md5_complete(void * md5_info,char * data)442 rdssl_md5_complete(void* md5_info, char* data)
443 {
444 rdssl_hash_complete(md5_info, data);
445 }
446
447 /*****************************************************************************/
448 void
rdssl_hmac_md5(char * key,int keylen,char * data,int len,char * output)449 rdssl_hmac_md5(char* key, int keylen, char* data, int len, char* output)
450 {
451 HCRYPTPROV hCryptProv;
452 HCRYPTKEY hKey;
453 HCRYPTKEY hHash;
454 BOOL ret;
455 DWORD dwErr, dwDataLen;
456 HMAC_INFO info;
457 BYTE * blob;
458 PUBLICKEYSTRUC *desc;
459 DWORD * keySize;
460 BYTE * keyBuf;
461 BYTE sum[16];
462
463 if (!key || !keylen || !data || !len ||!output)
464 {
465 error("rdssl_hmac_md5 %p %d %p %d %p\n", key, keylen, data, len, output);
466 return;
467 }
468 blob = g_malloc(sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + keylen, 0);
469 desc = (PUBLICKEYSTRUC *)blob;
470 keySize = (DWORD *)(blob + sizeof(PUBLICKEYSTRUC));
471 keyBuf = blob + sizeof(PUBLICKEYSTRUC) + sizeof(DWORD);
472 if (!blob)
473 {
474 error("rdssl_hmac_md5 %d no memory\n");
475 return;
476 }
477 ret = CryptAcquireContext(&hCryptProv,
478 L"MSTSC",
479 MS_ENHANCED_PROV,
480 PROV_RSA_FULL,
481 0);
482 if (!ret)
483 {
484 dwErr = GetLastError();
485 if (dwErr == NTE_BAD_KEYSET)
486 {
487 ret = CryptAcquireContext(&hCryptProv,
488 L"MSTSC",
489 MS_ENHANCED_PROV,
490 PROV_RSA_FULL,
491 CRYPT_NEWKEYSET);
492 }
493 }
494 if (!ret)
495 {
496 dwErr = GetLastError();
497 g_free(blob);
498 error("CryptAcquireContext failed with %lx\n", dwErr);
499 return;
500 }
501 desc->aiKeyAlg = CALG_RC4;
502 desc->bType = PLAINTEXTKEYBLOB;
503 desc->bVersion = CUR_BLOB_VERSION;
504 desc->reserved = 0;
505 if (keylen > 64)
506 {
507 HCRYPTKEY hHash;
508 ret = CryptCreateHash(hCryptProv,
509 CALG_MD5,
510 0,
511 0,
512 &hHash);
513 if (!ret)
514 {
515 dwErr = GetLastError();
516 g_free(blob);
517 error("CryptCreateHash failed with %lx\n", dwErr);
518 return;
519 }
520 ret = CryptHashData(hHash,
521 (BYTE *)key,
522 keylen,
523 0);
524 if (!ret)
525 {
526 dwErr = GetLastError();
527 g_free(blob);
528 error("CryptHashData failed with %lx\n", dwErr);
529 return;
530 }
531 ret = CryptGetHashParam(hHash,
532 HP_HASHVAL,
533 NULL,
534 &dwDataLen,
535 0);
536 if (!ret)
537 {
538 dwErr = GetLastError();
539 g_free(blob);
540 error("CryptGetHashParam failed with %lx\n", dwErr);
541 return;
542 }
543 ret = CryptGetHashParam(hHash,
544 HP_HASHVAL,
545 sum,
546 &dwDataLen,
547 0);
548 if (!ret)
549 {
550 dwErr = GetLastError();
551 g_free(blob);
552 error("CryptGetHashParam failed with %lx\n", dwErr);
553 return;
554 }
555 keylen = dwDataLen;
556 key = (char *)sum;
557 }
558 *keySize = keylen;
559 memcpy(keyBuf, key, keylen);
560 ret = CryptImportKey(hCryptProv,
561 blob,
562 sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + keylen,
563 0,
564 CRYPT_EXPORTABLE,
565 &hKey);
566 g_free(blob);
567 if (!ret)
568 {
569 dwErr = GetLastError();
570 error("CryptImportKey failed with %lx\n", dwErr);
571 return;
572 }
573 ret = CryptCreateHash(hCryptProv,
574 CALG_HMAC,
575 hKey,
576 0,
577 &hHash);
578 if (!ret)
579 {
580 dwErr = GetLastError();
581 error("CryptCreateHash failed with %lx\n", dwErr);
582 return;
583 }
584 info.HashAlgid = CALG_MD5;
585 info.cbInnerString = 0;
586 info.cbOuterString = 0;
587 ret = CryptSetHashParam(hHash,
588 HP_HMAC_INFO,
589 (BYTE *)&info,
590 0);
591 if (!ret)
592 {
593 dwErr = GetLastError();
594 error("CryptSetHashParam failed with %lx\n", dwErr);
595 return;
596 }
597 ret = CryptHashData(hHash,
598 (BYTE *)data,
599 len,
600 0);
601 if (!ret)
602 {
603 dwErr = GetLastError();
604 error("CryptHashData failed with %lx\n", dwErr);
605 return;
606 }
607 ret = CryptGetHashParam(hHash,
608 HP_HASHVAL,
609 NULL,
610 &dwDataLen,
611 0);
612 if (!ret)
613 {
614 dwErr = GetLastError();
615 error("CryptGetHashParam failed with %lx\n", dwErr);
616 return;
617 }
618 ret = CryptGetHashParam(hHash,
619 HP_HASHVAL,
620 (BYTE *)output,
621 &dwDataLen,
622 0);
623 if (!ret)
624 {
625 dwErr = GetLastError();
626 error("CryptGetHashParam failed with %lx\n", dwErr);
627 return;
628 }
629 CryptDestroyHash(hHash);
630 ret = CryptReleaseContext(hCryptProv, 0);
631 }
632
633 /*****************************************************************************/
634 /*****************************************************************************/
635 /* big number stuff */
636 /******************* SHORT COPYRIGHT NOTICE*************************
637 This source code is part of the BigDigits multiple-precision
638 arithmetic library Version 1.0 originally written by David Ireland,
639 copyright (c) 2001 D.I. Management Services Pty Limited, all rights
640 reserved. It is provided "as is" with no warranties. You may use
641 this software under the terms of the full copyright notice
642 "bigdigitsCopyright.txt" that should have been included with
643 this library. To obtain a copy send an email to
644 <code@di-mgt.com.au> or visit <www.di-mgt.com.au/crypto.html>.
645 This notice must be retained in any copy.
646 ****************** END OF COPYRIGHT NOTICE*************************/
647 /************************* COPYRIGHT NOTICE*************************
648 This source code is part of the BigDigits multiple-precision
649 arithmetic library Version 1.0 originally written by David Ireland,
650 copyright (c) 2001 D.I. Management Services Pty Limited, all rights
651 reserved. You are permitted to use compiled versions of this code as
652 part of your own executable files and to distribute unlimited copies
653 of such executable files for any purposes including commercial ones
654 provided you keep the copyright notices intact in the source code
655 and that you ensure that the following characters remain in any
656 object or executable files you distribute:
657
658 "Contains multiple-precision arithmetic code originally written
659 by David Ireland, copyright (c) 2001 by D.I. Management Services
660 Pty Limited <www.di-mgt.com.au>, and is used with permission."
661
662 David Ireland and DI Management Services Pty Limited make no
663 representations concerning either the merchantability of this
664 software or the suitability of this software for any particular
665 purpose. It is provided "as is" without express or implied warranty
666 of any kind.
667
668 Please forward any comments and bug reports to <code@di-mgt.com.au>.
669 The latest version of the source code can be downloaded from
670 www.di-mgt.com.au/crypto.html.
671 ****************** END OF COPYRIGHT NOTICE*************************/
672
673 typedef unsigned int DIGIT_T;
674 #define HIBITMASK 0x80000000
675 #define MAX_DIG_LEN 51
676 #define MAX_DIGIT 0xffffffff
677 #define BITS_PER_DIGIT 32
678 #define MAX_HALF_DIGIT 0xffff
679 #define B_J (MAX_HALF_DIGIT + 1)
680 #define LOHALF(x) ((DIGIT_T)((x) & 0xffff))
681 #define HIHALF(x) ((DIGIT_T)((x) >> 16 & 0xffff))
682 #define TOHIGH(x) ((DIGIT_T)((x) << 16))
683
684 #define mpNEXTBITMASK(mask, n) \
685 { \
686 if (mask == 1) \
687 { \
688 mask = HIBITMASK; \
689 n--; \
690 } \
691 else \
692 { \
693 mask >>= 1; \
694 } \
695 }
696
697 /*****************************************************************************/
698 static DIGIT_T
mpAdd(DIGIT_T * w,DIGIT_T * u,DIGIT_T * v,unsigned int ndigits)699 mpAdd(DIGIT_T* w, DIGIT_T* u, DIGIT_T* v, unsigned int ndigits)
700 {
701 /* Calculates w = u + v
702 where w, u, v are multiprecision integers of ndigits each
703 Returns carry if overflow. Carry = 0 or 1.
704
705 Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A. */
706 DIGIT_T k;
707 unsigned int j;
708
709 /* Step A1. Initialise */
710 k = 0;
711 for (j = 0; j < ndigits; j++)
712 {
713 /* Step A2. Add digits w_j = (u_j + v_j + k)
714 Set k = 1 if carry (overflow) occurs */
715 w[j] = u[j] + k;
716 if (w[j] < k)
717 {
718 k = 1;
719 }
720 else
721 {
722 k = 0;
723 }
724 w[j] += v[j];
725 if (w[j] < v[j])
726 {
727 k++;
728 }
729 } /* Step A3. Loop on j */
730 return k; /* w_n = k */
731 }
732
733 /*****************************************************************************/
734 static void
mpSetDigit(DIGIT_T * a,DIGIT_T d,unsigned int ndigits)735 mpSetDigit(DIGIT_T* a, DIGIT_T d, unsigned int ndigits)
736 { /* Sets a = d where d is a single digit */
737 unsigned int i;
738
739 for (i = 1; i < ndigits; i++)
740 {
741 a[i] = 0;
742 }
743 a[0] = d;
744 }
745
746 /*****************************************************************************/
747 static int
mpCompare(DIGIT_T * a,DIGIT_T * b,unsigned int ndigits)748 mpCompare(DIGIT_T* a, DIGIT_T* b, unsigned int ndigits)
749 {
750 /* Returns sign of (a - b) */
751 if (ndigits == 0)
752 {
753 return 0;
754 }
755 while (ndigits--)
756 {
757 if (a[ndigits] > b[ndigits])
758 {
759 return 1; /* GT */
760 }
761 if (a[ndigits] < b[ndigits])
762 {
763 return -1; /* LT */
764 }
765 }
766 return 0; /* EQ */
767 }
768
769 /*****************************************************************************/
770 static void
mpSetZero(DIGIT_T * a,unsigned int ndigits)771 mpSetZero(DIGIT_T* a, unsigned int ndigits)
772 { /* Sets a = 0 */
773 unsigned int i;
774
775 for (i = 0; i < ndigits; i++)
776 {
777 a[i] = 0;
778 }
779 }
780
781 /*****************************************************************************/
782 static void
mpSetEqual(DIGIT_T * a,DIGIT_T * b,unsigned int ndigits)783 mpSetEqual(DIGIT_T* a, DIGIT_T* b, unsigned int ndigits)
784 { /* Sets a = b */
785 unsigned int i;
786
787 for (i = 0; i < ndigits; i++)
788 {
789 a[i] = b[i];
790 }
791 }
792
793 /*****************************************************************************/
794 static unsigned int
mpSizeof(DIGIT_T * a,unsigned int ndigits)795 mpSizeof(DIGIT_T* a, unsigned int ndigits)
796 { /* Returns size of significant digits in a */
797 while (ndigits--)
798 {
799 if (a[ndigits] != 0)
800 {
801 return (++ndigits);
802 }
803 }
804 return 0;
805 }
806
807 /*****************************************************************************/
808 static DIGIT_T
mpShiftLeft(DIGIT_T * a,DIGIT_T * b,unsigned int x,unsigned int ndigits)809 mpShiftLeft(DIGIT_T* a, DIGIT_T* b, unsigned int x, unsigned int ndigits)
810 { /* Computes a = b << x */
811 unsigned int i;
812 unsigned int y;
813 DIGIT_T mask;
814 DIGIT_T carry;
815 DIGIT_T nextcarry;
816
817 /* Check input - NB unspecified result */
818 if (x >= BITS_PER_DIGIT)
819 {
820 return 0;
821 }
822 /* Construct mask */
823 mask = HIBITMASK;
824 for (i = 1; i < x; i++)
825 {
826 mask = (mask >> 1) | mask;
827 }
828 if (x == 0)
829 {
830 mask = 0x0;
831 }
832 y = BITS_PER_DIGIT - x;
833 carry = 0;
834 for (i = 0; i < ndigits; i++)
835 {
836 nextcarry = (b[i] & mask) >> y;
837 a[i] = b[i] << x | carry;
838 carry = nextcarry;
839 }
840 return carry;
841 }
842
843 /*****************************************************************************/
844 static DIGIT_T
mpShiftRight(DIGIT_T * a,DIGIT_T * b,unsigned int x,unsigned int ndigits)845 mpShiftRight(DIGIT_T* a, DIGIT_T* b, unsigned int x, unsigned int ndigits)
846 { /* Computes a = b >> x */
847 unsigned int i;
848 unsigned int y;
849 DIGIT_T mask;
850 DIGIT_T carry;
851 DIGIT_T nextcarry;
852
853 /* Check input - NB unspecified result */
854 if (x >= BITS_PER_DIGIT)
855 {
856 return 0;
857 }
858 /* Construct mask */
859 mask = 0x1;
860 for (i = 1; i < x; i++)
861 {
862 mask = (mask << 1) | mask;
863 }
864 if (x == 0)
865 {
866 mask = 0x0;
867 }
868 y = BITS_PER_DIGIT - x;
869 carry = 0;
870 i = ndigits;
871 while (i--)
872 {
873 nextcarry = (b[i] & mask) << y;
874 a[i] = b[i] >> x | carry;
875 carry = nextcarry;
876 }
877 return carry;
878 }
879
880 /*****************************************************************************/
881 static void
spMultSub(DIGIT_T * uu,DIGIT_T qhat,DIGIT_T v1,DIGIT_T v0)882 spMultSub(DIGIT_T* uu, DIGIT_T qhat, DIGIT_T v1, DIGIT_T v0)
883 {
884 /* Compute uu = uu - q(v1v0)
885 where uu = u3u2u1u0, u3 = 0
886 and u_n, v_n are all half-digits
887 even though v1, v2 are passed as full digits. */
888 DIGIT_T p0;
889 DIGIT_T p1;
890 DIGIT_T t;
891
892 p0 = qhat * v0;
893 p1 = qhat * v1;
894 t = p0 + TOHIGH(LOHALF(p1));
895 uu[0] -= t;
896 if (uu[0] > MAX_DIGIT - t)
897 {
898 uu[1]--; /* Borrow */
899 }
900 uu[1] -= HIHALF(p1);
901 }
902
903 /*****************************************************************************/
904 static int
spMultiply(DIGIT_T * p,DIGIT_T x,DIGIT_T y)905 spMultiply(DIGIT_T* p, DIGIT_T x, DIGIT_T y)
906 { /* Computes p = x * y */
907 /* Ref: Arbitrary Precision Computation
908 http://numbers.computation.free.fr/Constants/constants.html
909
910 high p1 p0 low
911 +--------+--------+--------+--------+
912 | x1*y1 | x0*y0 |
913 +--------+--------+--------+--------+
914 +-+--------+--------+
915 |1| (x0*y1 + x1*y1) |
916 +-+--------+--------+
917 ^carry from adding (x0*y1+x1*y1) together
918 +-+
919 |1|< carry from adding LOHALF t
920 +-+ to high half of p0 */
921 DIGIT_T x0;
922 DIGIT_T y0;
923 DIGIT_T x1;
924 DIGIT_T y1;
925 DIGIT_T t;
926 DIGIT_T u;
927 DIGIT_T carry;
928
929 /* Split each x,y into two halves
930 x = x0 + B * x1
931 y = y0 + B * y1
932 where B = 2^16, half the digit size
933 Product is
934 xy = x0y0 + B(x0y1 + x1y0) + B^2(x1y1) */
935
936 x0 = LOHALF(x);
937 x1 = HIHALF(x);
938 y0 = LOHALF(y);
939 y1 = HIHALF(y);
940
941 /* Calc low part - no carry */
942 p[0] = x0 * y0;
943
944 /* Calc middle part */
945 t = x0 * y1;
946 u = x1 * y0;
947 t += u;
948 if (t < u)
949 {
950 carry = 1;
951 }
952 else
953 {
954 carry = 0;
955 }
956 /* This carry will go to high half of p[1]
957 + high half of t into low half of p[1] */
958 carry = TOHIGH(carry) + HIHALF(t);
959
960 /* Add low half of t to high half of p[0] */
961 t = TOHIGH(t);
962 p[0] += t;
963 if (p[0] < t)
964 {
965 carry++;
966 }
967
968 p[1] = x1 * y1;
969 p[1] += carry;
970
971 return 0;
972 }
973
974 /*****************************************************************************/
975 static DIGIT_T
spDivide(DIGIT_T * q,DIGIT_T * r,DIGIT_T * u,DIGIT_T v)976 spDivide(DIGIT_T* q, DIGIT_T* r, DIGIT_T* u, DIGIT_T v)
977 { /* Computes quotient q = u / v, remainder r = u mod v
978 where u is a double digit
979 and q, v, r are single precision digits.
980 Returns high digit of quotient (max value is 1)
981 Assumes normalised such that v1 >= b/2
982 where b is size of HALF_DIGIT
983 i.e. the most significant bit of v should be one
984
985 In terms of half-digits in Knuth notation:
986 (q2q1q0) = (u4u3u2u1u0) / (v1v0)
987 (r1r0) = (u4u3u2u1u0) mod (v1v0)
988 for m = 2, n = 2 where u4 = 0
989 q2 is either 0 or 1.
990 We set q = (q1q0) and return q2 as "overflow' */
991 DIGIT_T qhat;
992 DIGIT_T rhat;
993 DIGIT_T t;
994 DIGIT_T v0;
995 DIGIT_T v1;
996 DIGIT_T u0;
997 DIGIT_T u1;
998 DIGIT_T u2;
999 DIGIT_T u3;
1000 DIGIT_T uu[2];
1001 DIGIT_T q2;
1002
1003 /* Check for normalisation */
1004 if (!(v & HIBITMASK))
1005 {
1006 *q = *r = 0;
1007 return MAX_DIGIT;
1008 }
1009
1010 /* Split up into half-digits */
1011 v0 = LOHALF(v);
1012 v1 = HIHALF(v);
1013 u0 = LOHALF(u[0]);
1014 u1 = HIHALF(u[0]);
1015 u2 = LOHALF(u[1]);
1016 u3 = HIHALF(u[1]);
1017
1018 /* Do three rounds of Knuth Algorithm D Vol 2 p272 */
1019
1020 /* ROUND 1. Set j = 2 and calculate q2 */
1021 /* Estimate qhat = (u4u3)/v1 = 0 or 1
1022 then set (u4u3u2) -= qhat(v1v0)
1023 where u4 = 0. */
1024 qhat = u3 / v1;
1025 if (qhat > 0)
1026 {
1027 rhat = u3 - qhat * v1;
1028 t = TOHIGH(rhat) | u2;
1029 if (qhat * v0 > t)
1030 {
1031 qhat--;
1032 }
1033 }
1034 uu[1] = 0; /* (u4) */
1035 uu[0] = u[1]; /* (u3u2) */
1036 if (qhat > 0)
1037 {
1038 /* (u4u3u2) -= qhat(v1v0) where u4 = 0 */
1039 spMultSub(uu, qhat, v1, v0);
1040 if (HIHALF(uu[1]) != 0)
1041 { /* Add back */
1042 qhat--;
1043 uu[0] += v;
1044 uu[1] = 0;
1045 }
1046 }
1047 q2 = qhat;
1048 /* ROUND 2. Set j = 1 and calculate q1 */
1049 /* Estimate qhat = (u3u2) / v1
1050 then set (u3u2u1) -= qhat(v1v0) */
1051 t = uu[0];
1052 qhat = t / v1;
1053 rhat = t - qhat * v1;
1054 /* Test on v0 */
1055 t = TOHIGH(rhat) | u1;
1056 if ((qhat == B_J) || (qhat * v0 > t))
1057 {
1058 qhat--;
1059 rhat += v1;
1060 t = TOHIGH(rhat) | u1;
1061 if ((rhat < B_J) && (qhat * v0 > t))
1062 {
1063 qhat--;
1064 }
1065 }
1066 /* Multiply and subtract
1067 (u3u2u1)' = (u3u2u1) - qhat(v1v0) */
1068 uu[1] = HIHALF(uu[0]); /* (0u3) */
1069 uu[0] = TOHIGH(LOHALF(uu[0])) | u1; /* (u2u1) */
1070 spMultSub(uu, qhat, v1, v0);
1071 if (HIHALF(uu[1]) != 0)
1072 { /* Add back */
1073 qhat--;
1074 uu[0] += v;
1075 uu[1] = 0;
1076 }
1077 /* q1 = qhat */
1078 *q = TOHIGH(qhat);
1079 /* ROUND 3. Set j = 0 and calculate q0 */
1080 /* Estimate qhat = (u2u1) / v1
1081 then set (u2u1u0) -= qhat(v1v0) */
1082 t = uu[0];
1083 qhat = t / v1;
1084 rhat = t - qhat * v1;
1085 /* Test on v0 */
1086 t = TOHIGH(rhat) | u0;
1087 if ((qhat == B_J) || (qhat * v0 > t))
1088 {
1089 qhat--;
1090 rhat += v1;
1091 t = TOHIGH(rhat) | u0;
1092 if ((rhat < B_J) && (qhat * v0 > t))
1093 {
1094 qhat--;
1095 }
1096 }
1097 /* Multiply and subtract
1098 (u2u1u0)" = (u2u1u0)' - qhat(v1v0) */
1099 uu[1] = HIHALF(uu[0]); /* (0u2) */
1100 uu[0] = TOHIGH(LOHALF(uu[0])) | u0; /* (u1u0) */
1101 spMultSub(uu, qhat, v1, v0);
1102 if (HIHALF(uu[1]) != 0)
1103 { /* Add back */
1104 qhat--;
1105 uu[0] += v;
1106 uu[1] = 0;
1107 }
1108 /* q0 = qhat */
1109 *q |= LOHALF(qhat);
1110 /* Remainder is in (u1u0) i.e. uu[0] */
1111 *r = uu[0];
1112 return q2;
1113 }
1114
1115 /*****************************************************************************/
1116 static int
QhatTooBig(DIGIT_T qhat,DIGIT_T rhat,DIGIT_T vn2,DIGIT_T ujn2)1117 QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, DIGIT_T vn2, DIGIT_T ujn2)
1118 { /* Returns true if Qhat is too big
1119 i.e. if (Qhat * Vn-2) > (b.Rhat + Uj+n-2) */
1120 DIGIT_T t[2];
1121
1122 spMultiply(t, qhat, vn2);
1123 if (t[1] < rhat)
1124 {
1125 return 0;
1126 }
1127 else if (t[1] > rhat)
1128 {
1129 return 1;
1130 }
1131 else if (t[0] > ujn2)
1132 {
1133 return 1;
1134 }
1135 return 0;
1136 }
1137
1138 /*****************************************************************************/
1139 static DIGIT_T
mpShortDiv(DIGIT_T * q,DIGIT_T * u,DIGIT_T v,unsigned int ndigits)1140 mpShortDiv(DIGIT_T* q, DIGIT_T* u, DIGIT_T v, unsigned int ndigits)
1141 {
1142 /* Calculates quotient q = u div v
1143 Returns remainder r = u mod v
1144 where q, u are multiprecision integers of ndigits each
1145 and d, v are single precision digits.
1146
1147 Makes no assumptions about normalisation.
1148
1149 Ref: Knuth Vol 2 Ch 4.3.1 Exercise 16 p625 */
1150 unsigned int j;
1151 unsigned int shift;
1152 DIGIT_T t[2];
1153 DIGIT_T r;
1154 DIGIT_T bitmask;
1155 DIGIT_T overflow;
1156 DIGIT_T* uu;
1157
1158 if (ndigits == 0)
1159 {
1160 return 0;
1161 }
1162 if (v == 0)
1163 {
1164 return 0; /* Divide by zero error */
1165 }
1166 /* Normalise first */
1167 /* Requires high bit of V
1168 to be set, so find most signif. bit then shift left,
1169 i.e. d = 2^shift, u' = u * d, v' = v * d. */
1170 bitmask = HIBITMASK;
1171 for (shift = 0; shift < BITS_PER_DIGIT; shift++)
1172 {
1173 if (v & bitmask)
1174 {
1175 break;
1176 }
1177 bitmask >>= 1;
1178 }
1179 v <<= shift;
1180 overflow = mpShiftLeft(q, u, shift, ndigits);
1181 uu = q;
1182 /* Step S1 - modified for extra digit. */
1183 r = overflow; /* New digit Un */
1184 j = ndigits;
1185 while (j--)
1186 {
1187 /* Step S2. */
1188 t[1] = r;
1189 t[0] = uu[j];
1190 overflow = spDivide(&q[j], &r, t, v);
1191 }
1192 /* Unnormalise */
1193 r >>= shift;
1194 return r;
1195 }
1196
1197 /*****************************************************************************/
1198 static DIGIT_T
mpMultSub(DIGIT_T wn,DIGIT_T * w,DIGIT_T * v,DIGIT_T q,unsigned int n)1199 mpMultSub(DIGIT_T wn, DIGIT_T* w, DIGIT_T* v, DIGIT_T q, unsigned int n)
1200 { /* Compute w = w - qv
1201 where w = (WnW[n-1]...W[0])
1202 return modified Wn. */
1203 DIGIT_T k;
1204 DIGIT_T t[2];
1205 unsigned int i;
1206
1207 if (q == 0) /* No change */
1208 {
1209 return wn;
1210 }
1211 k = 0;
1212 for (i = 0; i < n; i++)
1213 {
1214 spMultiply(t, q, v[i]);
1215 w[i] -= k;
1216 if (w[i] > MAX_DIGIT - k)
1217 {
1218 k = 1;
1219 }
1220 else
1221 {
1222 k = 0;
1223 }
1224 w[i] -= t[0];
1225 if (w[i] > MAX_DIGIT - t[0])
1226 {
1227 k++;
1228 }
1229 k += t[1];
1230 }
1231 /* Cope with Wn not stored in array w[0..n-1] */
1232 wn -= k;
1233 return wn;
1234 }
1235
1236 /*****************************************************************************/
1237 static int
mpDivide(DIGIT_T * q,DIGIT_T * r,DIGIT_T * u,unsigned int udigits,DIGIT_T * v,unsigned int vdigits)1238 mpDivide(DIGIT_T* q, DIGIT_T* r, DIGIT_T* u, unsigned int udigits,
1239 DIGIT_T* v, unsigned int vdigits)
1240 { /* Computes quotient q = u / v and remainder r = u mod v
1241 where q, r, u are multiple precision digits
1242 all of udigits and the divisor v is vdigits.
1243
1244 Ref: Knuth Vol 2 Ch 4.3.1 p 272 Algorithm D.
1245
1246 Do without extra storage space, i.e. use r[] for
1247 normalised u[], unnormalise v[] at end, and cope with
1248 extra digit Uj+n added to u after normalisation.
1249
1250 WARNING: this trashes q and r first, so cannot do
1251 u = u / v or v = u mod v. */
1252 unsigned int shift;
1253 int n;
1254 int m;
1255 int j;
1256 int qhatOK;
1257 int cmp;
1258 DIGIT_T bitmask;
1259 DIGIT_T overflow;
1260 DIGIT_T qhat;
1261 DIGIT_T rhat;
1262 DIGIT_T t[2];
1263 DIGIT_T* uu;
1264 DIGIT_T* ww;
1265
1266 /* Clear q and r */
1267 mpSetZero(q, udigits);
1268 mpSetZero(r, udigits);
1269 /* Work out exact sizes of u and v */
1270 n = (int)mpSizeof(v, vdigits);
1271 m = (int)mpSizeof(u, udigits);
1272 m -= n;
1273 /* Catch special cases */
1274 if (n == 0)
1275 {
1276 return -1; /* Error: divide by zero */
1277 }
1278 if (n == 1)
1279 { /* Use short division instead */
1280 r[0] = mpShortDiv(q, u, v[0], udigits);
1281 return 0;
1282 }
1283 if (m < 0)
1284 { /* v > u, so just set q = 0 and r = u */
1285 mpSetEqual(r, u, udigits);
1286 return 0;
1287 }
1288 if (m == 0)
1289 { /* u and v are the same length */
1290 cmp = mpCompare(u, v, (unsigned int)n);
1291 if (cmp < 0)
1292 { /* v > u, as above */
1293 mpSetEqual(r, u, udigits);
1294 return 0;
1295 }
1296 else if (cmp == 0)
1297 { /* v == u, so set q = 1 and r = 0 */
1298 mpSetDigit(q, 1, udigits);
1299 return 0;
1300 }
1301 }
1302 /* In Knuth notation, we have:
1303 Given
1304 u = (Um+n-1 ... U1U0)
1305 v = (Vn-1 ... V1V0)
1306 Compute
1307 q = u/v = (QmQm-1 ... Q0)
1308 r = u mod v = (Rn-1 ... R1R0) */
1309 /* Step D1. Normalise */
1310 /* Requires high bit of Vn-1
1311 to be set, so find most signif. bit then shift left,
1312 i.e. d = 2^shift, u' = u * d, v' = v * d. */
1313 bitmask = HIBITMASK;
1314 for (shift = 0; shift < BITS_PER_DIGIT; shift++)
1315 {
1316 if (v[n - 1] & bitmask)
1317 {
1318 break;
1319 }
1320 bitmask >>= 1;
1321 }
1322 /* Normalise v in situ - NB only shift non-zero digits */
1323 overflow = mpShiftLeft(v, v, shift, n);
1324 /* Copy normalised dividend u*d into r */
1325 overflow = mpShiftLeft(r, u, shift, n + m);
1326 uu = r; /* Use ptr to keep notation constant */
1327 t[0] = overflow; /* New digit Um+n */
1328 /* Step D2. Initialise j. Set j = m */
1329 for (j = m; j >= 0; j--)
1330 {
1331 /* Step D3. Calculate Qhat = (b.Uj+n + Uj+n-1)/Vn-1 */
1332 qhatOK = 0;
1333 t[1] = t[0]; /* This is Uj+n */
1334 t[0] = uu[j+n-1];
1335 overflow = spDivide(&qhat, &rhat, t, v[n - 1]);
1336 /* Test Qhat */
1337 if (overflow)
1338 { /* Qhat = b */
1339 qhat = MAX_DIGIT;
1340 rhat = uu[j + n - 1];
1341 rhat += v[n - 1];
1342 if (rhat < v[n - 1]) /* Overflow */
1343 {
1344 qhatOK = 1;
1345 }
1346 }
1347 if (!qhatOK && QhatTooBig(qhat, rhat, v[n - 2], uu[j + n - 2]))
1348 { /* Qhat.Vn-2 > b.Rhat + Uj+n-2 */
1349 qhat--;
1350 rhat += v[n - 1];
1351 if (!(rhat < v[n - 1]))
1352 {
1353 if (QhatTooBig(qhat, rhat, v[n - 2], uu[j + n - 2]))
1354 {
1355 qhat--;
1356 }
1357 }
1358 }
1359 /* Step D4. Multiply and subtract */
1360 ww = &uu[j];
1361 overflow = mpMultSub(t[1], ww, v, qhat, (unsigned int)n);
1362 /* Step D5. Test remainder. Set Qj = Qhat */
1363 q[j] = qhat;
1364 if (overflow)
1365 { /* Step D6. Add back if D4 was negative */
1366 q[j]--;
1367 overflow = mpAdd(ww, ww, v, (unsigned int)n);
1368 }
1369 t[0] = uu[j + n - 1]; /* Uj+n on next round */
1370 } /* Step D7. Loop on j */
1371 /* Clear high digits in uu */
1372 for (j = n; j < m+n; j++)
1373 {
1374 uu[j] = 0;
1375 }
1376 /* Step D8. Unnormalise. */
1377 mpShiftRight(r, r, shift, n);
1378 mpShiftRight(v, v, shift, n);
1379 return 0;
1380 }
1381
1382 /*****************************************************************************/
1383 static int
mpModulo(DIGIT_T * r,DIGIT_T * u,unsigned int udigits,DIGIT_T * v,unsigned int vdigits)1384 mpModulo(DIGIT_T* r, DIGIT_T* u, unsigned int udigits,
1385 DIGIT_T* v, unsigned int vdigits)
1386 {
1387 /* Calculates r = u mod v
1388 where r, v are multiprecision integers of length vdigits
1389 and u is a multiprecision integer of length udigits.
1390 r may overlap v.
1391
1392 Note that r here is only vdigits long,
1393 whereas in mpDivide it is udigits long.
1394
1395 Use remainder from mpDivide function. */
1396 /* Double-length temp variable for divide fn */
1397 DIGIT_T qq[MAX_DIG_LEN * 2];
1398 /* Use a double-length temp for r to allow overlap of r and v */
1399 DIGIT_T rr[MAX_DIG_LEN * 2];
1400
1401 /* rr[2n] = u[2n] mod v[n] */
1402 mpDivide(qq, rr, u, udigits, v, vdigits);
1403 mpSetEqual(r, rr, vdigits);
1404 mpSetZero(rr, udigits);
1405 mpSetZero(qq, udigits);
1406 return 0;
1407 }
1408
1409 /*****************************************************************************/
1410 static int
mpMultiply(DIGIT_T * w,DIGIT_T * u,DIGIT_T * v,unsigned int ndigits)1411 mpMultiply(DIGIT_T* w, DIGIT_T* u, DIGIT_T* v, unsigned int ndigits)
1412 {
1413 /* Computes product w = u * v
1414 where u, v are multiprecision integers of ndigits each
1415 and w is a multiprecision integer of 2*ndigits
1416 Ref: Knuth Vol 2 Ch 4.3.1 p 268 Algorithm M. */
1417 DIGIT_T k;
1418 DIGIT_T t[2];
1419 unsigned int i;
1420 unsigned int j;
1421 unsigned int m;
1422 unsigned int n;
1423
1424 n = ndigits;
1425 m = n;
1426 /* Step M1. Initialise */
1427 for (i = 0; i < 2 * m; i++)
1428 {
1429 w[i] = 0;
1430 }
1431 for (j = 0; j < n; j++)
1432 {
1433 /* Step M2. Zero multiplier? */
1434 if (v[j] == 0)
1435 {
1436 w[j + m] = 0;
1437 }
1438 else
1439 {
1440 /* Step M3. Initialise i */
1441 k = 0;
1442 for (i = 0; i < m; i++)
1443 {
1444 /* Step M4. Multiply and add */
1445 /* t = u_i * v_j + w_(i+j) + k */
1446 spMultiply(t, u[i], v[j]);
1447 t[0] += k;
1448 if (t[0] < k)
1449 {
1450 t[1]++;
1451 }
1452 t[0] += w[i + j];
1453 if (t[0] < w[i+j])
1454 {
1455 t[1]++;
1456 }
1457 w[i + j] = t[0];
1458 k = t[1];
1459 }
1460 /* Step M5. Loop on i, set w_(j+m) = k */
1461 w[j + m] = k;
1462 }
1463 } /* Step M6. Loop on j */
1464 return 0;
1465 }
1466
1467 /*****************************************************************************/
1468 static int
mpModMult(DIGIT_T * a,DIGIT_T * x,DIGIT_T * y,DIGIT_T * m,unsigned int ndigits)1469 mpModMult(DIGIT_T* a, DIGIT_T* x, DIGIT_T* y,
1470 DIGIT_T* m, unsigned int ndigits)
1471 { /* Computes a = (x * y) mod m */
1472 /* Double-length temp variable */
1473 DIGIT_T p[MAX_DIG_LEN * 2];
1474
1475 /* Calc p[2n] = x * y */
1476 mpMultiply(p, x, y, ndigits);
1477 /* Then modulo */
1478 mpModulo(a, p, ndigits * 2, m, ndigits);
1479 mpSetZero(p, ndigits * 2);
1480 return 0;
1481 }
1482
1483 /*****************************************************************************/
1484 int
rdssl_mod_exp(char * out,int out_len,char * in,int in_len,char * mod,int mod_len,char * exp,int exp_len)1485 rdssl_mod_exp(char* out, int out_len, char* in, int in_len,
1486 char* mod, int mod_len, char* exp, int exp_len)
1487 {
1488 /* Computes y = x ^ e mod m */
1489 /* Binary left-to-right method */
1490 DIGIT_T mask;
1491 DIGIT_T* e;
1492 DIGIT_T* x;
1493 DIGIT_T* y;
1494 DIGIT_T* m;
1495 unsigned int n;
1496 int max_size;
1497 char* l_out;
1498 char* l_in;
1499 char* l_mod;
1500 char* l_exp;
1501
1502 if (in_len > out_len || in_len == 0 ||
1503 out_len == 0 || mod_len == 0 || exp_len == 0)
1504 {
1505 return 0;
1506 }
1507 max_size = out_len;
1508 if (in_len > max_size)
1509 {
1510 max_size = in_len;
1511 }
1512 if (mod_len > max_size)
1513 {
1514 max_size = mod_len;
1515 }
1516 if (exp_len > max_size)
1517 {
1518 max_size = exp_len;
1519 }
1520 l_out = (char*)g_malloc(max_size, 1);
1521 l_in = (char*)g_malloc(max_size, 1);
1522 l_mod = (char*)g_malloc(max_size, 1);
1523 l_exp = (char*)g_malloc(max_size, 1);
1524 memcpy(l_in, in, in_len);
1525 memcpy(l_mod, mod, mod_len);
1526 memcpy(l_exp, exp, exp_len);
1527 e = (DIGIT_T*)l_exp;
1528 x = (DIGIT_T*)l_in;
1529 y = (DIGIT_T*)l_out;
1530 m = (DIGIT_T*)l_mod;
1531 /* Find second-most significant bit in e */
1532 n = mpSizeof(e, max_size / 4);
1533 for (mask = HIBITMASK; mask > 0; mask >>= 1)
1534 {
1535 if (e[n - 1] & mask)
1536 {
1537 break;
1538 }
1539 }
1540 mpNEXTBITMASK(mask, n);
1541 /* Set y = x */
1542 mpSetEqual(y, x, max_size / 4);
1543 /* For bit j = k - 2 downto 0 step -1 */
1544 while (n)
1545 {
1546 mpModMult(y, y, y, m, max_size / 4); /* Square */
1547 if (e[n - 1] & mask)
1548 {
1549 mpModMult(y, y, x, m, max_size / 4); /* Multiply */
1550 }
1551 /* Move to next bit */
1552 mpNEXTBITMASK(mask, n);
1553 }
1554 memcpy(out, l_out, out_len);
1555 g_free(l_out);
1556 g_free(l_in);
1557 g_free(l_mod);
1558 g_free(l_exp);
1559 return out_len;
1560 }
1561
1562 static uint8 g_ppk_n[72] =
1563 {
1564 0x3D, 0x3A, 0x5E, 0xBD, 0x72, 0x43, 0x3E, 0xC9,
1565 0x4D, 0xBB, 0xC1, 0x1E, 0x4A, 0xBA, 0x5F, 0xCB,
1566 0x3E, 0x88, 0x20, 0x87, 0xEF, 0xF5, 0xC1, 0xE2,
1567 0xD7, 0xB7, 0x6B, 0x9A, 0xF2, 0x52, 0x45, 0x95,
1568 0xCE, 0x63, 0x65, 0x6B, 0x58, 0x3A, 0xFE, 0xEF,
1569 0x7C, 0xE7, 0xBF, 0xFE, 0x3D, 0xF6, 0x5C, 0x7D,
1570 0x6C, 0x5E, 0x06, 0x09, 0x1A, 0xF5, 0x61, 0xBB,
1571 0x20, 0x93, 0x09, 0x5F, 0x05, 0x6D, 0xEA, 0x87,
1572 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
1573 };
1574
1575 static uint8 g_ppk_d[108] =
1576 {
1577 0x87, 0xA7, 0x19, 0x32, 0xDA, 0x11, 0x87, 0x55,
1578 0x58, 0x00, 0x16, 0x16, 0x25, 0x65, 0x68, 0xF8,
1579 0x24, 0x3E, 0xE6, 0xFA, 0xE9, 0x67, 0x49, 0x94,
1580 0xCF, 0x92, 0xCC, 0x33, 0x99, 0xE8, 0x08, 0x60,
1581 0x17, 0x9A, 0x12, 0x9F, 0x24, 0xDD, 0xB1, 0x24,
1582 0x99, 0xC7, 0x3A, 0xB8, 0x0A, 0x7B, 0x0D, 0xDD,
1583 0x35, 0x07, 0x79, 0x17, 0x0B, 0x51, 0x9B, 0xB3,
1584 0xC7, 0x10, 0x01, 0x13, 0xE7, 0x3F, 0xF3, 0x5F,
1585 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1586 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1587 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1588 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1589 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1590 0x00, 0x00, 0x00, 0x00
1591 };
1592
1593 int
rdssl_sign_ok(char * e_data,int e_len,char * n_data,int n_len,char * sign_data,int sign_len,char * sign_data2,int sign_len2,char * testkey)1594 rdssl_sign_ok(char* e_data, int e_len, char* n_data, int n_len,
1595 char* sign_data, int sign_len, char* sign_data2, int sign_len2, char* testkey)
1596 {
1597 char* key;
1598 char* md5_final;
1599 void* md5;
1600
1601 if ((e_len != 4) || (n_len != 64) || (sign_len != 64) || (sign_len2 != 64))
1602 {
1603 return 1;
1604 }
1605 md5 = rdssl_md5_info_create();
1606 if (!md5)
1607 {
1608 return 1;
1609 }
1610 key = (char*)xmalloc(176);
1611 md5_final = (char*)xmalloc(64);
1612 // copy the test key
1613 memcpy(key, testkey, 176);
1614 // replace e and n
1615 memcpy(key + 32, e_data, 4);
1616 memcpy(key + 36, n_data, 64);
1617 rdssl_md5_clear(md5);
1618 // the first 108 bytes
1619 rdssl_md5_transform(md5, key, 108);
1620 // set the whole thing with 0xff
1621 memset(md5_final, 0xff, 64);
1622 // digest 16 bytes
1623 rdssl_md5_complete(md5, md5_final);
1624 // set non 0xff array items
1625 md5_final[16] = 0;
1626 md5_final[62] = 1;
1627 md5_final[63] = 0;
1628 // encrypt
1629 rdssl_mod_exp(sign_data, 64, md5_final, 64, (char*)g_ppk_n, 64,
1630 (char*)g_ppk_d, 64);
1631 // cleanup
1632 rdssl_md5_info_delete(md5);
1633 xfree(key);
1634 xfree(md5_final);
1635 return memcmp(sign_data, sign_data2, sign_len2);
1636 }
1637
1638 /*****************************************************************************/
rdssl_cert_read(uint8 * data,uint32 len)1639 PCCERT_CONTEXT rdssl_cert_read(uint8 * data, uint32 len)
1640 {
1641 PCCERT_CONTEXT res;
1642 if (!data || !len)
1643 {
1644 error("rdssl_cert_read %p %ld\n", data, len);
1645 return NULL;
1646 }
1647 res = CertCreateCertificateContext(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, data, len);
1648 if (!res)
1649 {
1650 error("CertCreateCertificateContext call failed with %lx\n", GetLastError());
1651 }
1652 return res;
1653 }
1654
1655 /*****************************************************************************/
rdssl_cert_free(PCCERT_CONTEXT context)1656 void rdssl_cert_free(PCCERT_CONTEXT context)
1657 {
1658 if (context)
1659 CertFreeCertificateContext(context);
1660 }
1661
1662 /*****************************************************************************/
rdssl_cert_to_rkey(PCCERT_CONTEXT cert,uint32 * key_len)1663 uint8 *rdssl_cert_to_rkey(PCCERT_CONTEXT cert, uint32 * key_len)
1664 {
1665 HCRYPTPROV hCryptProv;
1666 HCRYPTKEY hKey;
1667 BOOL ret;
1668 BYTE * rkey;
1669 DWORD dwSize, dwErr;
1670 ret = CryptAcquireContext(&hCryptProv,
1671 NULL,
1672 MS_ENHANCED_PROV,
1673 PROV_RSA_FULL,
1674 0);
1675 if (!ret)
1676 {
1677 dwErr = GetLastError();
1678 if (dwErr == NTE_BAD_KEYSET)
1679 {
1680 ret = CryptAcquireContext(&hCryptProv,
1681 L"MSTSC",
1682 MS_ENHANCED_PROV,
1683 PROV_RSA_FULL,
1684 CRYPT_NEWKEYSET);
1685 }
1686 }
1687 if (!ret)
1688 {
1689 dwErr = GetLastError();
1690 error("CryptAcquireContext call failed with %lx\n", dwErr);
1691 return NULL;
1692 }
1693 ret = CryptImportPublicKeyInfoEx(hCryptProv,
1694 X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
1695 &(cert->pCertInfo->SubjectPublicKeyInfo),
1696 0,
1697 0,
1698 NULL,
1699 &hKey);
1700 if (!ret)
1701 {
1702 dwErr = GetLastError();
1703 CryptReleaseContext(hCryptProv, 0);
1704 error("CryptImportPublicKeyInfoEx call failed with %lx\n", dwErr);
1705 return NULL;
1706 }
1707 ret = CryptExportKey(hKey,
1708 0,
1709 PUBLICKEYBLOB,
1710 0,
1711 NULL,
1712 &dwSize);
1713 if (!ret)
1714 {
1715 dwErr = GetLastError();
1716 CryptDestroyKey(hKey);
1717 CryptReleaseContext(hCryptProv, 0);
1718 error("CryptExportKey call failed with %lx\n", dwErr);
1719 return NULL;
1720 }
1721 rkey = g_malloc(dwSize, 0);
1722 ret = CryptExportKey(hKey,
1723 0,
1724 PUBLICKEYBLOB,
1725 0,
1726 rkey,
1727 &dwSize);
1728 if (!ret)
1729 {
1730 dwErr = GetLastError();
1731 g_free(rkey);
1732 CryptDestroyKey(hKey);
1733 CryptReleaseContext(hCryptProv, 0);
1734 error("CryptExportKey call failed with %lx\n", dwErr);
1735 return NULL;
1736 }
1737 CryptDestroyKey(hKey);
1738 CryptReleaseContext(hCryptProv, 0);
1739 return rkey;
1740 }
1741
1742 /*****************************************************************************/
rdssl_certs_ok(PCCERT_CONTEXT server_cert,PCCERT_CONTEXT cacert)1743 RD_BOOL rdssl_certs_ok(PCCERT_CONTEXT server_cert, PCCERT_CONTEXT cacert)
1744 {
1745 /* FIXME should we check for expired certificates??? */
1746 DWORD dwFlags = CERT_STORE_SIGNATURE_FLAG; /* CERT_STORE_TIME_VALIDITY_FLAG */
1747 BOOL ret = CertVerifySubjectCertificateContext(server_cert,
1748 cacert,
1749 &dwFlags);
1750 if (!ret)
1751 {
1752 error("CertVerifySubjectCertificateContext call failed with %lx\n", GetLastError());
1753 }
1754 if (dwFlags)
1755 {
1756 error("CertVerifySubjectCertificateContext check failed %lx\n", dwFlags);
1757 }
1758 return (dwFlags == 0);
1759 }
1760
1761 /*****************************************************************************/
rdssl_rkey_get_exp_mod(uint8 * rkey,uint8 * exponent,uint32 max_exp_len,uint8 * modulus,uint32 max_mod_len)1762 int rdssl_rkey_get_exp_mod(uint8 * rkey, uint8 * exponent, uint32 max_exp_len, uint8 * modulus,
1763 uint32 max_mod_len)
1764 {
1765 RSAPUBKEY *desc = (RSAPUBKEY *)(rkey + sizeof(PUBLICKEYSTRUC));
1766 if (!rkey || !exponent || !max_exp_len || !modulus || !max_mod_len)
1767 {
1768 error("rdssl_rkey_get_exp_mod %p %p %ld %p %ld\n", rkey, exponent, max_exp_len, modulus, max_mod_len);
1769 return -1;
1770 }
1771 memcpy (exponent, &desc->pubexp, max_exp_len);
1772 memcpy (modulus, rkey + sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY), max_mod_len);
1773 return 0;
1774 }
1775
1776 /*****************************************************************************/
rdssl_rkey_free(uint8 * rkey)1777 void rdssl_rkey_free(uint8 * rkey)
1778 {
1779 if (!rkey)
1780 {
1781 error("rdssl_rkey_free rkey is null\n");
1782 return;
1783 }
1784 g_free(rkey);
1785 }
1786