xref: /freebsd/lib/libmp/mpasbn.c (revision bdd1243d)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2001 Dima Dorfman.
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26  * SUCH DAMAGE.
27  */
28 
29 /*
30  * This is the traditional Berkeley MP library implemented in terms of
31  * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
32  * is meant to be as compatible with the latter as feasible.
33  *
34  * There seems to be a lack of documentation for the Berkeley MP
35  * interface.  All I could find was libgmp documentation (which didn't
36  * talk about the semantics of the functions) and an old SunOS 4.1
37  * manual page from 1989.  The latter wasn't very detailed, either,
38  * but at least described what the function's arguments were.  In
39  * general the interface seems to be archaic, somewhat poorly
40  * designed, and poorly, if at all, documented.  It is considered
41  * harmful.
42  *
43  * Miscellaneous notes on this implementation:
44  *
45  *  - The SunOS manual page mentioned above indicates that if an error
46  *  occurs, the library should "produce messages and core images."
47  *  Given that most of the functions don't have return values (and
48  *  thus no sane way of alerting the caller to an error), this seems
49  *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
50  *  respectively, then abort().
51  *
52  *  - All the functions which take an argument to be "filled in"
53  *  assume that the argument has been initialized by one of the *tom()
54  *  routines before being passed to it.  I never saw this documented
55  *  anywhere, but this seems to be consistent with the way this
56  *  library is used.
57  *
58  *  - msqrt() is the only routine which had to be implemented which
59  *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
60  *  It was implemented by hand using Newton's recursive formula.
61  *  Doing it this way, although more error-prone, has the positive
62  *  sideaffect of testing a lot of other functions; if msqrt()
63  *  produces the correct results, most of the other routines will as
64  *  well.
65  *
66  *  - Internal-use-only routines (i.e., those defined here statically
67  *  and not in mp.h) have an underscore prepended to their name (this
68  *  is more for aesthetical reasons than technical).  All such
69  *  routines take an extra argument, 'msg', that denotes what they
70  *  should call themselves in an error message.  This is so a user
71  *  doesn't get an error message from a function they didn't call.
72  */
73 
74 #include <sys/cdefs.h>
75 __FBSDID("$FreeBSD$");
76 
77 #include <ctype.h>
78 #include <err.h>
79 #include <errno.h>
80 #include <stdio.h>
81 #include <stdlib.h>
82 #include <string.h>
83 
84 #include <openssl/crypto.h>
85 #include <openssl/err.h>
86 
87 #include "mp.h"
88 
89 #define MPERR(s)	do { warn s; abort(); } while (0)
90 #define MPERRX(s)	do { warnx s; abort(); } while (0)
91 #define BN_ERRCHECK(msg, expr) do {		\
92 	if (!(expr)) _bnerr(msg);		\
93 } while (0)
94 
95 static void _bnerr(const char *);
96 static MINT *_dtom(const char *, const char *);
97 static MINT *_itom(const char *, short);
98 static void _madd(const char *, const MINT *, const MINT *, MINT *);
99 static int _mcmpa(const char *, const MINT *, const MINT *);
100 static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
101 		BN_CTX *);
102 static void _mfree(const char *, MINT *);
103 static void _moveb(const char *, const BIGNUM *, MINT *);
104 static void _movem(const char *, const MINT *, MINT *);
105 static void _msub(const char *, const MINT *, const MINT *, MINT *);
106 static char *_mtod(const char *, const MINT *);
107 static char *_mtox(const char *, const MINT *);
108 static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
109 static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
110 static MINT *_xtom(const char *, const char *);
111 
112 /*
113  * Report an error from one of the BN_* functions using MPERRX.
114  */
115 static void
116 _bnerr(const char *msg)
117 {
118 
119 	ERR_load_crypto_strings();
120 	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
121 }
122 
123 /*
124  * Convert a decimal string to an MINT.
125  */
126 static MINT *
127 _dtom(const char *msg, const char *s)
128 {
129 	MINT *mp;
130 
131 	mp = malloc(sizeof(*mp));
132 	if (mp == NULL)
133 		MPERR(("%s", msg));
134 	mp->bn = BN_new();
135 	if (mp->bn == NULL)
136 		_bnerr(msg);
137 	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
138 	return (mp);
139 }
140 
141 /*
142  * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
143  */
144 void
145 mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
146 {
147 	BIGNUM *b;
148 	BN_CTX *c;
149 
150 	b = NULL;
151 	c = BN_CTX_new();
152 	if (c != NULL)
153 		b = BN_new();
154 	if (c == NULL || b == NULL)
155 		_bnerr("gcd");
156 	BN_ERRCHECK("gcd", BN_gcd(b, mp1->bn, mp2->bn, c));
157 	_moveb("gcd", b, rmp);
158 	BN_free(b);
159 	BN_CTX_free(c);
160 }
161 
162 /*
163  * Make an MINT out of a short integer.  Return value must be mfree()'d.
164  */
165 static MINT *
166 _itom(const char *msg, short n)
167 {
168 	MINT *mp;
169 	char *s;
170 
171 	asprintf(&s, "%x", n);
172 	if (s == NULL)
173 		MPERR(("%s", msg));
174 	mp = _xtom(msg, s);
175 	free(s);
176 	return (mp);
177 }
178 
179 MINT *
180 mp_itom(short n)
181 {
182 
183 	return (_itom("itom", n));
184 }
185 
186 /*
187  * Compute rmp=mp1+mp2.
188  */
189 static void
190 _madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
191 {
192 	BIGNUM *b;
193 
194 	b = BN_new();
195 	if (b == NULL)
196 		_bnerr(msg);
197 	BN_ERRCHECK(msg, BN_add(b, mp1->bn, mp2->bn));
198 	_moveb(msg, b, rmp);
199 	BN_free(b);
200 }
201 
202 void
203 mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
204 {
205 
206 	_madd("madd", mp1, mp2, rmp);
207 }
208 
209 /*
210  * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
211  */
212 int
213 mp_mcmp(const MINT *mp1, const MINT *mp2)
214 {
215 
216 	return (BN_cmp(mp1->bn, mp2->bn));
217 }
218 
219 /*
220  * Same as mcmp but compares absolute values.
221  */
222 static int
223 _mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
224 {
225 
226 	return (BN_ucmp(mp1->bn, mp2->bn));
227 }
228 
229 /*
230  * Compute qmp=nmp/dmp and rmp=nmp%dmp.
231  */
232 static void
233 _mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
234     BN_CTX *c)
235 {
236 	BIGNUM *q, *r;
237 
238 	q = NULL;
239 	r = BN_new();
240 	if (r != NULL)
241 		q = BN_new();
242 	if (r == NULL || q == NULL)
243 		_bnerr(msg);
244 	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
245 	_moveb(msg, q, qmp);
246 	_moveb(msg, r, rmp);
247 	BN_free(q);
248 	BN_free(r);
249 }
250 
251 void
252 mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
253 {
254 	BN_CTX *c;
255 
256 	c = BN_CTX_new();
257 	if (c == NULL)
258 		_bnerr("mdiv");
259 	_mdiv("mdiv", nmp, dmp, qmp, rmp, c);
260 	BN_CTX_free(c);
261 }
262 
263 /*
264  * Free memory associated with an MINT.
265  */
266 static void
267 _mfree(const char *msg __unused, MINT *mp)
268 {
269 
270 	BN_clear(mp->bn);
271 	BN_free(mp->bn);
272 	free(mp);
273 }
274 
275 void
276 mp_mfree(MINT *mp)
277 {
278 
279 	_mfree("mfree", mp);
280 }
281 
282 /*
283  * Read an integer from standard input and stick the result in mp.
284  * The input is treated to be in base 10.  This must be the silliest
285  * API in existence; why can't the program read in a string and call
286  * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
287  * exported.)
288  */
289 void
290 mp_min(MINT *mp)
291 {
292 	MINT *rmp;
293 	char *line, *nline;
294 	size_t linelen;
295 
296 	line = fgetln(stdin, &linelen);
297 	if (line == NULL)
298 		MPERR(("min"));
299 	nline = malloc(linelen + 1);
300 	if (nline == NULL)
301 		MPERR(("min"));
302 	memcpy(nline, line, linelen);
303 	nline[linelen] = '\0';
304 	rmp = _dtom("min", nline);
305 	_movem("min", rmp, mp);
306 	_mfree("min", rmp);
307 	free(nline);
308 }
309 
310 /*
311  * Print the value of mp to standard output in base 10.  See blurb
312  * above min() for why this is so useless.
313  */
314 void
315 mp_mout(const MINT *mp)
316 {
317 	char *s;
318 
319 	s = _mtod("mout", mp);
320 	printf("%s", s);
321 	free(s);
322 }
323 
324 /*
325  * Set the value of tmp to the value of smp (i.e., tmp=smp).
326  */
327 void
328 mp_move(const MINT *smp, MINT *tmp)
329 {
330 
331 	_movem("move", smp, tmp);
332 }
333 
334 
335 /*
336  * Internal routine to set the value of tmp to that of sbp.
337  */
338 static void
339 _moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
340 {
341 
342 	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
343 }
344 
345 /*
346  * Internal routine to set the value of tmp to that of smp.
347  */
348 static void
349 _movem(const char *msg, const MINT *smp, MINT *tmp)
350 {
351 
352 	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
353 }
354 
355 /*
356  * Compute the square root of nmp and put the result in xmp.  The
357  * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
358  *
359  * Note that the OpenSSL BIGNUM library does not have a square root
360  * function, so this had to be implemented by hand using Newton's
361  * recursive formula:
362  *
363  *		x = (x + (n / x)) / 2
364  *
365  * where x is the square root of the positive number n.  In the
366  * beginning, x should be a reasonable guess, but the value 1,
367  * although suboptimal, works, too; this is that is used below.
368  */
369 void
370 mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
371 {
372 	BN_CTX *c;
373 	MINT *tolerance;
374 	MINT *ox, *x;
375 	MINT *z1, *z2, *z3;
376 	short i;
377 
378 	c = BN_CTX_new();
379 	if (c == NULL)
380 		_bnerr("msqrt");
381 	tolerance = _itom("msqrt", 1);
382 	x = _itom("msqrt", 1);
383 	ox = _itom("msqrt", 0);
384 	z1 = _itom("msqrt", 0);
385 	z2 = _itom("msqrt", 0);
386 	z3 = _itom("msqrt", 0);
387 	do {
388 		_movem("msqrt", x, ox);
389 		_mdiv("msqrt", nmp, x, z1, z2, c);
390 		_madd("msqrt", x, z1, z2);
391 		_sdiv("msqrt", z2, 2, x, &i, c);
392 		_msub("msqrt", ox, x, z3);
393 	} while (_mcmpa("msqrt", z3, tolerance) == 1);
394 	_movem("msqrt", x, xmp);
395 	_mult("msqrt", x, x, z1, c);
396 	_msub("msqrt", nmp, z1, z2);
397 	_movem("msqrt", z2, rmp);
398 	_mfree("msqrt", tolerance);
399 	_mfree("msqrt", ox);
400 	_mfree("msqrt", x);
401 	_mfree("msqrt", z1);
402 	_mfree("msqrt", z2);
403 	_mfree("msqrt", z3);
404 	BN_CTX_free(c);
405 }
406 
407 /*
408  * Compute rmp=mp1-mp2.
409  */
410 static void
411 _msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
412 {
413 	BIGNUM *b;
414 
415 	b = BN_new();
416 	if (b == NULL)
417 		_bnerr(msg);
418 	BN_ERRCHECK(msg, BN_sub(b, mp1->bn, mp2->bn));
419 	_moveb(msg, b, rmp);
420 	BN_free(b);
421 }
422 
423 void
424 mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
425 {
426 
427 	_msub("msub", mp1, mp2, rmp);
428 }
429 
430 /*
431  * Return a decimal representation of mp.  Return value must be
432  * free()'d.
433  */
434 static char *
435 _mtod(const char *msg, const MINT *mp)
436 {
437 	char *s, *s2;
438 
439 	s = BN_bn2dec(mp->bn);
440 	if (s == NULL)
441 		_bnerr(msg);
442 	asprintf(&s2, "%s", s);
443 	if (s2 == NULL)
444 		MPERR(("%s", msg));
445 	OPENSSL_free(s);
446 	return (s2);
447 }
448 
449 /*
450  * Return a hexadecimal representation of mp.  Return value must be
451  * free()'d.
452  */
453 static char *
454 _mtox(const char *msg, const MINT *mp)
455 {
456 	char *p, *s, *s2;
457 	int len;
458 
459 	s = BN_bn2hex(mp->bn);
460 	if (s == NULL)
461 		_bnerr(msg);
462 	asprintf(&s2, "%s", s);
463 	if (s2 == NULL)
464 		MPERR(("%s", msg));
465 	OPENSSL_free(s);
466 
467 	/*
468 	 * This is a kludge for libgmp compatibility.  The latter's
469 	 * implementation of this function returns lower-case letters,
470 	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
471 	 * newkey(1)) are sensitive to this.  Although it's probably
472 	 * their fault, it's nice to be compatible.
473 	 */
474 	len = strlen(s2);
475 	for (p = s2; p < s2 + len; p++)
476 		*p = tolower(*p);
477 
478 	return (s2);
479 }
480 
481 char *
482 mp_mtox(const MINT *mp)
483 {
484 
485 	return (_mtox("mtox", mp));
486 }
487 
488 /*
489  * Compute rmp=mp1*mp2.
490  */
491 static void
492 _mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
493 {
494 	BIGNUM *b;
495 
496 	b = BN_new();
497 	if (b == NULL)
498 		_bnerr(msg);
499 	BN_ERRCHECK(msg, BN_mul(b, mp1->bn, mp2->bn, c));
500 	_moveb(msg, b, rmp);
501 	BN_free(b);
502 }
503 
504 void
505 mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
506 {
507 	BN_CTX *c;
508 
509 	c = BN_CTX_new();
510 	if (c == NULL)
511 		_bnerr("mult");
512 	_mult("mult", mp1, mp2, rmp, c);
513 	BN_CTX_free(c);
514 }
515 
516 /*
517  * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
518  * means 'raise to power', not 'bitwise XOR'.)
519  */
520 void
521 mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
522 {
523 	BIGNUM *b;
524 	BN_CTX *c;
525 
526 	b = NULL;
527 	c = BN_CTX_new();
528 	if (c != NULL)
529 		b = BN_new();
530 	if (c == NULL || b == NULL)
531 		_bnerr("pow");
532 	BN_ERRCHECK("pow", BN_mod_exp(b, bmp->bn, emp->bn, mmp->bn, c));
533 	_moveb("pow", b, rmp);
534 	BN_free(b);
535 	BN_CTX_free(c);
536 }
537 
538 /*
539  * Compute rmp=bmp^e.  (See note above pow().)
540  */
541 void
542 mp_rpow(const MINT *bmp, short e, MINT *rmp)
543 {
544 	MINT *emp;
545 	BIGNUM *b;
546 	BN_CTX *c;
547 
548 	b = NULL;
549 	c = BN_CTX_new();
550 	if (c != NULL)
551 		b = BN_new();
552 	if (c == NULL || b == NULL)
553 		_bnerr("rpow");
554 	emp = _itom("rpow", e);
555 	BN_ERRCHECK("rpow", BN_exp(b, bmp->bn, emp->bn, c));
556 	_moveb("rpow", b, rmp);
557 	_mfree("rpow", emp);
558 	BN_free(b);
559 	BN_CTX_free(c);
560 }
561 
562 /*
563  * Compute qmp=nmp/d and ro=nmp%d.
564  */
565 static void
566 _sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
567     BN_CTX *c)
568 {
569 	MINT *dmp, *rmp;
570 	BIGNUM *q, *r;
571 	char *s;
572 
573 	r = NULL;
574 	q = BN_new();
575 	if (q != NULL)
576 		r = BN_new();
577 	if (q == NULL || r == NULL)
578 		_bnerr(msg);
579 	dmp = _itom(msg, d);
580 	rmp = _itom(msg, 0);
581 	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
582 	_moveb(msg, q, qmp);
583 	_moveb(msg, r, rmp);
584 	s = _mtox(msg, rmp);
585 	errno = 0;
586 	*ro = strtol(s, NULL, 16);
587 	if (errno != 0)
588 		MPERR(("%s underflow or overflow", msg));
589 	free(s);
590 	_mfree(msg, dmp);
591 	_mfree(msg, rmp);
592 	BN_free(r);
593 	BN_free(q);
594 }
595 
596 void
597 mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
598 {
599 	BN_CTX *c;
600 
601 	c = BN_CTX_new();
602 	if (c == NULL)
603 		_bnerr("sdiv");
604 	_sdiv("sdiv", nmp, d, qmp, ro, c);
605 	BN_CTX_free(c);
606 }
607 
608 /*
609  * Convert a hexadecimal string to an MINT.
610  */
611 static MINT *
612 _xtom(const char *msg, const char *s)
613 {
614 	MINT *mp;
615 
616 	mp = malloc(sizeof(*mp));
617 	if (mp == NULL)
618 		MPERR(("%s", msg));
619 	mp->bn = BN_new();
620 	if (mp->bn == NULL)
621 		_bnerr(msg);
622 	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
623 	return (mp);
624 }
625 
626 MINT *
627 mp_xtom(const char *s)
628 {
629 
630 	return (_xtom("xtom", s));
631 }
632