xref: /openbsd/regress/lib/libcrypto/bn/bn_mod_exp.c (revision 24882d0a)
1 /*	$OpenBSD: bn_mod_exp.c,v 1.40 2023/10/19 13:38:12 tb Exp $ */
2 
3 /*
4  * Copyright (c) 2022,2023 Theo Buehler <tb@openbsd.org>
5  *
6  * Permission to use, copy, modify, and distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18 
19 #include <err.h>
20 #include <stdio.h>
21 #include <string.h>
22 
23 #include <openssl/bn.h>
24 #include <openssl/err.h>
25 
26 #include "bn_local.h"
27 
28 #define N_MOD_EXP_TESTS		100
29 #define N_MOD_EXP2_TESTS	50
30 
31 #define INIT_MOD_EXP_FN(f) { .name = #f, .mod_exp_fn = (f), }
32 #define INIT_MOD_EXP_MONT_FN(f) { .name = #f, .mod_exp_mont_fn = (f), }
33 
34 static int
bn_mod_exp2_mont_first(BIGNUM * r,const BIGNUM * a,const BIGNUM * p,const BIGNUM * m,BN_CTX * ctx,BN_MONT_CTX * mctx)35 bn_mod_exp2_mont_first(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
36     const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *mctx)
37 {
38 	const BIGNUM *one = BN_value_one();
39 
40 	return BN_mod_exp2_mont(r, a, p, one, one, m, ctx, mctx);
41 }
42 
43 static int
bn_mod_exp2_mont_second(BIGNUM * r,const BIGNUM * a,const BIGNUM * p,const BIGNUM * m,BN_CTX * ctx,BN_MONT_CTX * mctx)44 bn_mod_exp2_mont_second(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
45     const BIGNUM *m, BN_CTX *ctx, BN_MONT_CTX *mctx)
46 {
47 	const BIGNUM *one = BN_value_one();
48 
49 	return BN_mod_exp2_mont(r, one, one, a, p, m, ctx, mctx);
50 }
51 
52 static const struct mod_exp_test {
53 	const char *name;
54 	int (*mod_exp_fn)(BIGNUM *, const BIGNUM *, const BIGNUM *,
55 	    const BIGNUM *, BN_CTX *);
56 	int (*mod_exp_mont_fn)(BIGNUM *, const BIGNUM *, const BIGNUM *,
57 	    const BIGNUM *, BN_CTX *, BN_MONT_CTX *);
58 } mod_exp_fn[] = {
59 	INIT_MOD_EXP_FN(BN_mod_exp),
60 	INIT_MOD_EXP_FN(BN_mod_exp_ct),
61 	INIT_MOD_EXP_FN(BN_mod_exp_nonct),
62 	INIT_MOD_EXP_FN(BN_mod_exp_recp),
63 	INIT_MOD_EXP_FN(BN_mod_exp_simple),
64 	INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont),
65 	INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_ct),
66 	INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_consttime),
67 	INIT_MOD_EXP_MONT_FN(BN_mod_exp_mont_nonct),
68 	INIT_MOD_EXP_MONT_FN(bn_mod_exp2_mont_first),
69 	INIT_MOD_EXP_MONT_FN(bn_mod_exp2_mont_second),
70 };
71 
72 #define N_MOD_EXP_FN (sizeof(mod_exp_fn) / sizeof(mod_exp_fn[0]))
73 
74 static void
bn_print(const char * name,const BIGNUM * bn)75 bn_print(const char *name, const BIGNUM *bn)
76 {
77 	size_t len;
78 	int pad = 0;
79 
80 	if ((len = strlen(name)) < 7)
81 		pad = 6 - len;
82 
83 	fprintf(stderr, "%s: %*s", name, pad, "");
84 	BN_print_fp(stderr, bn);
85 	fprintf(stderr, "\n");
86 }
87 
88 static void
print_zero_test_failure(const BIGNUM * got,const BIGNUM * a,const BIGNUM * m,const char * name)89 print_zero_test_failure(const BIGNUM *got, const BIGNUM *a, const BIGNUM *m,
90     const char *name)
91 {
92 	fprintf(stderr, "%s() zero test failed:\n", name);
93 
94 	bn_print("a", a);
95 	bn_print("m", m);
96 	bn_print("got", got);
97 }
98 
99 static int
bn_mod_exp_zero_test(const struct mod_exp_test * test,BN_CTX * ctx,int neg_modulus,int random_base)100 bn_mod_exp_zero_test(const struct mod_exp_test *test, BN_CTX *ctx,
101     int neg_modulus, int random_base)
102 {
103 	BIGNUM *a, *m, *p, *got;
104 	int mod_exp_ret;
105 	int failed = 1;
106 
107 	BN_CTX_start(ctx);
108 
109 	if ((a = BN_CTX_get(ctx)) == NULL)
110 		errx(1, "BN_CTX_get");
111 	if ((m = BN_CTX_get(ctx)) == NULL)
112 		errx(1, "BN_CTX_get");
113 	if ((p = BN_CTX_get(ctx)) == NULL)
114 		errx(1, "BN_CTX_get");
115 	if ((got = BN_CTX_get(ctx)) == NULL)
116 		errx(1, "BN_CTX_get");
117 
118 	if (!BN_one(m))
119 		errx(1, "BN_one");
120 	if (neg_modulus)
121 		BN_set_negative(m, 1);
122 	BN_zero(a);
123 	BN_zero(p);
124 
125 	if (random_base) {
126 		if (!BN_rand(a, 1024, BN_RAND_TOP_ANY, BN_RAND_BOTTOM_ANY))
127 			errx(1, "BN_rand");
128 	}
129 
130 	if (test->mod_exp_fn != NULL) {
131 		mod_exp_ret = test->mod_exp_fn(got, a, p, m, ctx);
132 	} else {
133 		mod_exp_ret = test->mod_exp_mont_fn(got, a, p, m, ctx, NULL);
134 	}
135 
136 	if (!mod_exp_ret) {
137 		fprintf(stderr, "%s failed\n", test->name);
138 		ERR_print_errors_fp(stderr);
139 		goto err;
140 	}
141 
142 	if (!BN_is_zero(got)) {
143 		print_zero_test_failure(got, a, m, test->name);
144 		goto err;
145 	}
146 
147 	failed = 0;
148 
149  err:
150 	BN_CTX_end(ctx);
151 
152 	return failed;
153 }
154 
155 static int
bn_mod_exp_zero_word_test(BN_CTX * ctx,int neg_modulus)156 bn_mod_exp_zero_word_test(BN_CTX *ctx, int neg_modulus)
157 {
158 	const char *name = "BN_mod_exp_mont_word";
159 	BIGNUM *m, *p, *got;
160 	int failed = 1;
161 
162 	BN_CTX_start(ctx);
163 
164 	if ((m = BN_CTX_get(ctx)) == NULL)
165 		errx(1, "BN_CTX_get");
166 	if ((p = BN_CTX_get(ctx)) == NULL)
167 		errx(1, "BN_CTX_get");
168 	if ((got = BN_CTX_get(ctx)) == NULL)
169 		errx(1, "BN_CTX_get");
170 
171 	if (!BN_one(m))
172 		errx(1, "BN_one");
173 	if (neg_modulus)
174 		BN_set_negative(m, neg_modulus);
175 	BN_zero(p);
176 
177 	if (!BN_mod_exp_mont_word(got, 1, p, m, ctx, NULL)) {
178 		fprintf(stderr, "%s failed\n", name);
179 		ERR_print_errors_fp(stderr);
180 		goto err;
181 	}
182 
183 	if (!BN_is_zero(got)) {
184 		print_zero_test_failure(got, p, m, name);
185 		goto err;
186 	}
187 
188 	failed = 0;
189 
190  err:
191 	BN_CTX_end(ctx);
192 
193 	return failed;
194 }
195 
196 static int
test_bn_mod_exp_zero(void)197 test_bn_mod_exp_zero(void)
198 {
199 	BN_CTX *ctx;
200 	size_t i, j;
201 	int failed = 0;
202 
203 	if ((ctx = BN_CTX_new()) == NULL)
204 		errx(1, "BN_CTX_new");
205 
206 	for (i = 0; i < N_MOD_EXP_FN; i++) {
207 		for (j = 0; j < 4; j++) {
208 			int neg_modulus = (j >> 0) & 1;
209 			int random_base = (j >> 1) & 1;
210 
211 			failed |= bn_mod_exp_zero_test(&mod_exp_fn[i], ctx,
212 			    neg_modulus, random_base);
213 		}
214 	}
215 
216 	failed |= bn_mod_exp_zero_word_test(ctx, 0);
217 	failed |= bn_mod_exp_zero_word_test(ctx, 1);
218 
219 	BN_CTX_free(ctx);
220 
221 	return failed;
222 }
223 
224 static int
generate_bn(BIGNUM * bn,int avg_bits,int deviate,int force_odd)225 generate_bn(BIGNUM *bn, int avg_bits, int deviate, int force_odd)
226 {
227 	int bits;
228 
229 	if (bn == NULL)
230 		return 1;
231 
232 	if (avg_bits <= 0 || deviate <= 0 || deviate >= avg_bits)
233 		return 0;
234 
235 	bits = avg_bits + arc4random_uniform(deviate) - deviate;
236 
237 	return BN_rand(bn, bits, 0, force_odd);
238 }
239 
240 static int
generate_test_quintuple(int reduce,BIGNUM * a,BIGNUM * p,BIGNUM * b,BIGNUM * q,BIGNUM * m,BN_CTX * ctx)241 generate_test_quintuple(int reduce, BIGNUM *a, BIGNUM *p, BIGNUM *b, BIGNUM *q,
242     BIGNUM *m, BN_CTX *ctx)
243 {
244 	BIGNUM *mmodified;
245 	BN_ULONG multiple;
246 	int avg = 2 * BN_BITS, deviate = BN_BITS / 2;
247 	int ret = 0;
248 
249 	if (!generate_bn(a, avg, deviate, 0))
250 		return 0;
251 
252 	if (!generate_bn(p, avg, deviate, 0))
253 		return 0;
254 
255 	if (!generate_bn(b, avg, deviate, 0))
256 		return 0;
257 
258 	if (!generate_bn(q, avg, deviate, 0))
259 		return 0;
260 
261 	if (!generate_bn(m, avg, deviate, 1))
262 		return 0;
263 
264 	if (reduce) {
265 		if (!BN_mod(a, a, m, ctx))
266 			return 0;
267 
268 		if (b == NULL)
269 			return 1;
270 
271 		return BN_mod(b, b, m, ctx);
272 	}
273 
274 	/*
275 	 * Add a random multiple of m to a to test unreduced exponentiation.
276 	 */
277 
278 	BN_CTX_start(ctx);
279 
280 	if ((mmodified = BN_CTX_get(ctx)) == NULL)
281 		goto err;
282 
283 	if (!bn_copy(mmodified, m))
284 		goto err;
285 
286 	multiple = arc4random_uniform(16) + 2;
287 
288 	if (!BN_mul_word(mmodified, multiple))
289 		goto err;
290 
291 	if (!BN_add(a, a, mmodified))
292 		goto err;
293 
294 	if (b == NULL)
295 		goto done;
296 
297 	if (!BN_add(b, b, mmodified))
298 		goto err;
299 
300  done:
301 	ret = 1;
302 
303  err:
304 	BN_CTX_end(ctx);
305 
306 	return ret;
307 }
308 
309 static int
generate_test_triple(int reduce,BIGNUM * a,BIGNUM * p,BIGNUM * m,BN_CTX * ctx)310 generate_test_triple(int reduce, BIGNUM *a, BIGNUM *p, BIGNUM *m, BN_CTX *ctx)
311 {
312 	return generate_test_quintuple(reduce, a, p, NULL, NULL, m, ctx);
313 }
314 
315 static void
dump_results(const BIGNUM * a,const BIGNUM * p,const BIGNUM * b,const BIGNUM * q,const BIGNUM * m,const BIGNUM * want,const BIGNUM * got,const char * name)316 dump_results(const BIGNUM *a, const BIGNUM *p, const BIGNUM *b, const BIGNUM *q,
317     const BIGNUM *m, const BIGNUM *want, const BIGNUM *got, const char *name)
318 {
319 	fprintf(stderr, "BN_mod_exp_simple() and %s() disagree:\n", name);
320 
321 	bn_print("want", want);
322 	bn_print("got", got);
323 
324 	bn_print("a", a);
325 	bn_print("p", p);
326 
327 	if (b != NULL) {
328 		bn_print("b", b);
329 		bn_print("q", q);
330 	}
331 
332 	bn_print("m", m);
333 
334 	fprintf(stderr, "\n");
335 }
336 
337 static int
test_mod_exp(const BIGNUM * want,const BIGNUM * a,const BIGNUM * p,const BIGNUM * m,BN_CTX * ctx,const struct mod_exp_test * test)338 test_mod_exp(const BIGNUM *want, const BIGNUM *a, const BIGNUM *p,
339     const BIGNUM *m, BN_CTX *ctx, const struct mod_exp_test *test)
340 {
341 	BIGNUM *got;
342 	int mod_exp_ret;
343 	int ret = 0;
344 
345 	BN_CTX_start(ctx);
346 
347 	if ((got = BN_CTX_get(ctx)) == NULL)
348 		goto err;
349 
350 	if (test->mod_exp_fn != NULL)
351 		mod_exp_ret = test->mod_exp_fn(got, a, p, m, ctx);
352 	else
353 		mod_exp_ret = test->mod_exp_mont_fn(got, a, p, m, ctx, NULL);
354 
355 	if (!mod_exp_ret)
356 		errx(1, "%s() failed", test->name);
357 
358 	if (BN_cmp(want, got) != 0) {
359 		dump_results(a, p, NULL, NULL, m, want, got, test->name);
360 		goto err;
361 	}
362 
363 	ret = 1;
364 
365  err:
366 	BN_CTX_end(ctx);
367 
368 	return ret;
369 }
370 
371 static int
bn_mod_exp_test(int reduce,BIGNUM * want,BIGNUM * a,BIGNUM * p,BIGNUM * m,BN_CTX * ctx)372 bn_mod_exp_test(int reduce, BIGNUM *want, BIGNUM *a, BIGNUM *p, BIGNUM *m,
373     BN_CTX *ctx)
374 {
375 	size_t i, j;
376 	int failed = 0;
377 
378 	if (!generate_test_triple(reduce, a, p, m, ctx))
379 		errx(1, "generate_test_triple");
380 
381 	for (i = 0; i < 8 && !failed; i++) {
382 		BN_set_negative(a, (i >> 0) & 1);
383 		BN_set_negative(p, (i >> 1) & 1);
384 		BN_set_negative(m, (i >> 2) & 1);
385 
386 		if ((BN_mod_exp_simple(want, a, p, m, ctx)) <= 0)
387 			errx(1, "BN_mod_exp_simple");
388 
389 		for (j = 0; j < N_MOD_EXP_FN; j++) {
390 			const struct mod_exp_test *test = &mod_exp_fn[j];
391 
392 			if (!test_mod_exp(want, a, p, m, ctx, test))
393 				failed |= 1;
394 		}
395 	}
396 
397 	return failed;
398 }
399 
400 static int
test_bn_mod_exp(void)401 test_bn_mod_exp(void)
402 {
403 	BIGNUM *a, *p, *m, *want;
404 	BN_CTX *ctx;
405 	int i;
406 	int reduce;
407 	int failed = 0;
408 
409 	if ((ctx = BN_CTX_new()) == NULL)
410 		errx(1, "BN_CTX_new");
411 
412 	BN_CTX_start(ctx);
413 
414 	if ((a = BN_CTX_get(ctx)) == NULL)
415 		errx(1, "a = BN_CTX_get()");
416 	if ((p = BN_CTX_get(ctx)) == NULL)
417 		errx(1, "p = BN_CTX_get()");
418 	if ((m = BN_CTX_get(ctx)) == NULL)
419 		errx(1, "m = BN_CTX_get()");
420 	if ((want = BN_CTX_get(ctx)) == NULL)
421 		errx(1, "want = BN_CTX_get()");
422 
423 	reduce = 0;
424 	for (i = 0; i < N_MOD_EXP_TESTS && !failed; i++)
425 		failed |= bn_mod_exp_test(reduce, want, a, p, m, ctx);
426 
427 	reduce = 1;
428 	for (i = 0; i < N_MOD_EXP_TESTS && !failed; i++)
429 		failed |= bn_mod_exp_test(reduce, want, a, p, m, ctx);
430 
431 	BN_CTX_end(ctx);
432 	BN_CTX_free(ctx);
433 
434 	return failed;
435 }
436 
437 static int
bn_mod_exp2_simple(BIGNUM * out,const BIGNUM * a,const BIGNUM * p,const BIGNUM * b,const BIGNUM * q,const BIGNUM * m,BN_CTX * ctx)438 bn_mod_exp2_simple(BIGNUM *out, const BIGNUM *a, const BIGNUM *p,
439     const BIGNUM *b, const BIGNUM *q, const BIGNUM *m, BN_CTX *ctx)
440 {
441 	BIGNUM *fact1, *fact2;
442 	int ret = 0;
443 
444 	BN_CTX_start(ctx);
445 
446 	if ((fact1 = BN_CTX_get(ctx)) == NULL)
447 		goto err;
448 	if ((fact2 = BN_CTX_get(ctx)) == NULL)
449 		goto err;
450 
451 	if (!BN_mod_exp_simple(fact1, a, p, m, ctx))
452 		goto err;
453 	if (!BN_mod_exp_simple(fact2, b, q, m, ctx))
454 		goto err;
455 	if (!BN_mod_mul(out, fact1, fact2, m, ctx))
456 		goto err;
457 
458 	ret = 1;
459  err:
460 	BN_CTX_end(ctx);
461 
462 	return ret;
463 }
464 
465 static int
bn_mod_exp2_test(int reduce,BIGNUM * want,BIGNUM * got,BIGNUM * a,BIGNUM * p,BIGNUM * b,BIGNUM * q,BIGNUM * m,BN_CTX * ctx)466 bn_mod_exp2_test(int reduce, BIGNUM *want, BIGNUM *got, BIGNUM *a, BIGNUM *p,
467     BIGNUM *b, BIGNUM *q, BIGNUM *m, BN_CTX *ctx)
468 {
469 	size_t i;
470 	int failed = 0;
471 
472 	if (!generate_test_quintuple(reduce, a, p, b, q, m, ctx))
473 		errx(1, "generate_test_quintuple");
474 
475 	for (i = 0; i < 32 && !failed; i++) {
476 		BN_set_negative(a, (i >> 0) & 1);
477 		BN_set_negative(p, (i >> 1) & 1);
478 		BN_set_negative(b, (i >> 2) & 1);
479 		BN_set_negative(q, (i >> 3) & 1);
480 		BN_set_negative(m, (i >> 4) & 1);
481 
482 		if (!bn_mod_exp2_simple(want, a, p, b, q, m, ctx))
483 			errx(1, "BN_mod_exp_simple");
484 
485 		if (!BN_mod_exp2_mont(got, a, p, b, q, m, ctx, NULL))
486 			errx(1, "BN_mod_exp2_mont");
487 
488 		if (BN_cmp(want, got) != 0) {
489 			dump_results(a, p, b, q, m, want, got, "BN_mod_exp2_mont");
490 			failed |= 1;
491 		}
492 	}
493 
494 	return failed;
495 }
496 
497 static int
test_bn_mod_exp2(void)498 test_bn_mod_exp2(void)
499 {
500 	BIGNUM *a, *p, *b, *q, *m, *want, *got;
501 	BN_CTX *ctx;
502 	int i;
503 	int reduce;
504 	int failed = 0;
505 
506 	if ((ctx = BN_CTX_new()) == NULL)
507 		errx(1, "BN_CTX_new");
508 
509 	BN_CTX_start(ctx);
510 
511 	if ((a = BN_CTX_get(ctx)) == NULL)
512 		errx(1, "a = BN_CTX_get()");
513 	if ((p = BN_CTX_get(ctx)) == NULL)
514 		errx(1, "p = BN_CTX_get()");
515 	if ((b = BN_CTX_get(ctx)) == NULL)
516 		errx(1, "b = BN_CTX_get()");
517 	if ((q = BN_CTX_get(ctx)) == NULL)
518 		errx(1, "q = BN_CTX_get()");
519 	if ((m = BN_CTX_get(ctx)) == NULL)
520 		errx(1, "m = BN_CTX_get()");
521 	if ((want = BN_CTX_get(ctx)) == NULL)
522 		errx(1, "want = BN_CTX_get()");
523 	if ((got = BN_CTX_get(ctx)) == NULL)
524 		errx(1, "got = BN_CTX_get()");
525 
526 	reduce = 0;
527 	for (i = 0; i < N_MOD_EXP_TESTS && !failed; i++)
528 		failed |= bn_mod_exp2_test(reduce, want, got, a, p, b, q, m, ctx);
529 
530 	reduce = 1;
531 	for (i = 0; i < N_MOD_EXP_TESTS && !failed; i++)
532 		failed |= bn_mod_exp2_test(reduce, want, got, a, p, b, q, m, ctx);
533 
534 	BN_CTX_end(ctx);
535 	BN_CTX_free(ctx);
536 
537 	return failed;
538 }
539 
540 /*
541  * Small test for a crash reported by Guido Vranken, fixed in bn_exp2.c r1.13.
542  * https://github.com/openssl/openssl/issues/17648
543  */
544 
545 static int
test_bn_mod_exp2_mont_crash(void)546 test_bn_mod_exp2_mont_crash(void)
547 {
548 	BIGNUM *m;
549 	int failed = 0;
550 
551 	if ((m = BN_new()) == NULL)
552 		errx(1, "BN_new");
553 
554 	if (BN_mod_exp2_mont(NULL, NULL, NULL, NULL, NULL, m, NULL, NULL)) {
555 		fprintf(stderr, "BN_mod_exp2_mont succeeded\n");
556 		failed |= 1;
557 	}
558 
559 	BN_free(m);
560 
561 	return failed;
562 }
563 
564 const struct aliasing_test_case {
565 	BN_ULONG a;
566 	BN_ULONG p;
567 	BN_ULONG m;
568 } aliasing_test_cases[] = {
569 	{
570 		.a = 1031,
571 		.p = 1033,
572 		.m = 1039,
573 	},
574 	{
575 		.a = 3,
576 		.p = 4,
577 		.m = 5,
578 	},
579 	{
580 		.a = 97,
581 		.p = 17,
582 		.m = 11,
583 	},
584 	{
585 		.a = 999961,
586 		.p = 999979,
587 		.m = 999983,
588 	},
589 };
590 
591 #define N_ALIASING_TEST_CASES \
592 	(sizeof(aliasing_test_cases) / sizeof(aliasing_test_cases[0]))
593 
594 static void
test_bn_mod_exp_aliasing_setup(BIGNUM * want,BIGNUM * a,BIGNUM * p,BIGNUM * m,BN_CTX * ctx,const struct aliasing_test_case * tc)595 test_bn_mod_exp_aliasing_setup(BIGNUM *want, BIGNUM *a, BIGNUM *p, BIGNUM *m,
596     BN_CTX *ctx, const struct aliasing_test_case *tc)
597 {
598 	if (!BN_set_word(a, tc->a))
599 		errx(1, "BN_set_word");
600 	if (!BN_set_word(p, tc->p))
601 		errx(1, "BN_set_word");
602 	if (!BN_set_word(m, tc->m))
603 		errx(1, "BN_set_word");
604 
605 	if (!BN_mod_exp_simple(want, a, p, m, ctx))
606 		errx(1, "BN_mod_exp");
607 }
608 
609 static int
test_mod_exp_aliased(const char * alias,int want_ret,BIGNUM * got,const BIGNUM * want,const BIGNUM * a,const BIGNUM * p,const BIGNUM * m,BN_CTX * ctx,const struct mod_exp_test * test)610 test_mod_exp_aliased(const char *alias, int want_ret, BIGNUM *got,
611     const BIGNUM *want, const BIGNUM *a, const BIGNUM *p, const BIGNUM *m,
612     BN_CTX *ctx, const struct mod_exp_test *test)
613 {
614 	int mod_exp_ret;
615 	int ret = 0;
616 
617 	BN_CTX_start(ctx);
618 
619 	if (test->mod_exp_fn != NULL)
620 		mod_exp_ret = test->mod_exp_fn(got, a, p, m, ctx);
621 	else
622 		mod_exp_ret = test->mod_exp_mont_fn(got, a, p, m, ctx, NULL);
623 
624 	if (mod_exp_ret != want_ret) {
625 		warnx("%s() %s aliased with result failed", test->name, alias);
626 		goto err;
627 	}
628 
629 	if (!mod_exp_ret)
630 		goto done;
631 
632 	if (BN_cmp(want, got) != 0) {
633 		dump_results(a, p, NULL, NULL, m, want, got, test->name);
634 		goto err;
635 	}
636 
637  done:
638 	ret = 1;
639 
640  err:
641 	BN_CTX_end(ctx);
642 
643 	return ret;
644 }
645 
646 static int
test_bn_mod_exp_aliasing_test(const struct mod_exp_test * test,BIGNUM * a,BIGNUM * p,BIGNUM * m,BIGNUM * want,BIGNUM * got,BN_CTX * ctx)647 test_bn_mod_exp_aliasing_test(const struct mod_exp_test *test,
648     BIGNUM *a, BIGNUM *p, BIGNUM *m, BIGNUM *want, BIGNUM *got, BN_CTX *ctx)
649 {
650 	int modulus_alias_works = test->mod_exp_fn != BN_mod_exp_simple;
651 	size_t i;
652 	int failed = 0;
653 
654 	for (i = 0; i < N_ALIASING_TEST_CASES; i++) {
655 		const struct aliasing_test_case *tc = &aliasing_test_cases[i];
656 
657 		test_bn_mod_exp_aliasing_setup(want, a, p, m, ctx, tc);
658 		if (!test_mod_exp_aliased("nothing", 1, got, want, a, p, m, ctx,
659 		    test))
660 			failed |= 1;
661 		test_bn_mod_exp_aliasing_setup(want, a, p, m, ctx, tc);
662 		if (!test_mod_exp_aliased("a", 1, a, want, a, p, m, ctx, test))
663 			failed |= 1;
664 		test_bn_mod_exp_aliasing_setup(want, a, p, m, ctx, tc);
665 		if (!test_mod_exp_aliased("p", 1, p, want, a, p, m, ctx, test))
666 			failed |= 1;
667 		test_bn_mod_exp_aliasing_setup(want, a, p, m, ctx, tc);
668 		if (!test_mod_exp_aliased("m", modulus_alias_works, m, want,
669 		    a, p, m, ctx, test))
670 			failed |= 1;
671 	}
672 
673 	return failed;
674 }
675 
676 static int
test_bn_mod_exp_aliasing(void)677 test_bn_mod_exp_aliasing(void)
678 {
679 	BN_CTX *ctx;
680 	BIGNUM *a, *p, *m, *want, *got;
681 	size_t i;
682 	int failed = 0;
683 
684 	if ((ctx = BN_CTX_new()) == NULL)
685 		errx(1, "BN_CTX_new");
686 
687 	BN_CTX_start(ctx);
688 
689 	if ((a = BN_CTX_get(ctx)) == NULL)
690 		errx(1, "a = BN_CTX_get()");
691 	if ((p = BN_CTX_get(ctx)) == NULL)
692 		errx(1, "p = BN_CTX_get()");
693 	if ((m = BN_CTX_get(ctx)) == NULL)
694 		errx(1, "m = BN_CTX_get()");
695 	if ((want = BN_CTX_get(ctx)) == NULL)
696 		errx(1, "want = BN_CTX_get()");
697 	if ((got = BN_CTX_get(ctx)) == NULL)
698 		errx(1, "got = BN_CTX_get()");
699 
700 	for (i = 0; i < N_MOD_EXP_FN; i++) {
701 		const struct mod_exp_test *test = &mod_exp_fn[i];
702 		failed |= test_bn_mod_exp_aliasing_test(test, a, p, m,
703 		    want, got, ctx);
704 	}
705 
706 	BN_CTX_end(ctx);
707 	BN_CTX_free(ctx);
708 
709 	return failed;
710 }
711 
712 int
main(void)713 main(void)
714 {
715 	int failed = 0;
716 
717 	failed |= test_bn_mod_exp_zero();
718 	failed |= test_bn_mod_exp();
719 	failed |= test_bn_mod_exp2();
720 	failed |= test_bn_mod_exp2_mont_crash();
721 	failed |= test_bn_mod_exp_aliasing();
722 
723 	return failed;
724 }
725