xref: /reactos/base/applications/mstsc/ssl_calls.c (revision c2c66aff)
1 /*
2    This program is free software; you can redistribute it and/or modify
3    it under the terms of the GNU General Public License as published by
4    the Free Software Foundation; either version 2 of the License, or
5    (at your option) any later version.
6 
7    This program is distributed in the hope that it will be useful,
8    but WITHOUT ANY WARRANTY; without even the implied warranty of
9    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
10    GNU General Public License for more details.
11 
12    You should have received a copy of the GNU General Public License along
13    with this program; if not, write to the Free Software Foundation, Inc.,
14    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
15 
16    xrdp: A Remote Desktop Protocol server.
17    Copyright (C) Jay Sorg 2004-2005
18 
19    ssl calls
20 
21 */
22 
23 #include "precomp.h"
24 
25 /*****************************************************************************/
g_malloc(int size,int zero)26 static void * g_malloc(int size, int zero)
27 {
28     void * p;
29 
30     p = CryptMemAlloc(size);
31     if (zero)
32     {
33         memset(p, 0, size);
34     }
35     return p;
36 }
37 
38 /*****************************************************************************/
g_free(void * in)39 static void g_free(void * in)
40 {
41     CryptMemFree(in);
42 }
43 
44 struct rc4_state
45 {
46     HCRYPTPROV hCryptProv;
47     HCRYPTKEY hKey;
48 };
49 /*****************************************************************************/
50 void*
rdssl_rc4_info_create(void)51 rdssl_rc4_info_create(void)
52 {
53     struct rc4_state *info = g_malloc(sizeof(struct rc4_state), 1);
54     BOOL ret;
55     DWORD dwErr;
56     if (!info)
57     {
58         error("rdssl_rc4_info_create no memory\n");
59         return NULL;
60     }
61     ret = CryptAcquireContext(&info->hCryptProv,
62                               L"MSTSC",
63                               MS_ENHANCED_PROV,
64                               PROV_RSA_FULL,
65                               0);
66     if (!ret)
67     {
68         dwErr = GetLastError();
69         if (dwErr == NTE_BAD_KEYSET)
70         {
71             ret = CryptAcquireContext(&info->hCryptProv,
72                                       L"MSTSC",
73                                       MS_ENHANCED_PROV,
74                                       PROV_RSA_FULL,
75                                       CRYPT_NEWKEYSET);
76         }
77     }
78     if (!ret)
79     {
80         dwErr = GetLastError();
81         error("CryptAcquireContext failed with %lx\n", dwErr);
82         g_free(info);
83         return NULL;
84     }
85     return info;
86 }
87 
88 /*****************************************************************************/
89 void
rdssl_rc4_info_delete(void * rc4_info)90 rdssl_rc4_info_delete(void* rc4_info)
91 {
92     struct rc4_state *info = rc4_info;
93     BOOL ret = TRUE;
94     DWORD dwErr;
95     if (!info)
96     {
97         //error("rdssl_rc4_info_delete rc4_info is null\n");
98         return;
99     }
100     if (info->hKey)
101     {
102         ret = CryptDestroyKey(info->hKey);
103         if (!ret)
104         {
105             dwErr = GetLastError();
106             error("CryptDestroyKey failed with %lx\n", dwErr);
107         }
108     }
109     if (info->hCryptProv)
110     {
111         ret = CryptReleaseContext(info->hCryptProv, 0);
112         if (!ret)
113         {
114             dwErr = GetLastError();
115             error("CryptReleaseContext failed with %lx\n", dwErr);
116         }
117     }
118     g_free(rc4_info);
119 }
120 
121 /*****************************************************************************/
122 void
rdssl_rc4_set_key(void * rc4_info,char * key,int len)123 rdssl_rc4_set_key(void* rc4_info, char* key, int len)
124 {
125     struct rc4_state *info = rc4_info;
126     BOOL ret;
127     DWORD dwErr;
128     BYTE * blob;
129     PUBLICKEYSTRUC *desc;
130     DWORD * keySize;
131     BYTE * keyBuf;
132     if (!rc4_info || !key || !len || !info->hCryptProv)
133     {
134         error("rdssl_rc4_set_key %p %p %d\n", rc4_info, key, len);
135         return;
136     }
137     blob = g_malloc(sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + len, 0);
138     if (!blob)
139     {
140         error("rdssl_rc4_set_key no memory\n");
141         return;
142     }
143     desc = (PUBLICKEYSTRUC *)blob;
144     keySize = (DWORD *)(blob + sizeof(PUBLICKEYSTRUC));
145     keyBuf = blob + sizeof(PUBLICKEYSTRUC) + sizeof(DWORD);
146     desc->aiKeyAlg = CALG_RC4;
147     desc->bType = PLAINTEXTKEYBLOB;
148     desc->bVersion = CUR_BLOB_VERSION;
149     desc->reserved = 0;
150     *keySize = len;
151     memcpy(keyBuf, key, len);
152     if (info->hKey)
153     {
154         CryptDestroyKey(info->hKey);
155         info->hKey = 0;
156     }
157     ret = CryptImportKey(info->hCryptProv,
158                          blob,
159                          sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + len,
160                          0,
161                          CRYPT_EXPORTABLE,
162                          &info->hKey);
163     g_free(blob);
164     if (!ret)
165     {
166         dwErr = GetLastError();
167         error("CryptImportKey failed with %lx\n", dwErr);
168     }
169 }
170 
171 /*****************************************************************************/
172 void
rdssl_rc4_crypt(void * rc4_info,char * in_data,char * out_data,int len)173 rdssl_rc4_crypt(void* rc4_info, char* in_data, char* out_data, int len)
174 {
175     struct rc4_state *info = rc4_info;
176     BOOL ret;
177     DWORD dwErr;
178     BYTE * intermediate_data;
179     DWORD dwLen = len;
180     if (!rc4_info || !in_data || !out_data || !len || !info->hKey)
181     {
182         error("rdssl_rc4_crypt %p %p %p %d\n", rc4_info, in_data, out_data, len);
183         return;
184     }
185     intermediate_data = g_malloc(len, 0);
186     if (!intermediate_data)
187     {
188         error("rdssl_rc4_set_key no memory\n");
189         return;
190     }
191     memcpy(intermediate_data, in_data, len);
192     ret = CryptEncrypt(info->hKey,
193                        0,
194                        FALSE,
195                        0,
196                        intermediate_data,
197                        &dwLen,
198                        dwLen);
199     if (!ret)
200     {
201         dwErr = GetLastError();
202         g_free(intermediate_data);
203         error("CryptEncrypt failed with %lx\n", dwErr);
204         return;
205     }
206     memcpy(out_data, intermediate_data, len);
207     g_free(intermediate_data);
208 }
209 
210 struct hash_context
211 {
212     HCRYPTPROV hCryptProv;
213     HCRYPTKEY hHash;
214 };
215 
216 /*****************************************************************************/
217 void*
rdssl_hash_info_create(ALG_ID id)218 rdssl_hash_info_create(ALG_ID id)
219 {
220     struct hash_context *info = g_malloc(sizeof(struct hash_context), 1);
221     BOOL ret;
222     DWORD dwErr;
223     if (!info)
224     {
225         error("rdssl_hash_info_create %d no memory\n", id);
226         return NULL;
227     }
228     ret = CryptAcquireContext(&info->hCryptProv,
229                               L"MSTSC",
230                               MS_ENHANCED_PROV,
231                               PROV_RSA_FULL,
232                               0);
233     if (!ret)
234     {
235         dwErr = GetLastError();
236         if (dwErr == NTE_BAD_KEYSET)
237         {
238             ret = CryptAcquireContext(&info->hCryptProv,
239                                       L"MSTSC",
240                                       MS_ENHANCED_PROV,
241                                       PROV_RSA_FULL,
242                                       CRYPT_NEWKEYSET);
243         }
244     }
245     if (!ret)
246     {
247         dwErr = GetLastError();
248         g_free(info);
249         error("CryptAcquireContext failed with %lx\n", dwErr);
250         return NULL;
251     }
252     ret = CryptCreateHash(info->hCryptProv,
253                           id,
254                           0,
255                           0,
256                           &info->hHash);
257     if (!ret)
258     {
259         dwErr = GetLastError();
260         CryptReleaseContext(info->hCryptProv, 0);
261         g_free(info);
262         error("CryptCreateHash failed with %lx\n", dwErr);
263         return NULL;
264     }
265     return info;
266 }
267 
268 /*****************************************************************************/
269 void
rdssl_hash_info_delete(void * hash_info)270 rdssl_hash_info_delete(void* hash_info)
271 {
272     struct hash_context *info = hash_info;
273     if (!info)
274     {
275         //error("ssl_hash_info_delete hash_info is null\n");
276         return;
277     }
278     if (info->hHash)
279     {
280         CryptDestroyHash(info->hHash);
281     }
282     if (info->hCryptProv)
283     {
284         CryptReleaseContext(info->hCryptProv, 0);
285     }
286     g_free(hash_info);
287 }
288 
289 /*****************************************************************************/
290 void
rdssl_hash_clear(void * hash_info,ALG_ID id)291 rdssl_hash_clear(void* hash_info, ALG_ID id)
292 {
293     struct hash_context *info = hash_info;
294     BOOL ret;
295     DWORD dwErr;
296     if (!info || !info->hHash || !info->hCryptProv)
297     {
298         error("rdssl_hash_clear %p\n", info);
299         return;
300     }
301     ret = CryptDestroyHash(info->hHash);
302     if (!ret)
303     {
304         dwErr = GetLastError();
305         error("CryptDestroyHash failed with %lx\n", dwErr);
306         return;
307     }
308     ret = CryptCreateHash(info->hCryptProv,
309                           id,
310                           0,
311                           0,
312                           &info->hHash);
313     if (!ret)
314     {
315         dwErr = GetLastError();
316         error("CryptCreateHash failed with %lx\n", dwErr);
317     }
318 }
319 
320 void
rdssl_hash_transform(void * hash_info,char * data,int len)321 rdssl_hash_transform(void* hash_info, char* data, int len)
322 {
323     struct hash_context *info = hash_info;
324     BOOL ret;
325     DWORD dwErr;
326     if (!info || !info->hHash || !info->hCryptProv || !data || !len)
327     {
328         error("rdssl_hash_transform %p %p %d\n", hash_info, data, len);
329         return;
330     }
331     ret = CryptHashData(info->hHash,
332                         (BYTE *)data,
333                         len,
334                         0);
335     if (!ret)
336     {
337         dwErr = GetLastError();
338         error("CryptHashData failed with %lx\n", dwErr);
339     }
340 }
341 
342 /*****************************************************************************/
343 void
rdssl_hash_complete(void * hash_info,char * data)344 rdssl_hash_complete(void* hash_info, char* data)
345 {
346     struct hash_context *info = hash_info;
347     BOOL ret;
348     DWORD dwErr, dwDataLen;
349     if (!info || !info->hHash || !info->hCryptProv || !data)
350     {
351         error("rdssl_hash_complete %p %p\n", hash_info, data);
352         return;
353     }
354     ret = CryptGetHashParam(info->hHash,
355                             HP_HASHVAL,
356                             NULL,
357                             &dwDataLen,
358                             0);
359     if (!ret)
360     {
361         dwErr = GetLastError();
362         error("CryptGetHashParam failed with %lx\n", dwErr);
363         return;
364     }
365     ret = CryptGetHashParam(info->hHash,
366                             HP_HASHVAL,
367                             (BYTE *)data,
368                             &dwDataLen,
369                             0);
370     if (!ret)
371     {
372         dwErr = GetLastError();
373         error("CryptGetHashParam failed with %lx\n", dwErr);
374     }
375 }
376 
377 /*****************************************************************************/
378 void*
rdssl_sha1_info_create(void)379 rdssl_sha1_info_create(void)
380 {
381     return rdssl_hash_info_create(CALG_SHA1);
382 }
383 
384 /*****************************************************************************/
385 void
rdssl_sha1_info_delete(void * sha1_info)386 rdssl_sha1_info_delete(void* sha1_info)
387 {
388     rdssl_hash_info_delete(sha1_info);
389 }
390 
391 /*****************************************************************************/
392 void
rdssl_sha1_clear(void * sha1_info)393 rdssl_sha1_clear(void* sha1_info)
394 {
395     rdssl_hash_clear(sha1_info, CALG_SHA1);
396 }
397 
398 /*****************************************************************************/
399 void
rdssl_sha1_transform(void * sha1_info,char * data,int len)400 rdssl_sha1_transform(void* sha1_info, char* data, int len)
401 {
402     rdssl_hash_transform(sha1_info, data, len);
403 }
404 
405 /*****************************************************************************/
406 void
rdssl_sha1_complete(void * sha1_info,char * data)407 rdssl_sha1_complete(void* sha1_info, char* data)
408 {
409     rdssl_hash_complete(sha1_info, data);
410 }
411 
412 /*****************************************************************************/
413 void*
rdssl_md5_info_create(void)414 rdssl_md5_info_create(void)
415 {
416     return rdssl_hash_info_create(CALG_MD5);
417 }
418 
419 /*****************************************************************************/
420 void
rdssl_md5_info_delete(void * md5_info)421 rdssl_md5_info_delete(void* md5_info)
422 {
423     rdssl_hash_info_delete(md5_info);
424 }
425 
426 /*****************************************************************************/
427 void
rdssl_md5_clear(void * md5_info)428 rdssl_md5_clear(void* md5_info)
429 {
430     rdssl_hash_clear(md5_info, CALG_MD5);
431 }
432 
433 /*****************************************************************************/
434 void
rdssl_md5_transform(void * md5_info,char * data,int len)435 rdssl_md5_transform(void* md5_info, char* data, int len)
436 {
437     rdssl_hash_transform(md5_info, data, len);
438 }
439 
440 /*****************************************************************************/
441 void
rdssl_md5_complete(void * md5_info,char * data)442 rdssl_md5_complete(void* md5_info, char* data)
443 {
444     rdssl_hash_complete(md5_info, data);
445 }
446 
447 /*****************************************************************************/
448 void
rdssl_hmac_md5(char * key,int keylen,char * data,int len,char * output)449 rdssl_hmac_md5(char* key, int keylen, char* data, int len, char* output)
450 {
451     HCRYPTPROV hCryptProv;
452     HCRYPTKEY hKey;
453     HCRYPTKEY hHash;
454     BOOL ret;
455     DWORD dwErr, dwDataLen;
456     HMAC_INFO info;
457     BYTE * blob;
458     PUBLICKEYSTRUC *desc;
459     DWORD * keySize;
460     BYTE * keyBuf;
461     BYTE sum[16];
462 
463     if (!key || !keylen || !data || !len ||!output)
464     {
465         error("rdssl_hmac_md5 %p %d %p %d %p\n", key, keylen, data, len, output);
466         return;
467     }
468     blob = g_malloc(sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + keylen, 0);
469     desc = (PUBLICKEYSTRUC *)blob;
470     keySize = (DWORD *)(blob + sizeof(PUBLICKEYSTRUC));
471     keyBuf = blob + sizeof(PUBLICKEYSTRUC) + sizeof(DWORD);
472     if (!blob)
473     {
474         error("rdssl_hmac_md5 %d no memory\n");
475         return;
476     }
477     ret = CryptAcquireContext(&hCryptProv,
478                               L"MSTSC",
479                               MS_ENHANCED_PROV,
480                               PROV_RSA_FULL,
481                               0);
482     if (!ret)
483     {
484         dwErr = GetLastError();
485         if (dwErr == NTE_BAD_KEYSET)
486         {
487             ret = CryptAcquireContext(&hCryptProv,
488                                       L"MSTSC",
489                                       MS_ENHANCED_PROV,
490                                       PROV_RSA_FULL,
491                                       CRYPT_NEWKEYSET);
492         }
493     }
494     if (!ret)
495     {
496         dwErr = GetLastError();
497         g_free(blob);
498         error("CryptAcquireContext failed with %lx\n", dwErr);
499         return;
500     }
501     desc->aiKeyAlg = CALG_RC4;
502     desc->bType = PLAINTEXTKEYBLOB;
503     desc->bVersion = CUR_BLOB_VERSION;
504     desc->reserved = 0;
505     if (keylen > 64)
506     {
507         HCRYPTKEY hHash;
508         ret = CryptCreateHash(hCryptProv,
509                               CALG_MD5,
510                               0,
511                               0,
512                               &hHash);
513         if (!ret)
514         {
515             dwErr = GetLastError();
516             g_free(blob);
517             error("CryptCreateHash failed with %lx\n", dwErr);
518             return;
519         }
520         ret = CryptHashData(hHash,
521                             (BYTE *)key,
522                             keylen,
523                             0);
524         if (!ret)
525         {
526             dwErr = GetLastError();
527             g_free(blob);
528             error("CryptHashData failed with %lx\n", dwErr);
529             return;
530         }
531         ret = CryptGetHashParam(hHash,
532                                 HP_HASHVAL,
533                                 NULL,
534                                 &dwDataLen,
535                                 0);
536         if (!ret)
537         {
538             dwErr = GetLastError();
539             g_free(blob);
540             error("CryptGetHashParam failed with %lx\n", dwErr);
541             return;
542         }
543         ret = CryptGetHashParam(hHash,
544                                 HP_HASHVAL,
545                                 sum,
546                                 &dwDataLen,
547                                 0);
548         if (!ret)
549         {
550             dwErr = GetLastError();
551             g_free(blob);
552             error("CryptGetHashParam failed with %lx\n", dwErr);
553             return;
554         }
555         keylen = dwDataLen;
556         key = (char *)sum;
557     }
558     *keySize = keylen;
559     memcpy(keyBuf, key, keylen);
560     ret = CryptImportKey(hCryptProv,
561                          blob,
562                          sizeof(PUBLICKEYSTRUC) + sizeof(DWORD) + keylen,
563                          0,
564                          CRYPT_EXPORTABLE,
565                          &hKey);
566     g_free(blob);
567     if (!ret)
568     {
569         dwErr = GetLastError();
570         error("CryptImportKey failed with %lx\n", dwErr);
571         return;
572     }
573     ret = CryptCreateHash(hCryptProv,
574                           CALG_HMAC,
575                           hKey,
576                           0,
577                           &hHash);
578     if (!ret)
579     {
580         dwErr = GetLastError();
581         error("CryptCreateHash failed with %lx\n", dwErr);
582         return;
583     }
584     info.HashAlgid = CALG_MD5;
585     info.cbInnerString = 0;
586     info.cbOuterString = 0;
587     ret = CryptSetHashParam(hHash,
588                             HP_HMAC_INFO,
589                             (BYTE *)&info,
590                             0);
591     if (!ret)
592     {
593         dwErr = GetLastError();
594         error("CryptSetHashParam failed with %lx\n", dwErr);
595         return;
596     }
597     ret = CryptHashData(hHash,
598                         (BYTE *)data,
599                         len,
600                         0);
601     if (!ret)
602     {
603         dwErr = GetLastError();
604         error("CryptHashData failed with %lx\n", dwErr);
605         return;
606     }
607     ret = CryptGetHashParam(hHash,
608                             HP_HASHVAL,
609                             NULL,
610                             &dwDataLen,
611                             0);
612     if (!ret)
613     {
614         dwErr = GetLastError();
615         error("CryptGetHashParam failed with %lx\n", dwErr);
616         return;
617     }
618     ret = CryptGetHashParam(hHash,
619                             HP_HASHVAL,
620                             (BYTE *)output,
621                             &dwDataLen,
622                             0);
623     if (!ret)
624     {
625         dwErr = GetLastError();
626         error("CryptGetHashParam failed with %lx\n", dwErr);
627         return;
628     }
629     CryptDestroyHash(hHash);
630     ret = CryptReleaseContext(hCryptProv, 0);
631 }
632 
633 /*****************************************************************************/
634 /*****************************************************************************/
635 /* big number stuff */
636 /******************* SHORT COPYRIGHT NOTICE*************************
637 This source code is part of the BigDigits multiple-precision
638 arithmetic library Version 1.0 originally written by David Ireland,
639 copyright (c) 2001 D.I. Management Services Pty Limited, all rights
640 reserved. It is provided "as is" with no warranties. You may use
641 this software under the terms of the full copyright notice
642 "bigdigitsCopyright.txt" that should have been included with
643 this library. To obtain a copy send an email to
644 <code@di-mgt.com.au> or visit <www.di-mgt.com.au/crypto.html>.
645 This notice must be retained in any copy.
646 ****************** END OF COPYRIGHT NOTICE*************************/
647 /************************* COPYRIGHT NOTICE*************************
648 This source code is part of the BigDigits multiple-precision
649 arithmetic library Version 1.0 originally written by David Ireland,
650 copyright (c) 2001 D.I. Management Services Pty Limited, all rights
651 reserved. You are permitted to use compiled versions of this code as
652 part of your own executable files and to distribute unlimited copies
653 of such executable files for any purposes including commercial ones
654 provided you keep the copyright notices intact in the source code
655 and that you ensure that the following characters remain in any
656 object or executable files you distribute:
657 
658 "Contains multiple-precision arithmetic code originally written
659 by David Ireland, copyright (c) 2001 by D.I. Management Services
660 Pty Limited <www.di-mgt.com.au>, and is used with permission."
661 
662 David Ireland and DI Management Services Pty Limited make no
663 representations concerning either the merchantability of this
664 software or the suitability of this software for any particular
665 purpose. It is provided "as is" without express or implied warranty
666 of any kind.
667 
668 Please forward any comments and bug reports to <code@di-mgt.com.au>.
669 The latest version of the source code can be downloaded from
670 www.di-mgt.com.au/crypto.html.
671 ****************** END OF COPYRIGHT NOTICE*************************/
672 
673 typedef unsigned int DIGIT_T;
674 #define HIBITMASK 0x80000000
675 #define MAX_DIG_LEN 51
676 #define MAX_DIGIT 0xffffffff
677 #define BITS_PER_DIGIT 32
678 #define MAX_HALF_DIGIT 0xffff
679 #define B_J (MAX_HALF_DIGIT + 1)
680 #define LOHALF(x) ((DIGIT_T)((x) & 0xffff))
681 #define HIHALF(x) ((DIGIT_T)((x) >> 16 & 0xffff))
682 #define TOHIGH(x) ((DIGIT_T)((x) << 16))
683 
684 #define mpNEXTBITMASK(mask, n) \
685 { \
686   if (mask == 1) \
687   { \
688     mask = HIBITMASK; \
689     n--; \
690   } \
691   else \
692   { \
693     mask >>= 1; \
694   } \
695 }
696 
697 /*****************************************************************************/
698 static DIGIT_T
mpAdd(DIGIT_T * w,DIGIT_T * u,DIGIT_T * v,unsigned int ndigits)699 mpAdd(DIGIT_T* w, DIGIT_T* u, DIGIT_T* v, unsigned int ndigits)
700 {
701   /* Calculates w = u + v
702      where w, u, v are multiprecision integers of ndigits each
703      Returns carry if overflow. Carry = 0 or 1.
704 
705      Ref: Knuth Vol 2 Ch 4.3.1 p 266 Algorithm A. */
706   DIGIT_T k;
707   unsigned int j;
708 
709   /* Step A1. Initialise */
710   k = 0;
711   for (j = 0; j < ndigits; j++)
712   {
713     /* Step A2. Add digits w_j = (u_j + v_j + k)
714        Set k = 1 if carry (overflow) occurs */
715     w[j] = u[j] + k;
716     if (w[j] < k)
717     {
718       k = 1;
719     }
720     else
721     {
722       k = 0;
723     }
724     w[j] += v[j];
725     if (w[j] < v[j])
726     {
727       k++;
728     }
729   } /* Step A3. Loop on j */
730   return k; /* w_n = k */
731 }
732 
733 /*****************************************************************************/
734 static void
mpSetDigit(DIGIT_T * a,DIGIT_T d,unsigned int ndigits)735 mpSetDigit(DIGIT_T* a, DIGIT_T d, unsigned int ndigits)
736 { /* Sets a = d where d is a single digit */
737   unsigned int i;
738 
739   for (i = 1; i < ndigits; i++)
740   {
741     a[i] = 0;
742   }
743   a[0] = d;
744 }
745 
746 /*****************************************************************************/
747 static int
mpCompare(DIGIT_T * a,DIGIT_T * b,unsigned int ndigits)748 mpCompare(DIGIT_T* a, DIGIT_T* b, unsigned int ndigits)
749 {
750   /* Returns sign of (a - b) */
751   if (ndigits == 0)
752   {
753     return 0;
754   }
755   while (ndigits--)
756   {
757     if (a[ndigits] > b[ndigits])
758     {
759       return 1; /* GT */
760     }
761     if (a[ndigits] < b[ndigits])
762     {
763       return -1; /* LT */
764     }
765   }
766   return 0; /* EQ */
767 }
768 
769 /*****************************************************************************/
770 static void
mpSetZero(DIGIT_T * a,unsigned int ndigits)771 mpSetZero(DIGIT_T* a, unsigned int ndigits)
772 { /* Sets a = 0 */
773   unsigned int i;
774 
775   for (i = 0; i < ndigits; i++)
776   {
777     a[i] = 0;
778   }
779 }
780 
781 /*****************************************************************************/
782 static void
mpSetEqual(DIGIT_T * a,DIGIT_T * b,unsigned int ndigits)783 mpSetEqual(DIGIT_T* a, DIGIT_T* b, unsigned int ndigits)
784 {  /* Sets a = b */
785   unsigned int i;
786 
787   for (i = 0; i < ndigits; i++)
788   {
789     a[i] = b[i];
790   }
791 }
792 
793 /*****************************************************************************/
794 static unsigned int
mpSizeof(DIGIT_T * a,unsigned int ndigits)795 mpSizeof(DIGIT_T* a, unsigned int ndigits)
796 {  /* Returns size of significant digits in a */
797   while (ndigits--)
798   {
799     if (a[ndigits] != 0)
800     {
801       return (++ndigits);
802     }
803   }
804   return 0;
805 }
806 
807 /*****************************************************************************/
808 static DIGIT_T
mpShiftLeft(DIGIT_T * a,DIGIT_T * b,unsigned int x,unsigned int ndigits)809 mpShiftLeft(DIGIT_T* a, DIGIT_T* b, unsigned int x, unsigned int ndigits)
810 { /* Computes a = b << x */
811   unsigned int i;
812   unsigned int y;
813   DIGIT_T mask;
814   DIGIT_T carry;
815   DIGIT_T nextcarry;
816 
817   /* Check input - NB unspecified result */
818   if (x >= BITS_PER_DIGIT)
819   {
820     return 0;
821   }
822   /* Construct mask */
823   mask = HIBITMASK;
824   for (i = 1; i < x; i++)
825   {
826     mask = (mask >> 1) | mask;
827   }
828   if (x == 0)
829   {
830     mask = 0x0;
831   }
832   y = BITS_PER_DIGIT - x;
833   carry = 0;
834   for (i = 0; i < ndigits; i++)
835   {
836     nextcarry = (b[i] & mask) >> y;
837     a[i] = b[i] << x | carry;
838     carry = nextcarry;
839   }
840   return carry;
841 }
842 
843 /*****************************************************************************/
844 static DIGIT_T
mpShiftRight(DIGIT_T * a,DIGIT_T * b,unsigned int x,unsigned int ndigits)845 mpShiftRight(DIGIT_T* a, DIGIT_T* b, unsigned int x, unsigned int ndigits)
846 { /* Computes a = b >> x */
847   unsigned int i;
848   unsigned int y;
849   DIGIT_T mask;
850   DIGIT_T carry;
851   DIGIT_T nextcarry;
852 
853   /* Check input  - NB unspecified result */
854   if (x >= BITS_PER_DIGIT)
855   {
856     return 0;
857   }
858   /* Construct mask */
859   mask = 0x1;
860   for (i = 1; i < x; i++)
861   {
862     mask = (mask << 1) | mask;
863   }
864   if (x == 0)
865   {
866     mask = 0x0;
867   }
868   y = BITS_PER_DIGIT - x;
869   carry = 0;
870   i = ndigits;
871   while (i--)
872   {
873     nextcarry = (b[i] & mask) << y;
874     a[i] = b[i] >> x | carry;
875     carry = nextcarry;
876   }
877   return carry;
878 }
879 
880 /*****************************************************************************/
881 static void
spMultSub(DIGIT_T * uu,DIGIT_T qhat,DIGIT_T v1,DIGIT_T v0)882 spMultSub(DIGIT_T* uu, DIGIT_T qhat, DIGIT_T v1, DIGIT_T v0)
883 {
884   /* Compute uu = uu - q(v1v0)
885      where uu = u3u2u1u0, u3 = 0
886      and u_n, v_n are all half-digits
887      even though v1, v2 are passed as full digits. */
888   DIGIT_T p0;
889   DIGIT_T p1;
890   DIGIT_T t;
891 
892   p0 = qhat * v0;
893   p1 = qhat * v1;
894   t = p0 + TOHIGH(LOHALF(p1));
895   uu[0] -= t;
896   if (uu[0] > MAX_DIGIT - t)
897   {
898     uu[1]--; /* Borrow */
899   }
900   uu[1] -= HIHALF(p1);
901 }
902 
903 /*****************************************************************************/
904 static int
spMultiply(DIGIT_T * p,DIGIT_T x,DIGIT_T y)905 spMultiply(DIGIT_T* p, DIGIT_T x, DIGIT_T y)
906 { /* Computes p = x * y */
907   /* Ref: Arbitrary Precision Computation
908      http://numbers.computation.free.fr/Constants/constants.html
909 
910          high    p1                p0     low
911         +--------+--------+--------+--------+
912         |      x1*y1      |      x0*y0      |
913         +--------+--------+--------+--------+
914                +-+--------+--------+
915                |1| (x0*y1 + x1*y1) |
916                +-+--------+--------+
917                 ^carry from adding (x0*y1+x1*y1) together
918                         +-+
919                         |1|< carry from adding LOHALF t
920                         +-+  to high half of p0 */
921   DIGIT_T x0;
922   DIGIT_T y0;
923   DIGIT_T x1;
924   DIGIT_T y1;
925   DIGIT_T t;
926   DIGIT_T u;
927   DIGIT_T carry;
928 
929   /* Split each x,y into two halves
930      x = x0 + B * x1
931      y = y0 + B * y1
932      where B = 2^16, half the digit size
933      Product is
934         xy = x0y0 + B(x0y1 + x1y0) + B^2(x1y1) */
935 
936   x0 = LOHALF(x);
937   x1 = HIHALF(x);
938   y0 = LOHALF(y);
939   y1 = HIHALF(y);
940 
941   /* Calc low part - no carry */
942   p[0] = x0 * y0;
943 
944   /* Calc middle part */
945   t = x0 * y1;
946   u = x1 * y0;
947   t += u;
948   if (t < u)
949   {
950     carry = 1;
951   }
952   else
953   {
954     carry = 0;
955   }
956   /* This carry will go to high half of p[1]
957      + high half of t into low half of p[1] */
958   carry = TOHIGH(carry) + HIHALF(t);
959 
960   /* Add low half of t to high half of p[0] */
961   t = TOHIGH(t);
962   p[0] += t;
963   if (p[0] < t)
964   {
965     carry++;
966   }
967 
968   p[1] = x1 * y1;
969   p[1] += carry;
970 
971   return 0;
972 }
973 
974 /*****************************************************************************/
975 static DIGIT_T
spDivide(DIGIT_T * q,DIGIT_T * r,DIGIT_T * u,DIGIT_T v)976 spDivide(DIGIT_T* q, DIGIT_T* r, DIGIT_T* u, DIGIT_T v)
977 { /* Computes quotient q = u / v, remainder r = u mod v
978      where u is a double digit
979      and q, v, r are single precision digits.
980      Returns high digit of quotient (max value is 1)
981      Assumes normalised such that v1 >= b/2
982      where b is size of HALF_DIGIT
983      i.e. the most significant bit of v should be one
984 
985      In terms of half-digits in Knuth notation:
986      (q2q1q0) = (u4u3u2u1u0) / (v1v0)
987      (r1r0) = (u4u3u2u1u0) mod (v1v0)
988      for m = 2, n = 2 where u4 = 0
989      q2 is either 0 or 1.
990      We set q = (q1q0) and return q2 as "overflow' */
991   DIGIT_T qhat;
992   DIGIT_T rhat;
993   DIGIT_T t;
994   DIGIT_T v0;
995   DIGIT_T v1;
996   DIGIT_T u0;
997   DIGIT_T u1;
998   DIGIT_T u2;
999   DIGIT_T u3;
1000   DIGIT_T uu[2];
1001   DIGIT_T q2;
1002 
1003   /* Check for normalisation */
1004   if (!(v & HIBITMASK))
1005   {
1006     *q = *r = 0;
1007     return MAX_DIGIT;
1008   }
1009 
1010   /* Split up into half-digits */
1011   v0 = LOHALF(v);
1012   v1 = HIHALF(v);
1013   u0 = LOHALF(u[0]);
1014   u1 = HIHALF(u[0]);
1015   u2 = LOHALF(u[1]);
1016   u3 = HIHALF(u[1]);
1017 
1018   /* Do three rounds of Knuth Algorithm D Vol 2 p272 */
1019 
1020   /* ROUND 1. Set j = 2 and calculate q2 */
1021   /* Estimate qhat = (u4u3)/v1  = 0 or 1
1022      then set (u4u3u2) -= qhat(v1v0)
1023      where u4 = 0. */
1024   qhat = u3 / v1;
1025   if (qhat > 0)
1026   {
1027     rhat = u3 - qhat * v1;
1028     t = TOHIGH(rhat) | u2;
1029     if (qhat * v0 > t)
1030     {
1031       qhat--;
1032     }
1033   }
1034   uu[1] = 0; /* (u4) */
1035   uu[0] = u[1]; /* (u3u2) */
1036   if (qhat > 0)
1037   {
1038     /* (u4u3u2) -= qhat(v1v0) where u4 = 0 */
1039     spMultSub(uu, qhat, v1, v0);
1040     if (HIHALF(uu[1]) != 0)
1041     { /* Add back */
1042       qhat--;
1043       uu[0] += v;
1044       uu[1] = 0;
1045     }
1046   }
1047   q2 = qhat;
1048   /* ROUND 2. Set j = 1 and calculate q1 */
1049   /* Estimate qhat = (u3u2) / v1
1050      then set (u3u2u1) -= qhat(v1v0) */
1051   t = uu[0];
1052   qhat = t / v1;
1053   rhat = t - qhat * v1;
1054   /* Test on v0 */
1055   t = TOHIGH(rhat) | u1;
1056   if ((qhat == B_J) || (qhat * v0 > t))
1057   {
1058     qhat--;
1059     rhat += v1;
1060     t = TOHIGH(rhat) | u1;
1061     if ((rhat < B_J) && (qhat * v0 > t))
1062     {
1063       qhat--;
1064     }
1065   }
1066   /* Multiply and subtract
1067      (u3u2u1)' = (u3u2u1) - qhat(v1v0) */
1068   uu[1] = HIHALF(uu[0]); /* (0u3) */
1069   uu[0] = TOHIGH(LOHALF(uu[0])) | u1; /* (u2u1) */
1070   spMultSub(uu, qhat, v1, v0);
1071   if (HIHALF(uu[1]) != 0)
1072   { /* Add back */
1073     qhat--;
1074     uu[0] += v;
1075     uu[1] = 0;
1076   }
1077   /* q1 = qhat */
1078   *q = TOHIGH(qhat);
1079   /* ROUND 3. Set j = 0 and calculate q0 */
1080   /* Estimate qhat = (u2u1) / v1
1081      then set (u2u1u0) -= qhat(v1v0) */
1082   t = uu[0];
1083   qhat = t / v1;
1084   rhat = t - qhat * v1;
1085   /* Test on v0 */
1086   t = TOHIGH(rhat) | u0;
1087   if ((qhat == B_J) || (qhat * v0 > t))
1088   {
1089     qhat--;
1090     rhat += v1;
1091     t = TOHIGH(rhat) | u0;
1092     if ((rhat < B_J) && (qhat * v0 > t))
1093     {
1094       qhat--;
1095     }
1096   }
1097   /* Multiply and subtract
1098      (u2u1u0)" = (u2u1u0)' - qhat(v1v0) */
1099   uu[1] = HIHALF(uu[0]); /* (0u2) */
1100   uu[0] = TOHIGH(LOHALF(uu[0])) | u0; /* (u1u0) */
1101   spMultSub(uu, qhat, v1, v0);
1102   if (HIHALF(uu[1]) != 0)
1103   { /* Add back */
1104     qhat--;
1105     uu[0] += v;
1106     uu[1] = 0;
1107   }
1108   /* q0 = qhat */
1109   *q |= LOHALF(qhat);
1110   /* Remainder is in (u1u0) i.e. uu[0] */
1111   *r = uu[0];
1112   return q2;
1113 }
1114 
1115 /*****************************************************************************/
1116 static int
QhatTooBig(DIGIT_T qhat,DIGIT_T rhat,DIGIT_T vn2,DIGIT_T ujn2)1117 QhatTooBig(DIGIT_T qhat, DIGIT_T rhat, DIGIT_T vn2, DIGIT_T ujn2)
1118 { /* Returns true if Qhat is too big
1119      i.e. if (Qhat * Vn-2) > (b.Rhat + Uj+n-2) */
1120   DIGIT_T t[2];
1121 
1122   spMultiply(t, qhat, vn2);
1123   if (t[1] < rhat)
1124   {
1125     return 0;
1126   }
1127   else if (t[1] > rhat)
1128   {
1129     return 1;
1130   }
1131   else if (t[0] > ujn2)
1132   {
1133     return 1;
1134   }
1135   return 0;
1136 }
1137 
1138 /*****************************************************************************/
1139 static DIGIT_T
mpShortDiv(DIGIT_T * q,DIGIT_T * u,DIGIT_T v,unsigned int ndigits)1140 mpShortDiv(DIGIT_T* q, DIGIT_T* u, DIGIT_T v, unsigned int ndigits)
1141 {
1142   /* Calculates quotient q = u div v
1143      Returns remainder r = u mod v
1144      where q, u are multiprecision integers of ndigits each
1145      and d, v are single precision digits.
1146 
1147      Makes no assumptions about normalisation.
1148 
1149      Ref: Knuth Vol 2 Ch 4.3.1 Exercise 16 p625 */
1150   unsigned int j;
1151   unsigned int shift;
1152   DIGIT_T t[2];
1153   DIGIT_T r;
1154   DIGIT_T bitmask;
1155   DIGIT_T overflow;
1156   DIGIT_T* uu;
1157 
1158   if (ndigits == 0)
1159   {
1160     return 0;
1161   }
1162   if (v == 0)
1163   {
1164     return 0; /* Divide by zero error */
1165   }
1166   /* Normalise first */
1167   /* Requires high bit of V
1168      to be set, so find most signif. bit then shift left,
1169      i.e. d = 2^shift, u' = u * d, v' = v * d. */
1170   bitmask = HIBITMASK;
1171   for (shift = 0; shift < BITS_PER_DIGIT; shift++)
1172   {
1173     if (v & bitmask)
1174     {
1175       break;
1176     }
1177     bitmask >>= 1;
1178   }
1179   v <<= shift;
1180   overflow = mpShiftLeft(q, u, shift, ndigits);
1181   uu = q;
1182   /* Step S1 - modified for extra digit. */
1183   r = overflow; /* New digit Un */
1184   j = ndigits;
1185   while (j--)
1186   {
1187     /* Step S2. */
1188     t[1] = r;
1189     t[0] = uu[j];
1190     overflow = spDivide(&q[j], &r, t, v);
1191   }
1192   /* Unnormalise */
1193   r >>= shift;
1194   return r;
1195 }
1196 
1197 /*****************************************************************************/
1198 static DIGIT_T
mpMultSub(DIGIT_T wn,DIGIT_T * w,DIGIT_T * v,DIGIT_T q,unsigned int n)1199 mpMultSub(DIGIT_T wn, DIGIT_T* w, DIGIT_T* v, DIGIT_T q, unsigned int n)
1200 { /* Compute w = w - qv
1201      where w = (WnW[n-1]...W[0])
1202      return modified Wn. */
1203   DIGIT_T k;
1204   DIGIT_T t[2];
1205   unsigned int i;
1206 
1207   if (q == 0) /* No change */
1208   {
1209     return wn;
1210   }
1211   k = 0;
1212   for (i = 0; i < n; i++)
1213   {
1214     spMultiply(t, q, v[i]);
1215     w[i] -= k;
1216     if (w[i] > MAX_DIGIT - k)
1217     {
1218       k = 1;
1219     }
1220     else
1221     {
1222       k = 0;
1223     }
1224     w[i] -= t[0];
1225     if (w[i] > MAX_DIGIT - t[0])
1226     {
1227       k++;
1228     }
1229     k += t[1];
1230   }
1231   /* Cope with Wn not stored in array w[0..n-1] */
1232   wn -= k;
1233   return wn;
1234 }
1235 
1236 /*****************************************************************************/
1237 static int
mpDivide(DIGIT_T * q,DIGIT_T * r,DIGIT_T * u,unsigned int udigits,DIGIT_T * v,unsigned int vdigits)1238 mpDivide(DIGIT_T* q, DIGIT_T* r, DIGIT_T* u, unsigned int udigits,
1239          DIGIT_T* v, unsigned int vdigits)
1240 { /* Computes quotient q = u / v and remainder r = u mod v
1241      where q, r, u are multiple precision digits
1242      all of udigits and the divisor v is vdigits.
1243 
1244      Ref: Knuth Vol 2 Ch 4.3.1 p 272 Algorithm D.
1245 
1246      Do without extra storage space, i.e. use r[] for
1247      normalised u[], unnormalise v[] at end, and cope with
1248      extra digit Uj+n added to u after normalisation.
1249 
1250      WARNING: this trashes q and r first, so cannot do
1251      u = u / v or v = u mod v. */
1252   unsigned int shift;
1253   int n;
1254   int m;
1255   int j;
1256   int qhatOK;
1257   int cmp;
1258   DIGIT_T bitmask;
1259   DIGIT_T overflow;
1260   DIGIT_T qhat;
1261   DIGIT_T rhat;
1262   DIGIT_T t[2];
1263   DIGIT_T* uu;
1264   DIGIT_T* ww;
1265 
1266   /* Clear q and r */
1267   mpSetZero(q, udigits);
1268   mpSetZero(r, udigits);
1269   /* Work out exact sizes of u and v */
1270   n = (int)mpSizeof(v, vdigits);
1271   m = (int)mpSizeof(u, udigits);
1272   m -= n;
1273   /* Catch special cases */
1274   if (n == 0)
1275   {
1276     return -1;	/* Error: divide by zero */
1277   }
1278   if (n == 1)
1279   { /* Use short division instead */
1280     r[0] = mpShortDiv(q, u, v[0], udigits);
1281     return 0;
1282   }
1283   if (m < 0)
1284   { /* v > u, so just set q = 0 and r = u */
1285     mpSetEqual(r, u, udigits);
1286     return 0;
1287   }
1288   if (m == 0)
1289   { /* u and v are the same length */
1290     cmp = mpCompare(u, v, (unsigned int)n);
1291     if (cmp < 0)
1292     { /* v > u, as above */
1293       mpSetEqual(r, u, udigits);
1294       return 0;
1295     }
1296     else if (cmp == 0)
1297     { /* v == u, so set q = 1 and r = 0 */
1298       mpSetDigit(q, 1, udigits);
1299       return 0;
1300     }
1301   }
1302   /* In Knuth notation, we have:
1303      Given
1304      u = (Um+n-1 ... U1U0)
1305      v = (Vn-1 ... V1V0)
1306      Compute
1307      q = u/v = (QmQm-1 ... Q0)
1308      r = u mod v = (Rn-1 ... R1R0) */
1309   /* Step D1. Normalise */
1310   /* Requires high bit of Vn-1
1311      to be set, so find most signif. bit then shift left,
1312      i.e. d = 2^shift, u' = u * d, v' = v * d. */
1313   bitmask = HIBITMASK;
1314   for (shift = 0; shift < BITS_PER_DIGIT; shift++)
1315   {
1316     if (v[n - 1] & bitmask)
1317     {
1318       break;
1319     }
1320     bitmask >>= 1;
1321   }
1322   /* Normalise v in situ - NB only shift non-zero digits */
1323   overflow = mpShiftLeft(v, v, shift, n);
1324   /* Copy normalised dividend u*d into r */
1325   overflow = mpShiftLeft(r, u, shift, n + m);
1326   uu = r; /* Use ptr to keep notation constant */
1327   t[0] = overflow; /* New digit Um+n */
1328   /* Step D2. Initialise j. Set j = m */
1329   for (j = m; j >= 0; j--)
1330   {
1331     /* Step D3. Calculate Qhat = (b.Uj+n + Uj+n-1)/Vn-1 */
1332     qhatOK = 0;
1333     t[1] = t[0]; /* This is Uj+n */
1334     t[0] = uu[j+n-1];
1335     overflow = spDivide(&qhat, &rhat, t, v[n - 1]);
1336     /* Test Qhat */
1337     if (overflow)
1338     { /* Qhat = b */
1339       qhat = MAX_DIGIT;
1340       rhat = uu[j + n - 1];
1341       rhat += v[n - 1];
1342       if (rhat < v[n - 1]) /* Overflow */
1343       {
1344         qhatOK = 1;
1345       }
1346     }
1347     if (!qhatOK && QhatTooBig(qhat, rhat, v[n - 2], uu[j + n - 2]))
1348     { /* Qhat.Vn-2 > b.Rhat + Uj+n-2 */
1349       qhat--;
1350       rhat += v[n - 1];
1351       if (!(rhat < v[n - 1]))
1352       {
1353         if (QhatTooBig(qhat, rhat, v[n - 2], uu[j + n - 2]))
1354         {
1355           qhat--;
1356         }
1357       }
1358     }
1359     /* Step D4. Multiply and subtract */
1360     ww = &uu[j];
1361     overflow = mpMultSub(t[1], ww, v, qhat, (unsigned int)n);
1362     /* Step D5. Test remainder. Set Qj = Qhat */
1363     q[j] = qhat;
1364     if (overflow)
1365     { /* Step D6. Add back if D4 was negative */
1366       q[j]--;
1367       overflow = mpAdd(ww, ww, v, (unsigned int)n);
1368     }
1369     t[0] = uu[j + n - 1]; /* Uj+n on next round */
1370   } /* Step D7. Loop on j */
1371   /* Clear high digits in uu */
1372   for (j = n; j < m+n; j++)
1373   {
1374     uu[j] = 0;
1375   }
1376   /* Step D8. Unnormalise. */
1377   mpShiftRight(r, r, shift, n);
1378   mpShiftRight(v, v, shift, n);
1379   return 0;
1380 }
1381 
1382 /*****************************************************************************/
1383 static int
mpModulo(DIGIT_T * r,DIGIT_T * u,unsigned int udigits,DIGIT_T * v,unsigned int vdigits)1384 mpModulo(DIGIT_T* r, DIGIT_T* u, unsigned int udigits,
1385          DIGIT_T* v, unsigned int vdigits)
1386 {
1387   /* Calculates r = u mod v
1388      where r, v are multiprecision integers of length vdigits
1389      and u is a multiprecision integer of length udigits.
1390      r may overlap v.
1391 
1392      Note that r here is only vdigits long,
1393      whereas in mpDivide it is udigits long.
1394 
1395      Use remainder from mpDivide function. */
1396   /* Double-length temp variable for divide fn */
1397   DIGIT_T qq[MAX_DIG_LEN * 2];
1398   /* Use a double-length temp for r to allow overlap of r and v */
1399   DIGIT_T rr[MAX_DIG_LEN * 2];
1400 
1401   /* rr[2n] = u[2n] mod v[n] */
1402   mpDivide(qq, rr, u, udigits, v, vdigits);
1403   mpSetEqual(r, rr, vdigits);
1404   mpSetZero(rr, udigits);
1405   mpSetZero(qq, udigits);
1406   return 0;
1407 }
1408 
1409 /*****************************************************************************/
1410 static int
mpMultiply(DIGIT_T * w,DIGIT_T * u,DIGIT_T * v,unsigned int ndigits)1411 mpMultiply(DIGIT_T* w, DIGIT_T* u, DIGIT_T* v, unsigned int ndigits)
1412 {
1413   /* Computes product w = u * v
1414      where u, v are multiprecision integers of ndigits each
1415      and w is a multiprecision integer of 2*ndigits
1416      Ref: Knuth Vol 2 Ch 4.3.1 p 268 Algorithm M. */
1417   DIGIT_T k;
1418   DIGIT_T t[2];
1419   unsigned int i;
1420   unsigned int j;
1421   unsigned int m;
1422   unsigned int n;
1423 
1424   n = ndigits;
1425   m = n;
1426   /* Step M1. Initialise */
1427   for (i = 0; i < 2 * m; i++)
1428   {
1429     w[i] = 0;
1430   }
1431   for (j = 0; j < n; j++)
1432   {
1433     /* Step M2. Zero multiplier? */
1434     if (v[j] == 0)
1435     {
1436       w[j + m] = 0;
1437     }
1438     else
1439     {
1440       /* Step M3. Initialise i */
1441       k = 0;
1442       for (i = 0; i < m; i++)
1443       {
1444         /* Step M4. Multiply and add */
1445         /* t = u_i * v_j + w_(i+j) + k */
1446         spMultiply(t, u[i], v[j]);
1447         t[0] += k;
1448         if (t[0] < k)
1449         {
1450           t[1]++;
1451         }
1452         t[0] += w[i + j];
1453         if (t[0] < w[i+j])
1454         {
1455           t[1]++;
1456         }
1457         w[i + j] = t[0];
1458         k = t[1];
1459       }
1460       /* Step M5. Loop on i, set w_(j+m) = k */
1461       w[j + m] = k;
1462     }
1463   } /* Step M6. Loop on j */
1464   return 0;
1465 }
1466 
1467 /*****************************************************************************/
1468 static int
mpModMult(DIGIT_T * a,DIGIT_T * x,DIGIT_T * y,DIGIT_T * m,unsigned int ndigits)1469 mpModMult(DIGIT_T* a, DIGIT_T* x, DIGIT_T* y,
1470           DIGIT_T* m, unsigned int ndigits)
1471 { /* Computes a = (x * y) mod m */
1472   /* Double-length temp variable */
1473   DIGIT_T p[MAX_DIG_LEN * 2];
1474 
1475   /* Calc p[2n] = x * y */
1476   mpMultiply(p, x, y, ndigits);
1477   /* Then modulo */
1478   mpModulo(a, p, ndigits * 2, m, ndigits);
1479   mpSetZero(p, ndigits * 2);
1480   return 0;
1481 }
1482 
1483 /*****************************************************************************/
1484 int
rdssl_mod_exp(char * out,int out_len,char * in,int in_len,char * mod,int mod_len,char * exp,int exp_len)1485 rdssl_mod_exp(char* out, int out_len, char* in, int in_len,
1486               char* mod, int mod_len, char* exp, int exp_len)
1487 {
1488   /* Computes y = x ^ e mod m */
1489   /* Binary left-to-right method */
1490   DIGIT_T mask;
1491   DIGIT_T* e;
1492   DIGIT_T* x;
1493   DIGIT_T* y;
1494   DIGIT_T* m;
1495   unsigned int n;
1496   int max_size;
1497   char* l_out;
1498   char* l_in;
1499   char* l_mod;
1500   char* l_exp;
1501 
1502   if (in_len > out_len || in_len == 0 ||
1503       out_len == 0 || mod_len == 0 || exp_len == 0)
1504   {
1505     return 0;
1506   }
1507   max_size = out_len;
1508   if (in_len > max_size)
1509   {
1510     max_size = in_len;
1511   }
1512   if (mod_len > max_size)
1513   {
1514     max_size = mod_len;
1515   }
1516   if (exp_len > max_size)
1517   {
1518     max_size = exp_len;
1519   }
1520   l_out = (char*)g_malloc(max_size, 1);
1521   l_in = (char*)g_malloc(max_size, 1);
1522   l_mod = (char*)g_malloc(max_size, 1);
1523   l_exp = (char*)g_malloc(max_size, 1);
1524   memcpy(l_in, in, in_len);
1525   memcpy(l_mod, mod, mod_len);
1526   memcpy(l_exp, exp, exp_len);
1527   e = (DIGIT_T*)l_exp;
1528   x = (DIGIT_T*)l_in;
1529   y = (DIGIT_T*)l_out;
1530   m = (DIGIT_T*)l_mod;
1531   /* Find second-most significant bit in e */
1532   n = mpSizeof(e, max_size / 4);
1533   for (mask = HIBITMASK; mask > 0; mask >>= 1)
1534   {
1535     if (e[n - 1] & mask)
1536     {
1537       break;
1538     }
1539   }
1540   mpNEXTBITMASK(mask, n);
1541   /* Set y = x */
1542   mpSetEqual(y, x, max_size / 4);
1543   /* For bit j = k - 2 downto 0 step -1 */
1544   while (n)
1545   {
1546     mpModMult(y, y, y, m, max_size / 4); /* Square */
1547     if (e[n - 1] & mask)
1548     {
1549       mpModMult(y, y, x, m, max_size / 4); /* Multiply */
1550     }
1551     /* Move to next bit */
1552     mpNEXTBITMASK(mask, n);
1553   }
1554   memcpy(out, l_out, out_len);
1555   g_free(l_out);
1556   g_free(l_in);
1557   g_free(l_mod);
1558   g_free(l_exp);
1559   return out_len;
1560 }
1561 
1562 static uint8 g_ppk_n[72] =
1563 {
1564     0x3D, 0x3A, 0x5E, 0xBD, 0x72, 0x43, 0x3E, 0xC9,
1565     0x4D, 0xBB, 0xC1, 0x1E, 0x4A, 0xBA, 0x5F, 0xCB,
1566     0x3E, 0x88, 0x20, 0x87, 0xEF, 0xF5, 0xC1, 0xE2,
1567     0xD7, 0xB7, 0x6B, 0x9A, 0xF2, 0x52, 0x45, 0x95,
1568     0xCE, 0x63, 0x65, 0x6B, 0x58, 0x3A, 0xFE, 0xEF,
1569     0x7C, 0xE7, 0xBF, 0xFE, 0x3D, 0xF6, 0x5C, 0x7D,
1570     0x6C, 0x5E, 0x06, 0x09, 0x1A, 0xF5, 0x61, 0xBB,
1571     0x20, 0x93, 0x09, 0x5F, 0x05, 0x6D, 0xEA, 0x87,
1572     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
1573 };
1574 
1575 static uint8 g_ppk_d[108] =
1576 {
1577     0x87, 0xA7, 0x19, 0x32, 0xDA, 0x11, 0x87, 0x55,
1578     0x58, 0x00, 0x16, 0x16, 0x25, 0x65, 0x68, 0xF8,
1579     0x24, 0x3E, 0xE6, 0xFA, 0xE9, 0x67, 0x49, 0x94,
1580     0xCF, 0x92, 0xCC, 0x33, 0x99, 0xE8, 0x08, 0x60,
1581     0x17, 0x9A, 0x12, 0x9F, 0x24, 0xDD, 0xB1, 0x24,
1582     0x99, 0xC7, 0x3A, 0xB8, 0x0A, 0x7B, 0x0D, 0xDD,
1583     0x35, 0x07, 0x79, 0x17, 0x0B, 0x51, 0x9B, 0xB3,
1584     0xC7, 0x10, 0x01, 0x13, 0xE7, 0x3F, 0xF3, 0x5F,
1585     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1586     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1587     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1588     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1589     0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
1590     0x00, 0x00, 0x00, 0x00
1591 };
1592 
1593 int
rdssl_sign_ok(char * e_data,int e_len,char * n_data,int n_len,char * sign_data,int sign_len,char * sign_data2,int sign_len2,char * testkey)1594 rdssl_sign_ok(char* e_data, int e_len, char* n_data, int n_len,
1595     char* sign_data, int sign_len, char* sign_data2, int sign_len2, char* testkey)
1596 {
1597     char* key;
1598     char* md5_final;
1599     void* md5;
1600 
1601     if ((e_len != 4) || (n_len != 64) || (sign_len != 64) || (sign_len2 != 64))
1602     {
1603         return 1;
1604     }
1605     md5 = rdssl_md5_info_create();
1606     if (!md5)
1607     {
1608         return 1;
1609     }
1610     key = (char*)xmalloc(176);
1611     md5_final = (char*)xmalloc(64);
1612     // copy the test key
1613     memcpy(key, testkey, 176);
1614     // replace e and n
1615     memcpy(key + 32, e_data, 4);
1616     memcpy(key + 36, n_data, 64);
1617     rdssl_md5_clear(md5);
1618     // the first 108 bytes
1619     rdssl_md5_transform(md5, key, 108);
1620     // set the whole thing with 0xff
1621     memset(md5_final, 0xff, 64);
1622     // digest 16 bytes
1623     rdssl_md5_complete(md5, md5_final);
1624     // set non 0xff array items
1625     md5_final[16] = 0;
1626     md5_final[62] = 1;
1627     md5_final[63] = 0;
1628     // encrypt
1629     rdssl_mod_exp(sign_data, 64, md5_final, 64, (char*)g_ppk_n, 64,
1630         (char*)g_ppk_d, 64);
1631     // cleanup
1632     rdssl_md5_info_delete(md5);
1633     xfree(key);
1634     xfree(md5_final);
1635     return memcmp(sign_data, sign_data2, sign_len2);
1636 }
1637 
1638 /*****************************************************************************/
rdssl_cert_read(uint8 * data,uint32 len)1639 PCCERT_CONTEXT rdssl_cert_read(uint8 * data, uint32 len)
1640 {
1641     PCCERT_CONTEXT res;
1642     if (!data || !len)
1643     {
1644         error("rdssl_cert_read %p %ld\n", data, len);
1645         return NULL;
1646     }
1647     res = CertCreateCertificateContext(X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, data, len);
1648     if (!res)
1649     {
1650         error("CertCreateCertificateContext call failed with %lx\n", GetLastError());
1651     }
1652     return res;
1653 }
1654 
1655 /*****************************************************************************/
rdssl_cert_free(PCCERT_CONTEXT context)1656 void rdssl_cert_free(PCCERT_CONTEXT context)
1657 {
1658     if (context)
1659         CertFreeCertificateContext(context);
1660 }
1661 
1662 /*****************************************************************************/
rdssl_cert_to_rkey(PCCERT_CONTEXT cert,uint32 * key_len)1663 uint8 *rdssl_cert_to_rkey(PCCERT_CONTEXT cert, uint32 * key_len)
1664 {
1665     HCRYPTPROV hCryptProv;
1666     HCRYPTKEY hKey;
1667     BOOL ret;
1668     BYTE * rkey;
1669     DWORD dwSize, dwErr;
1670     ret = CryptAcquireContext(&hCryptProv,
1671                               NULL,
1672                               MS_ENHANCED_PROV,
1673                               PROV_RSA_FULL,
1674                               0);
1675     if (!ret)
1676     {
1677         dwErr = GetLastError();
1678         if (dwErr == NTE_BAD_KEYSET)
1679         {
1680             ret = CryptAcquireContext(&hCryptProv,
1681                                       L"MSTSC",
1682                                       MS_ENHANCED_PROV,
1683                                       PROV_RSA_FULL,
1684                                       CRYPT_NEWKEYSET);
1685         }
1686     }
1687     if (!ret)
1688     {
1689         dwErr = GetLastError();
1690         error("CryptAcquireContext call failed with %lx\n", dwErr);
1691         return NULL;
1692     }
1693     ret = CryptImportPublicKeyInfoEx(hCryptProv,
1694                                      X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
1695                                      &(cert->pCertInfo->SubjectPublicKeyInfo),
1696                                      0,
1697                                      0,
1698                                      NULL,
1699                                      &hKey);
1700     if (!ret)
1701     {
1702         dwErr = GetLastError();
1703         CryptReleaseContext(hCryptProv, 0);
1704         error("CryptImportPublicKeyInfoEx call failed with %lx\n", dwErr);
1705         return NULL;
1706     }
1707     ret = CryptExportKey(hKey,
1708                          0,
1709                          PUBLICKEYBLOB,
1710                          0,
1711                          NULL,
1712                          &dwSize);
1713     if (!ret)
1714     {
1715         dwErr = GetLastError();
1716         CryptDestroyKey(hKey);
1717         CryptReleaseContext(hCryptProv, 0);
1718         error("CryptExportKey call failed with %lx\n", dwErr);
1719         return NULL;
1720     }
1721     rkey = g_malloc(dwSize, 0);
1722     ret = CryptExportKey(hKey,
1723                          0,
1724                          PUBLICKEYBLOB,
1725                          0,
1726                          rkey,
1727                          &dwSize);
1728     if (!ret)
1729     {
1730         dwErr = GetLastError();
1731         g_free(rkey);
1732         CryptDestroyKey(hKey);
1733         CryptReleaseContext(hCryptProv, 0);
1734         error("CryptExportKey call failed with %lx\n", dwErr);
1735         return NULL;
1736     }
1737     CryptDestroyKey(hKey);
1738     CryptReleaseContext(hCryptProv, 0);
1739     return rkey;
1740 }
1741 
1742 /*****************************************************************************/
rdssl_certs_ok(PCCERT_CONTEXT server_cert,PCCERT_CONTEXT cacert)1743 RD_BOOL rdssl_certs_ok(PCCERT_CONTEXT server_cert, PCCERT_CONTEXT cacert)
1744 {
1745     /* FIXME should we check for expired certificates??? */
1746     DWORD dwFlags = CERT_STORE_SIGNATURE_FLAG; /* CERT_STORE_TIME_VALIDITY_FLAG */
1747     BOOL ret = CertVerifySubjectCertificateContext(server_cert,
1748                                                    cacert,
1749                                                    &dwFlags);
1750     if (!ret)
1751     {
1752         error("CertVerifySubjectCertificateContext call failed with %lx\n", GetLastError());
1753     }
1754     if (dwFlags)
1755     {
1756         error("CertVerifySubjectCertificateContext check failed %lx\n", dwFlags);
1757     }
1758     return (dwFlags == 0);
1759 }
1760 
1761 /*****************************************************************************/
rdssl_rkey_get_exp_mod(uint8 * rkey,uint8 * exponent,uint32 max_exp_len,uint8 * modulus,uint32 max_mod_len)1762 int rdssl_rkey_get_exp_mod(uint8 * rkey, uint8 * exponent, uint32 max_exp_len, uint8 * modulus,
1763     uint32 max_mod_len)
1764 {
1765     RSAPUBKEY *desc = (RSAPUBKEY *)(rkey + sizeof(PUBLICKEYSTRUC));
1766     if (!rkey || !exponent || !max_exp_len || !modulus || !max_mod_len)
1767     {
1768         error("rdssl_rkey_get_exp_mod %p %p %ld %p %ld\n", rkey, exponent, max_exp_len, modulus, max_mod_len);
1769         return -1;
1770     }
1771     memcpy (exponent, &desc->pubexp, max_exp_len);
1772     memcpy (modulus, rkey + sizeof(PUBLICKEYSTRUC) + sizeof(RSAPUBKEY), max_mod_len);
1773     return 0;
1774 }
1775 
1776 /*****************************************************************************/
rdssl_rkey_free(uint8 * rkey)1777 void rdssl_rkey_free(uint8 * rkey)
1778 {
1779     if (!rkey)
1780     {
1781         error("rdssl_rkey_free rkey is null\n");
1782         return;
1783     }
1784     g_free(rkey);
1785 }
1786