1 
2 
3 #include "BigInt.h"
4 #include <ctype.h>
5 #include <string.h>
6 
7 #include "RakAlloca.h"
8 #include "RakMemoryOverride.h"
9 #include "Rand.h"
10 
11 #if defined(_MSC_VER) && !defined(_DEBUG) && _MSC_VER > 1310
12 #include <intrin.h>
13 #endif
14 
15 namespace big
16 {
17 	static const char Bits256[] = {
18 		0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
19 		5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
20 		6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
21 		6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
22 		7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
23 		7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
24 		7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
25 		7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
26 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
27 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
28 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
29 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
30 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
31 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
32 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
33 		8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8
34 	};
35 
36 	// returns the degree of the base 2 monic polynomial
37 	// (the number of bits used to represent the number)
38 	// eg, 0 0 0 0 1 0 1 1 ... => 28 out of 32 used
Degree(uint32_t v)39 	uint32_t Degree(uint32_t v)
40 	{
41 //#if defined(_MSC_VER) && !defined(_DEBUG)
42 //		unsigned long index;
43 //		return _BitScanReverse(&index, v) ? (index + 1) : 0;
44 //#else
45 		uint32_t r, t = v >> 16;
46 
47 		if (t)	r = (r = t >> 8) ? 24 + Bits256[r] : 16 + Bits256[t];
48 		else 	r = (r = v >> 8) ? 8 + Bits256[r] : Bits256[v];
49 
50 		return r;
51 //#endif
52 	}
53 
54 	// returns the number of limbs that are actually used
LimbDegree(const uint32_t * n,int limbs)55 	int LimbDegree(const uint32_t *n, int limbs)
56 	{
57 		while (limbs--)
58 			if (n[limbs])
59 				return limbs + 1;
60 
61 		return 0;
62 	}
63 
64 	// return bits used
Degree(const uint32_t * n,int limbs)65 	uint32_t Degree(const uint32_t *n, int limbs)
66 	{
67 		uint32_t limb_degree = LimbDegree(n, limbs);
68 		if (!limb_degree) return 0;
69 		--limb_degree;
70 
71 		uint32_t msl_degree = Degree(n[limb_degree]);
72 
73 		return msl_degree + limb_degree*32;
74 	}
75 
Set(uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)76 	void Set(uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
77 	{
78 		int min = lhs_limbs < rhs_limbs ? lhs_limbs : rhs_limbs;
79 
80 		memcpy(lhs, rhs, min*4);
81 		memset(&lhs[min], 0, (lhs_limbs - min)*4);
82 	}
Set(uint32_t * lhs,int limbs,const uint32_t * rhs)83 	void Set(uint32_t *lhs, int limbs, const uint32_t *rhs)
84 	{
85 		memcpy(lhs, rhs, limbs*4);
86 	}
Set32(uint32_t * lhs,int lhs_limbs,const uint32_t rhs)87 	void Set32(uint32_t *lhs, int lhs_limbs, const uint32_t rhs)
88 	{
89 		*lhs = rhs;
90 		memset(&lhs[1], 0, (lhs_limbs - 1)*4);
91 	}
92 
93 #if defined(__BIG_ENDIAN__)
94 
95 	// Flip the byte order as needed to make 'n' big-endian for sharing over a network
ToLittleEndian(uint32_t * n,int limbs)96 	void ToLittleEndian(uint32_t *n, int limbs)
97 	{
98 		for (int ii = 0; ii < limbs; ++ii)
99 		{
100 			swapLE(n[ii]);
101 		}
102 	}
103 
104 	// Flip the byte order as needed to make big-endian 'n' use the local byte order
FromLittleEndian(uint32_t * n,int limbs)105 	void FromLittleEndian(uint32_t *n, int limbs)
106 	{
107 		// Same operation as ToBigEndian()
108 		ToLittleEndian(n, limbs);
109 	}
110 
111 #endif // __BIG_ENDIAN__
112 
Less(int limbs,const uint32_t * lhs,const uint32_t * rhs)113 	bool Less(int limbs, const uint32_t *lhs, const uint32_t *rhs)
114 	{
115 		for (int ii = limbs-1; ii >= 0; --ii)
116 			if (lhs[ii] != rhs[ii])
117 				return lhs[ii] < rhs[ii];
118 
119 		return false;
120 	}
Greater(int limbs,const uint32_t * lhs,const uint32_t * rhs)121 	bool Greater(int limbs, const uint32_t *lhs, const uint32_t *rhs)
122 	{
123 		for (int ii = limbs-1; ii >= 0; --ii)
124 			if (lhs[ii] != rhs[ii])
125 				return lhs[ii] > rhs[ii];
126 
127 		return false;
128 	}
Equal(int limbs,const uint32_t * lhs,const uint32_t * rhs)129 	bool Equal(int limbs, const uint32_t *lhs, const uint32_t *rhs)
130 	{
131 		return 0 == memcmp(lhs, rhs, limbs*4);
132 	}
133 
Less(const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)134 	bool Less(const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
135 	{
136 		if (lhs_limbs > rhs_limbs)
137 			do if (lhs[--lhs_limbs] != 0) return false; while (lhs_limbs > rhs_limbs);
138 		else if (lhs_limbs < rhs_limbs)
139 			do if (rhs[--rhs_limbs] != 0) return true; while (lhs_limbs < rhs_limbs);
140 
141 		while (lhs_limbs--) if (lhs[lhs_limbs] != rhs[lhs_limbs]) return lhs[lhs_limbs] < rhs[lhs_limbs];
142 		return false; // equal
143 	}
Greater(const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)144 	bool Greater(const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
145 	{
146 		if (lhs_limbs > rhs_limbs)
147 			do if (lhs[--lhs_limbs] != 0) return true; while (lhs_limbs > rhs_limbs);
148 		else if (lhs_limbs < rhs_limbs)
149 			do if (rhs[--rhs_limbs] != 0) return false; while (lhs_limbs < rhs_limbs);
150 
151 		while (lhs_limbs--) if (lhs[lhs_limbs] != rhs[lhs_limbs]) return lhs[lhs_limbs] > rhs[lhs_limbs];
152 		return false; // equal
153 	}
Equal(const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)154 	bool Equal(const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
155 	{
156 		if (lhs_limbs > rhs_limbs)
157 			do if (lhs[--lhs_limbs] != 0) return false; while (lhs_limbs > rhs_limbs);
158 		else if (lhs_limbs < rhs_limbs)
159 			do if (rhs[--rhs_limbs] != 0) return false; while (lhs_limbs < rhs_limbs);
160 
161 		while (lhs_limbs--) if (lhs[lhs_limbs] != rhs[lhs_limbs]) return false;
162 		return true; // equal
163 	}
164 
Greater32(const uint32_t * lhs,int lhs_limbs,uint32_t rhs)165 	bool Greater32(const uint32_t *lhs, int lhs_limbs, uint32_t rhs)
166 	{
167 		if (*lhs > rhs) return true;
168 		while (--lhs_limbs)
169 			if (*++lhs) return true;
170 		return false;
171 	}
Equal32(const uint32_t * lhs,int lhs_limbs,uint32_t rhs)172 	bool Equal32(const uint32_t *lhs, int lhs_limbs, uint32_t rhs)
173 	{
174 		if (*lhs != rhs) return false;
175 		while (--lhs_limbs)
176 			if (*++lhs) return false;
177 		return true; // equal
178 	}
179 
180 	// out = in >>> shift
181 	// Precondition: 0 <= shift < 31
ShiftRight(int limbs,uint32_t * out,const uint32_t * in,int shift)182 	void ShiftRight(int limbs, uint32_t *out, const uint32_t *in, int shift)
183 	{
184 		if (!shift)
185 		{
186 			Set(out, limbs, in);
187 			return;
188 		}
189 
190 		uint32_t carry = 0;
191 
192 		for (int ii = limbs - 1; ii >= 0; --ii)
193 		{
194 			uint32_t r = in[ii];
195 
196 			out[ii] = (r >> shift) | carry;
197 
198 			carry = r << (32 - shift);
199 		}
200 	}
201 
202 	// {out, carry} = in <<< shift
203 	// Precondition: 0 <= shift < 31
ShiftLeft(int limbs,uint32_t * out,const uint32_t * in,int shift)204 	uint32_t ShiftLeft(int limbs, uint32_t *out, const uint32_t *in, int shift)
205 	{
206 		if (!shift)
207 		{
208 			Set(out, limbs, in);
209 			return 0;
210 		}
211 
212 		uint32_t carry = 0;
213 
214 		for (int ii = 0; ii < limbs; ++ii)
215 		{
216 			uint32_t r = in[ii];
217 
218 			out[ii] = (r << shift) | carry;
219 
220 			carry = r >> (32 - shift);
221 		}
222 
223 		return carry;
224 	}
225 
226 	// lhs += rhs, return carry out
227 	// precondition: lhs_limbs >= rhs_limbs
Add(uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)228 	uint32_t Add(uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
229 	{
230 		int ii;
231 		uint64_t r = (uint64_t)lhs[0] + rhs[0];
232 		lhs[0] = (uint32_t)r;
233 
234 		for (ii = 1; ii < rhs_limbs; ++ii)
235 		{
236 			r = ((uint64_t)lhs[ii] + rhs[ii]) + (uint32_t)(r >> 32);
237 			lhs[ii] = (uint32_t)r;
238 		}
239 
240 		for (; ii < lhs_limbs && (uint32_t)(r >>= 32) != 0; ++ii)
241 		{
242 			r += lhs[ii];
243 			lhs[ii] = (uint32_t)r;
244 		}
245 
246 		return (uint32_t)(r >> 32);
247 	}
248 
249 	// out = lhs + rhs, return carry out
250 	// precondition: lhs_limbs >= rhs_limbs
Add(uint32_t * out,const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)251 	uint32_t Add(uint32_t *out, const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
252 	{
253 		int ii;
254 		uint64_t r = (uint64_t)lhs[0] + rhs[0];
255 		out[0] = (uint32_t)r;
256 
257 		for (ii = 1; ii < rhs_limbs; ++ii)
258 		{
259 			r = ((uint64_t)lhs[ii] + rhs[ii]) + (uint32_t)(r >> 32);
260 			out[ii] = (uint32_t)r;
261 		}
262 
263 		for (; ii < lhs_limbs && (uint32_t)(r >>= 32) != 0; ++ii)
264 		{
265 			r += lhs[ii];
266 			out[ii] = (uint32_t)r;
267 		}
268 
269 		return (uint32_t)(r >> 32);
270 	}
271 
272 	// lhs += rhs, return carry out
273 	// precondition: lhs_limbs > 0
Add32(uint32_t * lhs,int lhs_limbs,uint32_t rhs)274 	uint32_t Add32(uint32_t *lhs, int lhs_limbs, uint32_t rhs)
275 	{
276 		uint32_t n = lhs[0];
277 		uint32_t r = n + rhs;
278 		lhs[0] = r;
279 
280 		if (r >= n)
281 			return 0;
282 
283 		for (int ii = 1; ii < lhs_limbs; ++ii)
284 			if (++lhs[ii])
285 				return 0;
286 
287 		return 1;
288 	}
289 
290 	// lhs -= rhs, return borrow out
291 	// precondition: lhs_limbs >= rhs_limbs
Subtract(uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)292 	int32_t Subtract(uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
293 	{
294 		int ii;
295 		int64_t r = (int64_t)lhs[0] - rhs[0];
296 		lhs[0] = (uint32_t)r;
297 
298 		for (ii = 1; ii < rhs_limbs; ++ii)
299 		{
300 			r = ((int64_t)lhs[ii] - rhs[ii]) + (int32_t)(r >> 32);
301 			lhs[ii] = (uint32_t)r;
302 		}
303 
304 		for (; ii < lhs_limbs && (int32_t)(r >>= 32) != 0; ++ii)
305 		{
306 			r += lhs[ii];
307 			lhs[ii] = (uint32_t)r;
308 		}
309 
310 		return (int32_t)(r >> 32);
311 	}
312 
313 	// out = lhs - rhs, return borrow out
314 	// precondition: lhs_limbs >= rhs_limbs
Subtract(uint32_t * out,const uint32_t * lhs,int lhs_limbs,const uint32_t * rhs,int rhs_limbs)315 	int32_t Subtract(uint32_t *out, const uint32_t *lhs, int lhs_limbs, const uint32_t *rhs, int rhs_limbs)
316 	{
317 		int ii;
318 		int64_t r = (int64_t)lhs[0] - rhs[0];
319 		out[0] = (uint32_t)r;
320 
321 		for (ii = 1; ii < rhs_limbs; ++ii)
322 		{
323 			r = ((int64_t)lhs[ii] - rhs[ii]) + (int32_t)(r >> 32);
324 			out[ii] = (uint32_t)r;
325 		}
326 
327 		for (; ii < lhs_limbs && (int32_t)(r >>= 32) != 0; ++ii)
328 		{
329 			r += lhs[ii];
330 			out[ii] = (uint32_t)r;
331 		}
332 
333 		return (int32_t)(r >> 32);
334 	}
335 
336 	// lhs -= rhs, return borrow out
337 	// precondition: lhs_limbs > 0, result limbs = lhs_limbs
Subtract32(uint32_t * lhs,int lhs_limbs,uint32_t rhs)338 	int32_t Subtract32(uint32_t *lhs, int lhs_limbs, uint32_t rhs)
339 	{
340 		uint32_t n = lhs[0];
341 		uint32_t r = n - rhs;
342 		lhs[0] = r;
343 
344 		if (r <= n)
345 			return 0;
346 
347 		for (int ii = 1; ii < lhs_limbs; ++ii)
348 			if (lhs[ii]--)
349 				return 0;
350 
351 		return -1;
352 	}
353 
354 	// lhs = -rhs
Negate(int limbs,uint32_t * lhs,const uint32_t * rhs)355 	void Negate(int limbs, uint32_t *lhs, const uint32_t *rhs)
356 	{
357 		// Propagate negations until carries stop
358 		while (limbs-- > 0 && !(*lhs++ = -(int32_t)(*rhs++)));
359 
360 		// Then just invert the remaining words
361 		while (limbs-- > 0) *lhs++ = ~(*rhs++);
362 	}
363 
364 	// n = ~n, only invert bits up to the MSB, but none above that
BitNot(uint32_t * n,int limbs)365 	void BitNot(uint32_t *n, int limbs)
366 	{
367 		limbs = LimbDegree(n, limbs);
368 		if (limbs)
369 		{
370 			uint32_t high = n[--limbs];
371 			uint32_t high_degree = 32 - Degree(high);
372 
373 			n[limbs] = ((uint32_t)(~high << high_degree) >> high_degree);
374 			while (limbs--) n[limbs] = ~n[limbs];
375 		}
376 	}
377 
378 	// n = ~n, invert all bits, even ones above MSB
LimbNot(uint32_t * n,int limbs)379 	void LimbNot(uint32_t *n, int limbs)
380 	{
381 		while (limbs--) *n++ = ~(*n);
382 	}
383 
384 	// lhs ^= rhs
Xor(int limbs,uint32_t * lhs,const uint32_t * rhs)385 	void Xor(int limbs, uint32_t *lhs, const uint32_t *rhs)
386 	{
387 		while (limbs--) *lhs++ ^= *rhs++;
388 	}
389 
390 	// Return the carry out from A += B << S
AddLeftShift32(int limbs,uint32_t * A,const uint32_t * B,uint32_t S)391     uint32_t AddLeftShift32(
392     	int limbs,		// Number of limbs in parameter A and B
393     	uint32_t *A,			// Large number
394     	const uint32_t *B,	// Large number
395     	uint32_t S)			// 32-bit number
396 	{
397 		uint64_t sum = 0;
398 		uint32_t last = 0;
399 
400 		while (limbs--)
401 		{
402 			uint32_t b = *B++;
403 
404 			sum = (uint64_t)((b << S) | (last >> (32 - S))) + *A + (uint32_t)(sum >> 32);
405 
406 			last = b;
407 			*A++ = (uint32_t)sum;
408 		}
409 
410 		return (uint32_t)(sum >> 32) + (last >> (32 - S));
411 	}
412 
413 	// Return the carry out from result = A * B
Multiply32(int limbs,uint32_t * result,const uint32_t * A,uint32_t B)414     uint32_t Multiply32(
415     	int limbs,		// Number of limbs in parameter A, result
416     	uint32_t *result,	// Large number
417     	const uint32_t *A,	// Large number
418     	uint32_t B)			// 32-bit number
419 	{
420 		uint64_t p = (uint64_t)A[0] * B;
421 		result[0] = (uint32_t)p;
422 
423 		while (--limbs)
424 		{
425 			p = (uint64_t)*(++A) * B + (uint32_t)(p >> 32);
426 			*(++result) = (uint32_t)p;
427 		}
428 
429 		return (uint32_t)(p >> 32);
430 	}
431 
432 	// Return the carry out from X = X * M + A
MultiplyAdd32(int limbs,uint32_t * X,uint32_t M,uint32_t A)433     uint32_t MultiplyAdd32(
434     	int limbs,	// Number of limbs in parameter A and B
435     	uint32_t *X,		// Large number
436     	uint32_t M,		// Large number
437     	uint32_t A)		// 32-bit number
438 	{
439 		uint64_t p = (uint64_t)X[0] * M + A;
440 		X[0] = (uint32_t)p;
441 
442 		while (--limbs)
443 		{
444 			p = (uint64_t)*(++X) * M + (uint32_t)(p >> 32);
445 			*X = (uint32_t)p;
446 		}
447 
448 		return (uint32_t)(p >> 32);
449 	}
450 
451 	// Return the carry out from A += B * M
AddMultiply32(int limbs,uint32_t * A,const uint32_t * B,uint32_t M)452     uint32_t AddMultiply32(
453     	int limbs,		// Number of limbs in parameter A and B
454     	uint32_t *A,			// Large number
455     	const uint32_t *B,	// Large number
456     	uint32_t M)			// 32-bit number
457 	{
458 		// This function is roughly 85% of the cost of exponentiation
459 #if defined(ASSEMBLY_INTEL_SYNTAX)
460 		ASSEMBLY_BLOCK // VS.NET, x86, 32-bit words
461 		{
462 			mov esi, [B]
463 			mov edi, [A]
464 			mov eax, [esi]
465 			mul [M]					; (edx,eax) = [M]*[esi]
466 			add eax, [edi]			; (edx,eax) += [edi]
467 			adc edx, 0
468 			; (edx,eax) = [B]*[M] + [A]
469 
470 			mov [edi], eax
471 			; [A] = eax
472 
473 			mov ecx, [limbs]
474 			sub ecx, 1
475 			jz loop_done
476 loop_head:
477 				lea esi, [esi + 4]	; ++B
478 				mov eax, [esi]		; eax = [B]
479 				mov ebx, edx		; ebx = last carry
480 				lea edi, [edi + 4]	; ++A
481 				mul [M]				; (edx,eax) = [M]*[esi]
482 				add eax, [edi]		; (edx,eax) += [edi]
483 				adc edx, 0
484 				add eax, ebx		; (edx,eax) += ebx
485 				adc edx, 0
486 				; (edx,eax) = [esi]*[M] + [edi] + (ebx=last carry)
487 
488 				mov [edi], eax
489 				; [A] = eax
490 
491 			sub ecx, 1
492 			jnz loop_head
493 loop_done:
494 			mov [M], edx	; Use [M] to copy the carry into C++ land
495 		}
496 
497 		return M;
498 #else
499 		// Unrolled first loop
500 		uint64_t p = B[0] * (uint64_t)M + A[0];
501 		A[0] = (uint32_t)p;
502 
503 		while (--limbs)
504 		{
505 			p = (*(++B) * (uint64_t)M + *(++A)) + (uint32_t)(p >> 32);
506 			A[0] = (uint32_t)p;
507 		}
508 
509 		return (uint32_t)(p >> 32);
510 #endif
511 	}
512 
513 	// product = x * y
SimpleMultiply(int limbs,uint32_t * product,const uint32_t * x,const uint32_t * y)514 	void SimpleMultiply(
515 		int limbs,		// Number of limbs in parameters x, y
516 		uint32_t *product,	// Large number; buffer size = limbs*2
517 		const uint32_t *x,	// Large number
518 		const uint32_t *y)	// Large number
519 	{
520 		// Roughly 25% of the cost of exponentiation
521 		product[limbs] = Multiply32(limbs, product, x, y[0]);
522 
523 		uint32_t ctr = limbs;
524 		while (--ctr)
525 		{
526 			++product;
527 			product[limbs] = AddMultiply32(limbs, product, x, (++y)[0]);
528 		}
529 	}
530 
531 	// product = low half of x * y product
SimpleMultiplyLowHalf(int limbs,uint32_t * product,const uint32_t * x,const uint32_t * y)532 	void SimpleMultiplyLowHalf(
533 		int limbs,		// Number of limbs in parameters x, y
534 		uint32_t *product,	// Large number; buffer size = limbs
535 		const uint32_t *x,	// Large number
536 		const uint32_t *y)	// Large number
537 	{
538 		Multiply32(limbs, product, x, y[0]);
539 
540 		while (--limbs)
541 		{
542 			++product;
543 			++y;
544 			AddMultiply32(limbs, product, x, y[0]);
545 		}
546 	}
547 
548 	// product = x ^ 2
SimpleSquare(int limbs,uint32_t * product,const uint32_t * x)549 	void SimpleSquare(
550 		int limbs,		// Number of limbs in parameter x
551 		uint32_t *product,	// Large number; buffer size = limbs*2
552 		const uint32_t *x)	// Large number
553 	{
554 		// Seems about 15% faster than SimpleMultiply() in practice
555 		uint32_t *cross_product = (uint32_t*)alloca(limbs*2*4);
556 
557 		// Calculate square-less and repeat-less cross products
558 		cross_product[limbs] = Multiply32(limbs - 1, cross_product + 1, x + 1, x[0]);
559 		for (int ii = 1; ii < limbs - 1; ++ii)
560 		{
561 			cross_product[limbs + ii] = AddMultiply32(limbs - ii - 1, cross_product + ii*2 + 1, x + ii + 1, x[ii]);
562 		}
563 
564 		// Calculate square products
565 		for (int ii = 0; ii < limbs; ++ii)
566 		{
567 			uint32_t xi = x[ii];
568 			uint64_t si = (uint64_t)xi * xi;
569 			product[ii*2] = (uint32_t)si;
570 			product[ii*2+1] = (uint32_t)(si >> 32);
571 		}
572 
573 		// Multiply the cross product by 2 and add it to the square products
574 		product[limbs*2 - 1] += AddLeftShift32(limbs*2 - 2, product + 1, cross_product + 1, 1);
575 	}
576 
577 	// product = xy
578 	// memory space for product may not overlap with x,y
Multiply(int limbs,uint32_t * product,const uint32_t * x,const uint32_t * y)579     void Multiply(
580     	int limbs,		// Number of limbs in x,y
581     	uint32_t *product,	// Product; buffer size = limbs*2
582     	const uint32_t *x,	// Large number; buffer size = limbs
583     	const uint32_t *y)	// Large number; buffer size = limbs
584 	{
585 		// Stop recursing under 640 bits or odd limb count
586 		if (limbs < 30 || (limbs & 1))
587 		{
588 			SimpleMultiply(limbs, product, x, y);
589 			return;
590 		}
591 
592 		// Compute high and low products
593 		Multiply(limbs/2, product, x, y);
594 		Multiply(limbs/2, product + limbs, x + limbs/2, y + limbs/2);
595 
596 		// Compute (x1 + x2), xc = carry out
597 		uint32_t *xsum = (uint32_t*)alloca((limbs/2)*4);
598 		uint32_t xcarry = Add(xsum, x, limbs/2, x + limbs/2, limbs/2);
599 
600 		// Compute (y1 + y2), yc = carry out
601 		uint32_t *ysum = (uint32_t*)alloca((limbs/2)*4);
602 		uint32_t ycarry = Add(ysum, y, limbs/2, y + limbs/2, limbs/2);
603 
604 		// Compute (x1 + x2) * (y1 + y2)
605 		uint32_t *cross_product = (uint32_t*)alloca(limbs*4);
606 		Multiply(limbs/2, cross_product, xsum, ysum);
607 
608 		// Subtract out the high and low products
609 		int32_t cross_carry = Subtract(cross_product, limbs, product, limbs);
610 		cross_carry += Subtract(cross_product, limbs, product + limbs, limbs);
611 
612 		// Fix the extra high carry bits of the result
613 		if (ycarry) cross_carry += Add(cross_product + limbs/2, limbs/2, xsum, limbs/2);
614 		if (xcarry) cross_carry += Add(cross_product + limbs/2, limbs/2, ysum, limbs/2);
615 		cross_carry += (xcarry & ycarry);
616 
617 		// Add the cross product into the result
618 		cross_carry += Add(product + limbs/2, limbs*3/2, cross_product, limbs);
619 
620 		// Add in the fixed high carry bits
621 		if (cross_carry) Add32(product + limbs*3/2, limbs/2, cross_carry);
622 	}
623 
624 	// product = x^2
625 	// memory space for product may not overlap with x
Square(int limbs,uint32_t * product,const uint32_t * x)626     void Square(
627     	int limbs,		// Number of limbs in x
628     	uint32_t *product,	// Product; buffer size = limbs*2
629     	const uint32_t *x)	// Large number; buffer size = limbs
630 	{
631 		// Stop recursing under 1280 bits or odd limb count
632 		if (limbs < 40 || (limbs & 1))
633 		{
634 			SimpleSquare(limbs, product, x);
635 			return;
636 		}
637 
638 		// Compute high and low squares
639 		Square(limbs/2, product, x);
640 		Square(limbs/2, product + limbs, x + limbs/2);
641 
642 		// Generate the cross product
643 		uint32_t *cross_product = (uint32_t*)alloca(limbs*4);
644 		Multiply(limbs/2, cross_product, x, x + limbs/2);
645 
646 		// Multiply the cross product by 2 and add it to the result
647 		uint32_t cross_carry = AddLeftShift32(limbs, product + limbs/2, cross_product, 1);
648 
649 		// Roll the carry out up to the highest limb
650 		if (cross_carry) Add32(product + limbs*3/2, limbs/2, cross_carry);
651 	}
652 
653 	// Returns the remainder of N / divisor for a 32-bit divisor
Modulus32(int limbs,const uint32_t * N,uint32_t divisor)654     uint32_t Modulus32(
655     	int limbs,		// Number of limbs in parameter N
656     	const uint32_t *N,	// Large number, buffer size = limbs
657     	uint32_t divisor)	// 32-bit number
658 	{
659 		uint32_t remainder = N[limbs-1] < divisor ? N[limbs-1] : 0;
660 		uint32_t counter = N[limbs-1] < divisor ? limbs-1 : limbs;
661 
662 		while (counter--) remainder = (uint32_t)((((uint64_t)remainder << 32) | N[counter]) % divisor);
663 
664 		return remainder;
665 	}
666 
667 	/*
668 	 * 'A' is overwritten with the quotient of the operation
669 	 * Returns the remainder of 'A' / divisor for a 32-bit divisor
670 	 *
671 	 * Does not check for divide-by-zero
672 	 */
Divide32(int limbs,uint32_t * A,uint32_t divisor)673     uint32_t Divide32(
674     	int limbs,		// Number of limbs in parameter A
675     	uint32_t *A,			// Large number, buffer size = limbs
676     	uint32_t divisor)	// 32-bit number
677 	{
678 		uint64_t r = 0;
679 		for (int ii = limbs-1; ii >= 0; --ii)
680 		{
681 			uint64_t n = (r << 32) | A[ii];
682 			A[ii] = (uint32_t)(n / divisor);
683 			r = n % divisor;
684 		}
685 
686 		return (uint32_t)r;
687 	}
688 
689 	// returns (n ^ -1) Mod 2^32
MulInverse32(uint32_t n)690 	uint32_t MulInverse32(uint32_t n)
691 	{
692 		// {u1, g1} = 2^32 / n
693 		uint32_t hb = (~(n - 1) >> 31);
694 		uint32_t u1 = -(int32_t)(0xFFFFFFFF / n + hb);
695 		uint32_t g1 = ((-(int32_t)hb) & (0xFFFFFFFF % n + 1)) - n;
696 
697 		if (!g1) {
698 			if (n != 1) return 0;
699 			else return 1;
700 		}
701 
702 		uint32_t q, u = 1, g = n;
703 
704 		for (;;) {
705 			q = g / g1;
706 			g %= g1;
707 
708 			if (!g) {
709 				if (g1 != 1) return 0;
710 				else return u1;
711 			}
712 
713 			u -= q*u1;
714 			q = g1 / g;
715 			g1 %= g;
716 
717 			if (!g1) {
718 				if (g != 1) return 0;
719 				else return u;
720 			}
721 
722 			u1 -= q*u;
723 		}
724 	}
725 
726 	/*
727 	 * Computes multiplicative inverse of given number
728 	 * Such that: result * u = 1
729 	 * Using Extended Euclid's Algorithm (GCDe)
730 	 *
731 	 * This is not always possible, so it will return false iff not possible.
732 	 */
MulInverse(int limbs,const uint32_t * u,uint32_t * result)733 	bool MulInverse(
734 		int limbs,		// Limbs in u and result
735 		const uint32_t *u,	// Large number, buffer size = limbs
736 		uint32_t *result)	// Large number, buffer size = limbs
737 	{
738 		uint32_t *u1 = (uint32_t*)alloca(limbs*4);
739 		uint32_t *u3 = (uint32_t*)alloca(limbs*4);
740 		uint32_t *v1 = (uint32_t*)alloca(limbs*4);
741 		uint32_t *v3 = (uint32_t*)alloca(limbs*4);
742 		uint32_t *t1 = (uint32_t*)alloca(limbs*4);
743 		uint32_t *t3 = (uint32_t*)alloca(limbs*4);
744 		uint32_t *q = (uint32_t*)alloca((limbs+1)*4);
745 		uint32_t *w = (uint32_t*)alloca((limbs+1)*4);
746 
747 		// Unrolled first iteration
748 		{
749 			Set32(u1, limbs, 0);
750 			Set32(v1, limbs, 1);
751 			Set(v3, limbs, u);
752 		}
753 
754 		// Unrolled second iteration
755 		if (!LimbDegree(v3, limbs))
756 			return false;
757 
758 		// {q, t3} <- R / v3
759 		Set32(w, limbs, 0);
760 		w[limbs] = 1;
761 		Divide(w, limbs+1, v3, limbs, q, t3);
762 
763 		SimpleMultiplyLowHalf(limbs, t1, q, v1);
764 		Add(t1, limbs, u1, limbs);
765 
766 		for (;;)
767 		{
768 			if (!LimbDegree(t3, limbs))
769 			{
770 				Set(result, limbs, v1);
771 				return Equal32(v3, limbs, 1);
772 			}
773 
774 			Divide(v3, limbs, t3, limbs, q, u3);
775 			SimpleMultiplyLowHalf(limbs, u1, q, t1);
776 			Add(u1, limbs, v1, limbs);
777 
778 			if (!LimbDegree(u3, limbs))
779 			{
780 				Negate(limbs, result, t1);
781 				return Equal32(t3, limbs, 1);
782 			}
783 
784 			Divide(t3, limbs, u3, limbs, q, v3);
785 			SimpleMultiplyLowHalf(limbs, v1, q, u1);
786 			Add(v1, limbs, t1, limbs);
787 
788 			if (!LimbDegree(v3, limbs))
789 			{
790 				Set(result, limbs, u1);
791 				return Equal32(u3, limbs, 1);
792 			}
793 
794 			Divide(u3, limbs, v3, limbs, q, t3);
795 			SimpleMultiplyLowHalf(limbs, t1, q, v1);
796 			Add(t1, limbs, u1, limbs);
797 
798 			if (!LimbDegree(t3, limbs))
799 			{
800 				Negate(limbs, result, v1);
801 				return Equal32(v3, limbs, 1);
802 			}
803 
804 			Divide(v3, limbs, t3, limbs, q, u3);
805 			SimpleMultiplyLowHalf(limbs, u1, q, t1);
806 			Add(u1, limbs, v1, limbs);
807 
808 			if (!LimbDegree(u3, limbs))
809 			{
810 				Set(result, limbs, t1);
811 				return Equal32(t3, limbs, 1);
812 			}
813 
814 			Divide(t3, limbs, u3, limbs, q, v3);
815 			SimpleMultiplyLowHalf(limbs, v1, q, u1);
816 			Add(v1, limbs, t1, limbs);
817 
818 			if (!LimbDegree(v3, limbs))
819 			{
820 				Negate(limbs, result, u1);
821 				return Equal32(u3, limbs, 1);
822 			}
823 
824 			Divide(u3, limbs, v3, limbs, q, t3);
825 			SimpleMultiplyLowHalf(limbs, t1, q, v1);
826 			Add(t1, limbs, u1, limbs);
827 		}
828 	}
829 
830 	// {q, r} = u / v
831 	// q is not u or v
832 	// Return false on divide by zero
Divide(const uint32_t * u,int u_limbs,const uint32_t * v,int v_limbs,uint32_t * q,uint32_t * r)833 	bool Divide(
834 		const uint32_t *u,	// numerator, size = u_limbs
835 		int u_limbs,
836 		const uint32_t *v,	// denominator, size = v_limbs
837 		int v_limbs,
838 		uint32_t *q,			// quotient, size = u_limbs
839 		uint32_t *r)			// remainder, size = v_limbs
840 	{
841 		// calculate v_used and u_used
842 		int v_used = LimbDegree(v, v_limbs);
843 		if (!v_used) return false;
844 
845 		int u_used = LimbDegree(u, u_limbs);
846 
847 		// if u < v, avoid division
848 		if (u_used <= v_used && Less(u, u_used, v, v_used))
849 		{
850 			// r = u, q = 0
851 			Set(r, v_limbs, u, u_used);
852 			Set32(q, u_limbs, 0);
853 			return true;
854 		}
855 
856 		// if v is 32 bits, use faster Divide32 code
857 		if (v_used == 1)
858 		{
859 			// {q, r} = u / v[0]
860 			Set(q, u_limbs, u);
861 			Set32(r, v_limbs, Divide32(u_limbs, q, v[0]));
862 			return true;
863 		}
864 
865 		// calculate high zero bits in v's high used limb
866 		int shift = 32 - Degree(v[v_used - 1]);
867 		int uu_used = u_used;
868 		if (shift > 0) uu_used++;
869 
870 		uint32_t *uu = (uint32_t*)alloca(uu_used*4);
871 		uint32_t *vv = (uint32_t*)alloca(v_used*4);
872 
873 		// shift left to fill high MSB of divisor
874 		if (shift > 0)
875 		{
876 			ShiftLeft(v_used, vv, v, shift);
877 			uu[u_used] = ShiftLeft(u_used, uu, u, shift);
878 		}
879 		else
880 		{
881 			Set(uu, u_used, u);
882 			Set(vv, v_used, v);
883 		}
884 
885 		int q_high_index = uu_used - v_used;
886 
887 		if (GreaterOrEqual(uu + q_high_index, v_used, vv, v_used))
888 		{
889 			Subtract(uu + q_high_index, v_used, vv, v_used);
890 			Set32(q + q_high_index, u_used - q_high_index, 1);
891 		}
892 		else
893 		{
894 			Set32(q + q_high_index, u_used - q_high_index, 0);
895 		}
896 
897 		uint32_t *vq_product = (uint32_t*)alloca((v_used+1)*4);
898 
899 		// for each limb,
900 		for (int ii = q_high_index - 1; ii >= 0; --ii)
901 		{
902 			uint64_t q_full = *(uint64_t*)(uu + ii + v_used - 1) / vv[v_used - 1];
903 			uint32_t q_low = (uint32_t)q_full;
904 			uint32_t q_high = (uint32_t)(q_full >> 32);
905 
906 			vq_product[v_used] = Multiply32(v_used, vq_product, vv, q_low);
907 
908 			if (q_high) // it must be '1'
909 				Add(vq_product + 1, v_used, vv, v_used);
910 
911 			if (Subtract(uu + ii, v_used + 1, vq_product, v_used + 1))
912 			{
913 				--q_low;
914 				if (Add(uu + ii, v_used + 1, vv, v_used) == 0)
915 				{
916 					--q_low;
917 					Add(uu + ii, v_used + 1, vv, v_used);
918 				}
919 			}
920 
921 			q[ii] = q_low;
922 		}
923 
924 		memset(r + v_used, 0, (v_limbs - v_used)*4);
925 		ShiftRight(v_used, r, uu, shift);
926 
927 		return true;
928 	}
929 
930 	// r = u % v
931 	// Return false on divide by zero
Modulus(const uint32_t * u,int u_limbs,const uint32_t * v,int v_limbs,uint32_t * r)932 	bool Modulus(
933 		const uint32_t *u,	// numerator, size = u_limbs
934 		int u_limbs,
935 		const uint32_t *v,	// denominator, size = v_limbs
936 		int v_limbs,
937 		uint32_t *r)			// remainder, size = v_limbs
938 	{
939 		// calculate v_used and u_used
940 		int v_used = LimbDegree(v, v_limbs);
941 		if (!v_used) return false;
942 
943 		int u_used = LimbDegree(u, u_limbs);
944 
945 		// if u < v, avoid division
946 		if (u_used <= v_used && Less(u, u_used, v, v_used))
947 		{
948 			// r = u, q = 0
949 			Set(r, v_limbs, u, u_used);
950 			//Set32(q, u_limbs, 0);
951 			return true;
952 		}
953 
954 		// if v is 32 bits, use faster Divide32 code
955 		if (v_used == 1)
956 		{
957 			// {q, r} = u / v[0]
958 			//Set(q, u_limbs, u);
959 			Set32(r, v_limbs, Modulus32(u_limbs, u, v[0]));
960 			return true;
961 		}
962 
963 		// calculate high zero bits in v's high used limb
964 		int shift = 32 - Degree(v[v_used - 1]);
965 		int uu_used = u_used;
966 		if (shift > 0) uu_used++;
967 
968 		uint32_t *uu = (uint32_t*)alloca(uu_used*4);
969 		uint32_t *vv = (uint32_t*)alloca(v_used*4);
970 
971 		// shift left to fill high MSB of divisor
972 		if (shift > 0)
973 		{
974 			ShiftLeft(v_used, vv, v, shift);
975 			uu[u_used] = ShiftLeft(u_used, uu, u, shift);
976 		}
977 		else
978 		{
979 			Set(uu, u_used, u);
980 			Set(vv, v_used, v);
981 		}
982 
983 		int q_high_index = uu_used - v_used;
984 
985 		if (GreaterOrEqual(uu + q_high_index, v_used, vv, v_used))
986 		{
987 			Subtract(uu + q_high_index, v_used, vv, v_used);
988 			//Set32(q + q_high_index, u_used - q_high_index, 1);
989 		}
990 		else
991 		{
992 			//Set32(q + q_high_index, u_used - q_high_index, 0);
993 		}
994 
995 		uint32_t *vq_product = (uint32_t*)alloca((v_used+1)*4);
996 
997 		// for each limb,
998 		for (int ii = q_high_index - 1; ii >= 0; --ii)
999 		{
1000 			uint64_t q_full = *(uint64_t*)(uu + ii + v_used - 1) / vv[v_used - 1];
1001 			uint32_t q_low = (uint32_t)q_full;
1002 			uint32_t q_high = (uint32_t)(q_full >> 32);
1003 
1004 			vq_product[v_used] = Multiply32(v_used, vq_product, vv, q_low);
1005 
1006 			if (q_high) // it must be '1'
1007 				Add(vq_product + 1, v_used, vv, v_used);
1008 
1009 			if (Subtract(uu + ii, v_used + 1, vq_product, v_used + 1))
1010 			{
1011 				//--q_low;
1012 				if (Add(uu + ii, v_used + 1, vv, v_used) == 0)
1013 				{
1014 					//--q_low;
1015 					Add(uu + ii, v_used + 1, vv, v_used);
1016 				}
1017 			}
1018 
1019 			//q[ii] = q_low;
1020 		}
1021 
1022 		memset(r + v_used, 0, (v_limbs - v_used)*4);
1023 		ShiftRight(v_used, r, uu, shift);
1024 
1025 		return true;
1026 	}
1027 
1028 	// m_inv ~= 2^(2k)/m
1029 	// Generates m_inv parameter of BarrettModulus()
1030 	// It is limbs in size, chopping off the 2^k bit
1031 	// Only works for m with the high bit set
BarrettModulusPrecomp(int limbs,const uint32_t * m,uint32_t * m_inv)1032 	void BarrettModulusPrecomp(
1033 		int limbs,		// Number of limbs in m and m_inv
1034 		const uint32_t *m,	// Modulus, size = limbs
1035 		uint32_t *m_inv)		// Large number result, size = limbs
1036 	{
1037 		uint32_t *q = (uint32_t*)alloca((limbs*2+1)*4);
1038 
1039 		// q = 2^(2k)
1040 		big::Set32(q, limbs*2, 0);
1041 		q[limbs*2] = 1;
1042 
1043 		// q /= m
1044 		big::Divide(q, limbs*2+1, m, limbs, q, m_inv);
1045 
1046 		// m_inv = q
1047 		Set(m_inv, limbs, q);
1048 	}
1049 
1050 	// r = x mod m
1051 	// Using Barrett's method with precomputed m_inv
BarrettModulus(int limbs,const uint32_t * x,const uint32_t * m,const uint32_t * m_inv,uint32_t * result)1052 	void BarrettModulus(
1053 		int limbs,			// Number of limbs in m and m_inv
1054 		const uint32_t *x,		// Number to reduce, size = limbs*2
1055 		const uint32_t *m,		// Modulus, size = limbs
1056 		const uint32_t *m_inv,	// R/Modulus, precomputed, size = limbs
1057 		uint32_t *result)		// Large number result
1058 	{
1059 		// q2 = x * m_inv
1060 		// Skips the low limbs+1 words and some high limbs too
1061 		// Needs to partially calculate the next 2 words below for carries
1062 		uint32_t *q2 = (uint32_t*)alloca((limbs+3)*4);
1063 		int ii, jj = limbs - 1;
1064 
1065 		// derived from the fact that m_inv[limbs] was always 1, so m_inv is the same length as modulus now
1066 		*(uint64_t*)q2 = (uint64_t)m_inv[jj] * x[jj];
1067 		*(uint64_t*)(q2 + 1) = (uint64_t)q2[1] + x[jj];
1068 
1069 		for (ii = 1; ii < limbs; ++ii)
1070 			*(uint64_t*)(q2 + ii + 1) = ((uint64_t)q2[ii + 1] + x[jj + ii]) + AddMultiply32(ii + 1, q2, m_inv + jj - ii, x[jj + ii]);
1071 
1072 		*(uint64_t*)(q2 + ii + 1) = ((uint64_t)q2[ii + 1] + x[jj + ii]) + AddMultiply32(ii, q2 + 1, m_inv, x[jj + ii]);
1073 
1074 		q2 += 2;
1075 
1076 		// r2 = (q3 * m2) mod b^(k+1)
1077 		uint32_t *r2 = (uint32_t*)alloca((limbs+1)*4);
1078 
1079 		// Skip high words in product, also input limbs are different by 1
1080 		Multiply32(limbs + 1, r2, q2, m[0]);
1081 		for (int ii = 1; ii < limbs; ++ii)
1082 			AddMultiply32(limbs + 1 - ii, r2 + ii, q2, m[ii]);
1083 
1084 		// Correct the error of up to two modulii
1085 		uint32_t *r = (uint32_t*)alloca((limbs+1)*4);
1086 		if (Subtract(r, x, limbs+1, r2, limbs+1))
1087 		{
1088 			while (!Subtract(r, limbs+1, m, limbs));
1089 		}
1090 		else
1091 		{
1092 			while (GreaterOrEqual(r, limbs+1, m, limbs))
1093 				Subtract(r, limbs+1, m, limbs);
1094 		}
1095 
1096 		Set(result, limbs, r);
1097 	}
1098 
1099 	// result = (x * y) (Mod modulus)
MulMod(int limbs,const uint32_t * x,const uint32_t * y,const uint32_t * modulus,uint32_t * result)1100 	bool MulMod(
1101 		int limbs,			// Number of limbs in x,y,modulus
1102 		const uint32_t *x,		// Large number x
1103 		const uint32_t *y,		// Large number y
1104 		const uint32_t *modulus,	// Large number modulus
1105 		uint32_t *result)		// Large number result
1106 	{
1107 		uint32_t *product = (uint32_t*)alloca(limbs*2*4);
1108 
1109 		Multiply(limbs, product, x, y);
1110 
1111 		return Modulus(product, limbs * 2, modulus, limbs, result);
1112 	}
1113 
1114 	// Convert bigint to string
1115 	/*
1116 	std::string ToStr(const uint32_t *n, int limbs, int base)
1117 	{
1118 		limbs = LimbDegree(n, limbs);
1119 		if (!limbs) return "0";
1120 
1121 		std::string out;
1122 		char ch;
1123 
1124 		uint32_t *m = (uint32_t*)alloca(limbs*4);
1125 		Set(m, limbs, n, limbs);
1126 
1127 		while (limbs)
1128 		{
1129 			uint32_t mod = Divide32(limbs, m, base);
1130 			if (mod <= 9) ch = '0' + mod;
1131 			else ch = 'A' + mod - 10;
1132 			out = ch + out;
1133 			limbs = LimbDegree(m, limbs);
1134 		}
1135 
1136 		return out;
1137 	}
1138 	*/
1139 
1140 	// Convert string to bigint
1141 	// Return 0 if string contains non-digit characters, else number of limbs used
ToInt(uint32_t * lhs,int max_limbs,const char * rhs,uint32_t base)1142 	int ToInt(uint32_t *lhs, int max_limbs, const char *rhs, uint32_t base)
1143 	{
1144 		if (max_limbs < 2) return 0;
1145 
1146 		lhs[0] = 0;
1147 		int used = 1;
1148 
1149 		char ch;
1150 		while ((ch = *rhs++))
1151 		{
1152 			uint32_t mod;
1153 			if (ch >= '0' && ch <= '9') mod = ch - '0';
1154 			else mod = toupper(ch) - 'A' + 10;
1155 			if (mod >= base) return 0;
1156 
1157 			// lhs *= base
1158 			uint32_t carry = MultiplyAdd32(used, lhs, base, mod);
1159 
1160 			// react to running out of room
1161 			if (carry)
1162 			{
1163 				if (used >= max_limbs)
1164 					return 0;
1165 
1166 				lhs[used++] = carry;
1167 			}
1168 		}
1169 
1170 		if (used < max_limbs)
1171 			Set32(lhs+used, max_limbs-used, 0);
1172 
1173 		return used;
1174 	}
1175 
1176 	/*
1177 	 * Computes: result = GCD(a, b)  (greatest common divisor)
1178 	 *
1179 	 * Length of result is the length of the smallest argument
1180 	 */
GCD(const uint32_t * a,int a_limbs,const uint32_t * b,int b_limbs,uint32_t * result)1181 	void GCD(
1182 		const uint32_t *a,	//	Large number, buffer size = a_limbs
1183 		int a_limbs,	//	Size of a
1184 		const uint32_t *b,	//	Large number, buffer size = b_limbs
1185 		int b_limbs,	//	Size of b
1186 		uint32_t *result)	//	Large number, buffer size = min(a, b)
1187 	{
1188 		int limbs = (a_limbs <= b_limbs) ? a_limbs : b_limbs;
1189 
1190 		uint32_t *g = (uint32_t*)alloca(limbs*4);
1191 		uint32_t *g1 = (uint32_t*)alloca(limbs*4);
1192 
1193 		if (a_limbs <= b_limbs)
1194 		{
1195 			// g = a, g1 = b (mod a)
1196 			Set(g, limbs, a, a_limbs);
1197 			Modulus(b, b_limbs, a, a_limbs, g1);
1198 		}
1199 		else
1200 		{
1201 			// g = b, g1 = a (mod b)
1202 			Set(g, limbs, b, b_limbs);
1203 			Modulus(a, a_limbs, b, b_limbs, g1);
1204 		}
1205 
1206 		for (;;) {
1207 			// g = (g mod g1)
1208 			Modulus(g, limbs, g1, limbs, g);
1209 
1210 			if (!LimbDegree(g, limbs)) {
1211 				Set(result, limbs, g1, limbs);
1212 				return;
1213 			}
1214 
1215 			// g1 = (g1 mod g)
1216 			Modulus(g1, limbs, g, limbs, g1);
1217 
1218 			if (!LimbDegree(g1, limbs)) {
1219 				Set(result, limbs, g, limbs);
1220 				return;
1221 			}
1222 		}
1223 	}
1224 
1225 	/*
1226 	 * Computes: result = (1/u) (Mod v)
1227 	 * Such that: result * u (Mod v) = 1
1228 	 * Using Extended Euclid's Algorithm (GCDe)
1229 	 *
1230 	 * This is not always possible, so it will return false iff not possible.
1231 	 */
InvMod(const uint32_t * u,int u_limbs,const uint32_t * v,int limbs,uint32_t * result)1232 	bool InvMod(
1233 		const uint32_t *u,	// Large number, buffer size = u_limbs
1234 		int u_limbs,	// Limbs in u
1235 		const uint32_t *v,	// Large number, buffer size = limbs
1236 		int limbs,		// Limbs in modulus(v) and result
1237 		uint32_t *result)	// Large number, buffer size = limbs
1238 	{
1239 		uint32_t *u1 = (uint32_t*)alloca(limbs*4);
1240 		uint32_t *u3 = (uint32_t*)alloca(limbs*4);
1241 		uint32_t *v1 = (uint32_t*)alloca(limbs*4);
1242 		uint32_t *v3 = (uint32_t*)alloca(limbs*4);
1243 		uint32_t *t1 = (uint32_t*)alloca(limbs*4);
1244 		uint32_t *t3 = (uint32_t*)alloca(limbs*4);
1245 		uint32_t *q = (uint32_t*)alloca((limbs + u_limbs)*4);
1246 
1247 		// Unrolled first iteration
1248 		{
1249 			Set32(u1, limbs, 0);
1250 			Set32(v1, limbs, 1);
1251 			Set(u3, limbs, v);
1252 
1253 			// v3 = u % v
1254 			Modulus(u, u_limbs, v, limbs, v3);
1255 		}
1256 
1257 		for (;;)
1258 		{
1259 			if (!LimbDegree(v3, limbs))
1260 			{
1261 				Subtract(result, v, limbs, u1, limbs);
1262 				return Equal32(u3, limbs, 1);
1263 			}
1264 
1265 			Divide(u3, limbs, v3, limbs, q, t3);
1266 			SimpleMultiplyLowHalf(limbs, t1, q, v1);
1267 			Add(t1, limbs, u1, limbs);
1268 
1269 			if (!LimbDegree(t3, limbs))
1270 			{
1271 				Set(result, limbs, v1);
1272 				return Equal32(v3, limbs, 1);
1273 			}
1274 
1275 			Divide(v3, limbs, t3, limbs, q, u3);
1276 			SimpleMultiplyLowHalf(limbs, u1, q, t1);
1277 			Add(u1, limbs, v1, limbs);
1278 
1279 			if (!LimbDegree(u3, limbs))
1280 			{
1281 				Subtract(result, v, limbs, t1, limbs);
1282 				return Equal32(t3, limbs, 1);
1283 			}
1284 
1285 			Divide(t3, limbs, u3, limbs, q, v3);
1286 			SimpleMultiplyLowHalf(limbs, v1, q, u1);
1287 			Add(v1, limbs, t1, limbs);
1288 
1289 			if (!LimbDegree(v3, limbs))
1290 			{
1291 				Set(result, limbs, u1);
1292 				return Equal32(u3, limbs, 1);
1293 			}
1294 
1295 			Divide(u3, limbs, v3, limbs, q, t3);
1296 			SimpleMultiplyLowHalf(limbs, t1, q, v1);
1297 			Add(t1, limbs, u1, limbs);
1298 
1299 			if (!LimbDegree(t3, limbs))
1300 			{
1301 				Subtract(result, v, limbs, v1, limbs);
1302 				return Equal32(v3, limbs, 1);
1303 			}
1304 
1305 			Divide(v3, limbs, t3, limbs, q, u3);
1306 			SimpleMultiplyLowHalf(limbs, u1, q, t1);
1307 			Add(u1, limbs, v1, limbs);
1308 
1309 			if (!LimbDegree(u3, limbs))
1310 			{
1311 				Set(result, limbs, t1);
1312 				return Equal32(t3, limbs, 1);
1313 			}
1314 
1315 			Divide(t3, limbs, u3, limbs, q, v3);
1316 			SimpleMultiplyLowHalf(limbs, v1, q, u1);
1317 			Add(v1, limbs, t1, limbs);
1318 		}
1319 	}
1320 
1321 	// root = sqrt(square)
1322 	// Based on Newton-Raphson iteration: root_n+1 = (root_n + square/root_n) / 2
1323 	// Doubles number of correct bits each iteration
1324 	// Precondition: The high limb of square is non-zero
1325 	// Returns false if it was unable to determine the root
SquareRoot(int limbs,const uint32_t * square,uint32_t * root)1326 	bool SquareRoot(
1327 		int limbs,			// Number of limbs in root
1328 		const uint32_t *square,	// Square to root, size = limbs * 2
1329 		uint32_t *root)			// Output root, size = limbs
1330 	{
1331 		uint32_t *q = (uint32_t*)alloca(limbs*2*4);
1332 		uint32_t *r = (uint32_t*)alloca((limbs+1)*4);
1333 
1334 		// Take high limbs of square as the initial root guess
1335 		Set(root, limbs, square + limbs);
1336 
1337 		int ctr = 64;
1338 		while (ctr--)
1339 		{
1340 			// {q, r} = square / root
1341 			Divide(square, limbs*2, root, limbs, q, r);
1342 
1343 			// root = (root + q) / 2, assuming high limbs of q = 0
1344 			Add(q, limbs+1, root, limbs);
1345 
1346 			// Round division up to the nearest bit
1347 			// Fixes a problem where root is off by 1
1348 			if (q[0] & 1) Add32(q, limbs+1, 2);
1349 
1350 			ShiftRight(limbs+1, q, q, 1);
1351 
1352 			// Return success if there was no change
1353 			if (Equal(limbs, q, root))
1354 				return true;
1355 
1356 			// Else update root and continue
1357 			Set(root, limbs, q);
1358 		}
1359 
1360 		// In practice only takes about 9 iterations, as many as 31
1361 		// Varies slightly as number of limbs increases but not by much
1362 		return false;
1363 	}
1364 
1365 	// Calculates mod_inv from low limb of modulus for Mon*()
MonReducePrecomp(uint32_t modulus0)1366 	uint32_t MonReducePrecomp(uint32_t modulus0)
1367 	{
1368 		// mod_inv = -M ^ -1 (Mod 2^32)
1369 		return MulInverse32(-(int32_t)modulus0);
1370 	}
1371 
1372 	// Compute n_residue for Montgomery reduction
MonInputResidue(const uint32_t * n,int n_limbs,const uint32_t * modulus,int m_limbs,uint32_t * n_residue)1373 	void MonInputResidue(
1374 		const uint32_t *n,		//	Large number, buffer size = n_limbs
1375 		int n_limbs,		//	Number of limbs in n
1376 		const uint32_t *modulus,	//	Large number, buffer size = m_limbs
1377 		int m_limbs,		//	Number of limbs in modulus
1378 		uint32_t *n_residue)		//	Result, buffer size = m_limbs
1379 	{
1380 		// p = n * 2^(k*m)
1381 		uint32_t *p = (uint32_t*)alloca((n_limbs+m_limbs)*4);
1382 		Set(p+m_limbs, n_limbs, n, n_limbs);
1383 		Set32(p, m_limbs, 0);
1384 
1385 		// n_residue = p (Mod modulus)
1386 		Modulus(p, n_limbs+m_limbs, modulus, m_limbs, n_residue);
1387 	}
1388 
1389 	// result = a * b * r^-1 (Mod modulus) in Montgomery domain
MonPro(int limbs,const uint32_t * a_residue,const uint32_t * b_residue,const uint32_t * modulus,uint32_t mod_inv,uint32_t * result)1390 	void MonPro(
1391 		int limbs,				// Number of limbs in each parameter
1392 		const uint32_t *a_residue,	// Large number, buffer size = limbs
1393 		const uint32_t *b_residue,	// Large number, buffer size = limbs
1394 		const uint32_t *modulus,		// Large number, buffer size = limbs
1395 		uint32_t mod_inv,			// MonReducePrecomp() return
1396 		uint32_t *result)			// Large number, buffer size = limbs
1397 	{
1398 		uint32_t *t = (uint32_t*)alloca(limbs*2*4);
1399 
1400 		Multiply(limbs, t, a_residue, b_residue);
1401 		MonReduce(limbs, t, modulus, mod_inv, result);
1402 	}
1403 
1404 	// result = a^-1 (Mod modulus) in Montgomery domain
MonInverse(int limbs,const uint32_t * a_residue,const uint32_t * modulus,uint32_t mod_inv,uint32_t * result)1405 	void MonInverse(
1406 		int limbs,				// Number of limbs in each parameter
1407 		const uint32_t *a_residue,	// Large number, buffer size = limbs
1408 		const uint32_t *modulus,		// Large number, buffer size = limbs
1409 		uint32_t mod_inv,			// MonReducePrecomp() return
1410 		uint32_t *result)			// Large number, buffer size = limbs
1411 	{
1412 		Set(result, limbs, a_residue);
1413 		MonFinish(limbs, result, modulus, mod_inv);
1414 		InvMod(result, limbs, modulus, limbs, result);
1415 		MonInputResidue(result, limbs, modulus, limbs, result);
1416 	}
1417 
1418 	// result = a * r^-1 (Mod modulus) in Montgomery domain
1419 	// The result may be greater than the modulus, but this is okay since
1420 	// the result is still in the RNS.  MonFinish() corrects this at the end.
MonReduce(int limbs,uint32_t * s,const uint32_t * modulus,uint32_t mod_inv,uint32_t * result)1421 	void MonReduce(
1422 		int limbs,			// Number of limbs in modulus
1423 		uint32_t *s,				// Large number, buffer size = limbs*2, gets clobbered
1424 		const uint32_t *modulus,	// Large number, buffer size = limbs
1425 		uint32_t mod_inv,		// MonReducePrecomp() return
1426 		uint32_t *result)		// Large number, buffer size = limbs
1427 	{
1428 		// This function is roughly 60% of the cost of exponentiation
1429 		for (int ii = 0; ii < limbs; ++ii)
1430 		{
1431 			uint32_t q = s[0] * mod_inv;
1432 			s[0] = AddMultiply32(limbs, s, modulus, q);
1433 			++s;
1434 		}
1435 
1436 		// Add the saved carries
1437 		if (Add(result, s, limbs, s - limbs, limbs))
1438 		{
1439 			// Reduce the result only when needed
1440 			Subtract(result, limbs, modulus, limbs);
1441 		}
1442 	}
1443 
1444 	// result = a * r^-1 (Mod modulus) in Montgomery domain
MonFinish(int limbs,uint32_t * n,const uint32_t * modulus,uint32_t mod_inv)1445 	void MonFinish(
1446 		int limbs,			// Number of limbs in each parameter
1447 		uint32_t *n,				// Large number, buffer size = limbs
1448 		const uint32_t *modulus,	// Large number, buffer size = limbs
1449 		uint32_t mod_inv)		// MonReducePrecomp() return
1450 	{
1451 		uint32_t *t = (uint32_t*)alloca(limbs*2*4);
1452 		memcpy(t, n, limbs*4);
1453 		memset(t + limbs, 0, limbs*4);
1454 
1455 		// Reduce the number
1456 		MonReduce(limbs, t, modulus, mod_inv, n);
1457 
1458 		// Fix MonReduce() results greater than the modulus
1459 		if (!Less(limbs, n, modulus))
1460 			Subtract(n, limbs, modulus, limbs);
1461 	}
1462 
1463 	// Simple internal version without windowing for small exponents
SimpleMonExpMod(const uint32_t * base,const uint32_t * exponent,int exponent_limbs,const uint32_t * modulus,int mod_limbs,uint32_t mod_inv,uint32_t * result)1464 	static void SimpleMonExpMod(
1465 		const uint32_t *base,	//	Base for exponentiation, buffer size = mod_limbs
1466 		const uint32_t *exponent,//	Exponent, buffer size = exponent_limbs
1467 		int exponent_limbs,	//	Number of limbs in exponent
1468 		const uint32_t *modulus,	//	Modulus, buffer size = mod_limbs
1469 		int mod_limbs,		//	Number of limbs in modulus
1470 		uint32_t mod_inv,		//	MonReducePrecomp() return
1471 		uint32_t *result)		//	Result, buffer size = mod_limbs
1472 	{
1473 		bool set = false;
1474 
1475 		uint32_t *temp = (uint32_t*)alloca((mod_limbs*2)*4);
1476 
1477 		// Run down exponent bits and use the squaring method
1478 		for (int ii = exponent_limbs - 1; ii >= 0; --ii)
1479 		{
1480 			uint32_t e_i = exponent[ii];
1481 
1482 			for (uint32_t mask = 0x80000000; mask; mask >>= 1)
1483 			{
1484 				if (set)
1485 				{
1486 					// result = result^2
1487 					Square(mod_limbs, temp, result);
1488 					MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1489 
1490 					if (e_i & mask)
1491 					{
1492 						// result *= base
1493 						Multiply(mod_limbs, temp, result, base);
1494 						MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1495 					}
1496 				}
1497 				else
1498 				{
1499 					if (e_i & mask)
1500 					{
1501 						// result = base
1502 						Set(result, mod_limbs, base, mod_limbs);
1503 						set = true;
1504 					}
1505 				}
1506 			}
1507 		}
1508 	}
1509 
1510 	// Precompute a window for ExpMod() and MonExpMod()
1511 	// Requires 2^window_bits multiplies
PrecomputeWindow(const uint32_t * base,const uint32_t * modulus,int limbs,uint32_t mod_inv,int window_bits)1512 	uint32_t *PrecomputeWindow(const uint32_t *base, const uint32_t *modulus, int limbs, uint32_t mod_inv, int window_bits)
1513 	{
1514 		uint32_t *temp = (uint32_t*)alloca(limbs*2*4);
1515 
1516 		uint32_t *base_squared = (uint32_t*)alloca(limbs*4);
1517 		Square(limbs, temp, base);
1518 		MonReduce(limbs, temp, modulus, mod_inv, base_squared);
1519 
1520 		// precomputed window starts with 000001, 000011, 000101, 000111, ...
1521 		uint32_t k = (1 << (window_bits - 1));
1522 
1523 		uint32_t *window = RakNet::OP_NEW_ARRAY<uint32_t>(limbs * k, __FILE__, __LINE__ );
1524 
1525 		uint32_t *cw = window;
1526 		Set(window, limbs, base);
1527 
1528 		while (--k)
1529 		{
1530 			// cw+1 = cw * base^2
1531 			Multiply(limbs, temp, cw, base_squared);
1532 			MonReduce(limbs, temp, modulus, mod_inv, cw + limbs);
1533 			cw += limbs;
1534 		}
1535 
1536 		return window;
1537 	};
1538 
1539 	// Computes: result = base ^ exponent (Mod modulus)
1540 	// Using Montgomery multiplication with simple squaring method
1541 	// Base parameter must be a Montgomery Residue created with MonInputResidue()
MonExpMod(const uint32_t * base,const uint32_t * exponent,int exponent_limbs,const uint32_t * modulus,int mod_limbs,uint32_t mod_inv,uint32_t * result)1542 	void MonExpMod(
1543 		const uint32_t *base,	//	Base for exponentiation, buffer size = mod_limbs
1544 		const uint32_t *exponent,//	Exponent, buffer size = exponent_limbs
1545 		int exponent_limbs,	//	Number of limbs in exponent
1546 		const uint32_t *modulus,	//	Modulus, buffer size = mod_limbs
1547 		int mod_limbs,		//	Number of limbs in modulus
1548 		uint32_t mod_inv,		//	MonReducePrecomp() return
1549 		uint32_t *result)		//	Result, buffer size = mod_limbs
1550 	{
1551 		// Calculate the number of window bits to use (decent approximation..)
1552 		int window_bits = Degree(exponent_limbs);
1553 
1554 		// If the window bits are too small, might as well just use left-to-right S&M method
1555 		if (window_bits < 4)
1556 		{
1557 			SimpleMonExpMod(base, exponent, exponent_limbs, modulus, mod_limbs, mod_inv, result);
1558 			return;
1559 		}
1560 
1561 		// Precompute a window of the size determined above
1562 		uint32_t *window = PrecomputeWindow(base, modulus, mod_limbs, mod_inv, window_bits);
1563 
1564 		bool seen_bits = false;
1565 		uint32_t e_bits=0, trailing_zeroes=0, used_bits = 0;
1566 
1567 		uint32_t *temp = (uint32_t*)alloca((mod_limbs*2)*4);
1568 
1569 		for (int ii = exponent_limbs - 1; ii >= 0; --ii)
1570 		{
1571 			uint32_t e_i = exponent[ii];
1572 
1573 			int wordbits = 32;
1574 			while (wordbits--)
1575 			{
1576 				// If we have been accumulating bits,
1577 				if (used_bits)
1578 				{
1579 					// If this new bit is set,
1580 					if (e_i >> 31)
1581 					{
1582 						e_bits <<= 1;
1583 						e_bits |= 1;
1584 
1585 						trailing_zeroes = 0;
1586 					}
1587 					else // the new bit is unset
1588 					{
1589 						e_bits <<= 1;
1590 
1591 						++trailing_zeroes;
1592 					}
1593 
1594 					++used_bits;
1595 
1596 					// If we have used up the window bits,
1597 					if (used_bits == (uint32_t) window_bits)
1598 					{
1599 						// Select window index 1011 from "101110"
1600 						uint32_t window_index = e_bits >> (trailing_zeroes + 1);
1601 
1602 						if (seen_bits)
1603 						{
1604 							uint32_t ctr = used_bits - trailing_zeroes;
1605 							while (ctr--)
1606 							{
1607 								// result = result^2
1608 								Square(mod_limbs, temp, result);
1609 								MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1610 							}
1611 
1612 							// result = result * window[index]
1613 							Multiply(mod_limbs, temp, result, &window[window_index * mod_limbs]);
1614 							MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1615 						}
1616 						else
1617 						{
1618 							// result = window[index]
1619 							Set(result, mod_limbs, &window[window_index * mod_limbs]);
1620 							seen_bits = true;
1621 						}
1622 
1623 						while (trailing_zeroes--)
1624 						{
1625 							// result = result^2
1626 							Square(mod_limbs, temp, result);
1627 							MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1628 						}
1629 
1630 						used_bits = 0;
1631 					}
1632 				}
1633 				else
1634 				{
1635 					// If this new bit is set,
1636 					if (e_i >> 31)
1637 					{
1638 						used_bits = 1;
1639 						e_bits = 1;
1640 						trailing_zeroes = 0;
1641 					}
1642 					else // the new bit is unset
1643 					{
1644 						// If we have processed any bits yet,
1645 						if (seen_bits)
1646 						{
1647 							// result = result^2
1648 							Square(mod_limbs, temp, result);
1649 							MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1650 						}
1651 					}
1652 				}
1653 
1654 				e_i <<= 1;
1655 			}
1656 		}
1657 
1658 		if (used_bits)
1659 		{
1660 			// Select window index 1011 from "101110"
1661 			uint32_t window_index = e_bits >> (trailing_zeroes + 1);
1662 
1663 			if (seen_bits)
1664 			{
1665 				uint32_t ctr = used_bits - trailing_zeroes;
1666 				while (ctr--)
1667 				{
1668 					// result = result^2
1669 					Square(mod_limbs, temp, result);
1670 					MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1671 				}
1672 
1673 				// result = result * window[index]
1674 				Multiply(mod_limbs, temp, result, &window[window_index * mod_limbs]);
1675 				MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1676 			}
1677 			else
1678 			{
1679 				// result = window[index]
1680 				Set(result, mod_limbs, &window[window_index * mod_limbs]);
1681 				//seen_bits = true;
1682 			}
1683 
1684 			while (trailing_zeroes--)
1685 			{
1686 				// result = result^2
1687 				Square(mod_limbs, temp, result);
1688 				MonReduce(mod_limbs, temp, modulus, mod_inv, result);
1689 			}
1690 
1691 			//e_bits = 0;
1692 		}
1693 
1694 		RakNet::OP_DELETE_ARRAY(window, __FILE__, __LINE__);
1695 	}
1696 
1697 	// Computes: result = base ^ exponent (Mod modulus)
1698 	// Using Montgomery multiplication with simple squaring method
ExpMod(const uint32_t * base,int base_limbs,const uint32_t * exponent,int exponent_limbs,const uint32_t * modulus,int mod_limbs,uint32_t mod_inv,uint32_t * result)1699 	void ExpMod(
1700 		const uint32_t *base,	//	Base for exponentiation, buffer size = base_limbs
1701 		int base_limbs,		//	Number of limbs in base
1702 		const uint32_t *exponent,//	Exponent, buffer size = exponent_limbs
1703 		int exponent_limbs,	//	Number of limbs in exponent
1704 		const uint32_t *modulus,	//	Modulus, buffer size = mod_limbs
1705 		int mod_limbs,		//	Number of limbs in modulus
1706 		uint32_t mod_inv,		//	MonReducePrecomp() return
1707 		uint32_t *result)		//	Result, buffer size = mod_limbs
1708 	{
1709 		uint32_t *mon_base = (uint32_t*)alloca(mod_limbs*4);
1710 		MonInputResidue(base, base_limbs, modulus, mod_limbs, mon_base);
1711 
1712 		MonExpMod(mon_base, exponent, exponent_limbs, modulus, mod_limbs, mod_inv, result);
1713 
1714 		MonFinish(mod_limbs, result, modulus, mod_inv);
1715 	}
1716 
1717 	// returns b ^ e (Mod m)
ExpMod(uint32_t b,uint32_t e,uint32_t m)1718 	uint32_t ExpMod(uint32_t b, uint32_t e, uint32_t m)
1719 	{
1720 		// validate arguments
1721 		if (b == 0 || m <= 1) return 0;
1722 		if (e == 0) return 1;
1723 
1724 		// find high bit of exponent
1725 		uint32_t mask = 0x80000000;
1726 		while ((e & mask) == 0) mask >>= 1;
1727 
1728 		// seen 1 set bit, so result = base so far
1729 		uint32_t r = b;
1730 
1731 		while (mask >>= 1)
1732 		{
1733 			// VS.NET does a poor job recognizing that the division
1734 			// is just an IDIV with a 32-bit dividend (not 64-bit) :-(
1735 
1736 			// r = r^2 (mod m)
1737 			r = (uint32_t)(((uint64_t)r * r) % m);
1738 
1739 			// if exponent bit is set, r = r*b (mod m)
1740 			if (e & mask) r = (uint32_t)(((uint64_t)r * b) % m);
1741 		}
1742 
1743 		return r;
1744 	}
1745 
1746 	// Rabin-Miller method for finding a strong pseudo-prime
1747 	// Preconditions: High bit and low bit of n = 1
RabinMillerPrimeTest(const uint32_t * n,int limbs,uint32_t k)1748 	bool RabinMillerPrimeTest(
1749 		const uint32_t *n,	// Number to check for primality
1750 		int limbs,		// Number of limbs in n
1751 		uint32_t k)			// Confidence level (40 is pretty good)
1752 	{
1753 		// n1 = n - 1
1754 		uint32_t *n1 = (uint32_t *)alloca(limbs*4);
1755 		Set(n1, limbs, n);
1756 		Subtract32(n1, limbs, 1);
1757 
1758 		// d = n1
1759 		uint32_t *d = (uint32_t *)alloca(limbs*4);
1760 		Set(d, limbs, n1);
1761 
1762 		// remove factors of two from d
1763 		while (!(d[0] & 1))
1764 			ShiftRight(limbs, d, d, 1);
1765 
1766 		uint32_t *a = (uint32_t *)alloca(limbs*4);
1767 		uint32_t *t = (uint32_t *)alloca(limbs*4);
1768 		uint32_t *p = (uint32_t *)alloca((limbs*2)*4);
1769 		uint32_t n_inv = MonReducePrecomp(n[0]);
1770 
1771 		// iterate k times
1772 		while (k--)
1773 		{
1774 			//do Random::ref()->generate(a, limbs*4);
1775 			do fillBufferMT(a,limbs*4);
1776 			while (GreaterOrEqual(a, limbs, n, limbs));
1777 
1778 			// a = a ^ d (Mod n)
1779 			ExpMod(a, limbs, d, limbs, n, limbs, n_inv, a);
1780 
1781 			Set(t, limbs, d);
1782 			while (!Equal(limbs, t, n1) &&
1783 				   !Equal32(a, limbs, 1) &&
1784 				   !Equal(limbs, a, n1))
1785 			{
1786 				// a = a^2 (Mod n), non-critical path
1787 				Square(limbs, p, a);
1788 				Modulus(p, limbs*2, n, limbs, a);
1789 
1790 				// t <<= 1
1791 				ShiftLeft(limbs, t, t, 1);
1792 			}
1793 
1794 			if (!Equal(limbs, a, n1) && !(t[0] & 1)) return false;
1795 		}
1796 
1797 		return true;
1798 	}
1799 
1800 	// Generate a strong pseudo-prime using the Rabin-Miller primality test
GenerateStrongPseudoPrime(uint32_t * n,int limbs)1801 	void GenerateStrongPseudoPrime(
1802 		uint32_t *n,			// Output prime
1803 		int limbs)		// Number of limbs in n
1804 	{
1805 		do {
1806 			fillBufferMT(n,limbs*4);
1807 			n[limbs-1] |= 0x80000000;
1808 			n[0] |= 1;
1809 		} while (!RabinMillerPrimeTest(n, limbs, 40)); // 40 iterations
1810 	}
1811 }
1812 
1813 
1814