xref: /freebsd/lib/libmp/mpasbn.c (revision 1d386b48)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause
3  *
4  * Copyright (c) 2001 Dima Dorfman.
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions
9  * are met:
10  * 1. Redistributions of source code must retain the above copyright
11  *    notice, this list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright
13  *    notice, this list of conditions and the following disclaimer in the
14  *    documentation and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
20  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
22  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
25  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
26  * SUCH DAMAGE.
27  */
28 
29 /*
30  * This is the traditional Berkeley MP library implemented in terms of
31  * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
32  * is meant to be as compatible with the latter as feasible.
33  *
34  * There seems to be a lack of documentation for the Berkeley MP
35  * interface.  All I could find was libgmp documentation (which didn't
36  * talk about the semantics of the functions) and an old SunOS 4.1
37  * manual page from 1989.  The latter wasn't very detailed, either,
38  * but at least described what the function's arguments were.  In
39  * general the interface seems to be archaic, somewhat poorly
40  * designed, and poorly, if at all, documented.  It is considered
41  * harmful.
42  *
43  * Miscellaneous notes on this implementation:
44  *
45  *  - The SunOS manual page mentioned above indicates that if an error
46  *  occurs, the library should "produce messages and core images."
47  *  Given that most of the functions don't have return values (and
48  *  thus no sane way of alerting the caller to an error), this seems
49  *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
50  *  respectively, then abort().
51  *
52  *  - All the functions which take an argument to be "filled in"
53  *  assume that the argument has been initialized by one of the *tom()
54  *  routines before being passed to it.  I never saw this documented
55  *  anywhere, but this seems to be consistent with the way this
56  *  library is used.
57  *
58  *  - msqrt() is the only routine which had to be implemented which
59  *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
60  *  It was implemented by hand using Newton's recursive formula.
61  *  Doing it this way, although more error-prone, has the positive
62  *  sideaffect of testing a lot of other functions; if msqrt()
63  *  produces the correct results, most of the other routines will as
64  *  well.
65  *
66  *  - Internal-use-only routines (i.e., those defined here statically
67  *  and not in mp.h) have an underscore prepended to their name (this
68  *  is more for aesthetical reasons than technical).  All such
69  *  routines take an extra argument, 'msg', that denotes what they
70  *  should call themselves in an error message.  This is so a user
71  *  doesn't get an error message from a function they didn't call.
72  */
73 
74 #include <sys/cdefs.h>
75 #include <ctype.h>
76 #include <err.h>
77 #include <errno.h>
78 #include <stdio.h>
79 #include <stdlib.h>
80 #include <string.h>
81 
82 #include <openssl/crypto.h>
83 #include <openssl/err.h>
84 
85 #include "mp.h"
86 
87 #define MPERR(s)	do { warn s; abort(); } while (0)
88 #define MPERRX(s)	do { warnx s; abort(); } while (0)
89 #define BN_ERRCHECK(msg, expr) do {		\
90 	if (!(expr)) _bnerr(msg);		\
91 } while (0)
92 
93 static void _bnerr(const char *);
94 static MINT *_dtom(const char *, const char *);
95 static MINT *_itom(const char *, short);
96 static void _madd(const char *, const MINT *, const MINT *, MINT *);
97 static int _mcmpa(const char *, const MINT *, const MINT *);
98 static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
99 		BN_CTX *);
100 static void _mfree(const char *, MINT *);
101 static void _moveb(const char *, const BIGNUM *, MINT *);
102 static void _movem(const char *, const MINT *, MINT *);
103 static void _msub(const char *, const MINT *, const MINT *, MINT *);
104 static char *_mtod(const char *, const MINT *);
105 static char *_mtox(const char *, const MINT *);
106 static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
107 static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
108 static MINT *_xtom(const char *, const char *);
109 
110 /*
111  * Report an error from one of the BN_* functions using MPERRX.
112  */
113 static void
_bnerr(const char * msg)114 _bnerr(const char *msg)
115 {
116 
117 	ERR_load_crypto_strings();
118 	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
119 }
120 
121 /*
122  * Convert a decimal string to an MINT.
123  */
124 static MINT *
_dtom(const char * msg,const char * s)125 _dtom(const char *msg, const char *s)
126 {
127 	MINT *mp;
128 
129 	mp = malloc(sizeof(*mp));
130 	if (mp == NULL)
131 		MPERR(("%s", msg));
132 	mp->bn = BN_new();
133 	if (mp->bn == NULL)
134 		_bnerr(msg);
135 	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
136 	return (mp);
137 }
138 
139 /*
140  * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
141  */
142 void
mp_gcd(const MINT * mp1,const MINT * mp2,MINT * rmp)143 mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
144 {
145 	BIGNUM *b;
146 	BN_CTX *c;
147 
148 	b = NULL;
149 	c = BN_CTX_new();
150 	if (c != NULL)
151 		b = BN_new();
152 	if (c == NULL || b == NULL)
153 		_bnerr("gcd");
154 	BN_ERRCHECK("gcd", BN_gcd(b, mp1->bn, mp2->bn, c));
155 	_moveb("gcd", b, rmp);
156 	BN_free(b);
157 	BN_CTX_free(c);
158 }
159 
160 /*
161  * Make an MINT out of a short integer.  Return value must be mfree()'d.
162  */
163 static MINT *
_itom(const char * msg,short n)164 _itom(const char *msg, short n)
165 {
166 	MINT *mp;
167 	char *s;
168 
169 	asprintf(&s, "%x", n);
170 	if (s == NULL)
171 		MPERR(("%s", msg));
172 	mp = _xtom(msg, s);
173 	free(s);
174 	return (mp);
175 }
176 
177 MINT *
mp_itom(short n)178 mp_itom(short n)
179 {
180 
181 	return (_itom("itom", n));
182 }
183 
184 /*
185  * Compute rmp=mp1+mp2.
186  */
187 static void
_madd(const char * msg,const MINT * mp1,const MINT * mp2,MINT * rmp)188 _madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
189 {
190 	BIGNUM *b;
191 
192 	b = BN_new();
193 	if (b == NULL)
194 		_bnerr(msg);
195 	BN_ERRCHECK(msg, BN_add(b, mp1->bn, mp2->bn));
196 	_moveb(msg, b, rmp);
197 	BN_free(b);
198 }
199 
200 void
mp_madd(const MINT * mp1,const MINT * mp2,MINT * rmp)201 mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
202 {
203 
204 	_madd("madd", mp1, mp2, rmp);
205 }
206 
207 /*
208  * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
209  */
210 int
mp_mcmp(const MINT * mp1,const MINT * mp2)211 mp_mcmp(const MINT *mp1, const MINT *mp2)
212 {
213 
214 	return (BN_cmp(mp1->bn, mp2->bn));
215 }
216 
217 /*
218  * Same as mcmp but compares absolute values.
219  */
220 static int
_mcmpa(const char * msg __unused,const MINT * mp1,const MINT * mp2)221 _mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
222 {
223 
224 	return (BN_ucmp(mp1->bn, mp2->bn));
225 }
226 
227 /*
228  * Compute qmp=nmp/dmp and rmp=nmp%dmp.
229  */
230 static void
_mdiv(const char * msg,const MINT * nmp,const MINT * dmp,MINT * qmp,MINT * rmp,BN_CTX * c)231 _mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
232     BN_CTX *c)
233 {
234 	BIGNUM *q, *r;
235 
236 	q = NULL;
237 	r = BN_new();
238 	if (r != NULL)
239 		q = BN_new();
240 	if (r == NULL || q == NULL)
241 		_bnerr(msg);
242 	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
243 	_moveb(msg, q, qmp);
244 	_moveb(msg, r, rmp);
245 	BN_free(q);
246 	BN_free(r);
247 }
248 
249 void
mp_mdiv(const MINT * nmp,const MINT * dmp,MINT * qmp,MINT * rmp)250 mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
251 {
252 	BN_CTX *c;
253 
254 	c = BN_CTX_new();
255 	if (c == NULL)
256 		_bnerr("mdiv");
257 	_mdiv("mdiv", nmp, dmp, qmp, rmp, c);
258 	BN_CTX_free(c);
259 }
260 
261 /*
262  * Free memory associated with an MINT.
263  */
264 static void
_mfree(const char * msg __unused,MINT * mp)265 _mfree(const char *msg __unused, MINT *mp)
266 {
267 
268 	BN_clear(mp->bn);
269 	BN_free(mp->bn);
270 	free(mp);
271 }
272 
273 void
mp_mfree(MINT * mp)274 mp_mfree(MINT *mp)
275 {
276 
277 	_mfree("mfree", mp);
278 }
279 
280 /*
281  * Read an integer from standard input and stick the result in mp.
282  * The input is treated to be in base 10.  This must be the silliest
283  * API in existence; why can't the program read in a string and call
284  * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
285  * exported.)
286  */
287 void
mp_min(MINT * mp)288 mp_min(MINT *mp)
289 {
290 	MINT *rmp;
291 	char *line, *nline;
292 	size_t linelen;
293 
294 	line = fgetln(stdin, &linelen);
295 	if (line == NULL)
296 		MPERR(("min"));
297 	nline = malloc(linelen + 1);
298 	if (nline == NULL)
299 		MPERR(("min"));
300 	memcpy(nline, line, linelen);
301 	nline[linelen] = '\0';
302 	rmp = _dtom("min", nline);
303 	_movem("min", rmp, mp);
304 	_mfree("min", rmp);
305 	free(nline);
306 }
307 
308 /*
309  * Print the value of mp to standard output in base 10.  See blurb
310  * above min() for why this is so useless.
311  */
312 void
mp_mout(const MINT * mp)313 mp_mout(const MINT *mp)
314 {
315 	char *s;
316 
317 	s = _mtod("mout", mp);
318 	printf("%s", s);
319 	free(s);
320 }
321 
322 /*
323  * Set the value of tmp to the value of smp (i.e., tmp=smp).
324  */
325 void
mp_move(const MINT * smp,MINT * tmp)326 mp_move(const MINT *smp, MINT *tmp)
327 {
328 
329 	_movem("move", smp, tmp);
330 }
331 
332 
333 /*
334  * Internal routine to set the value of tmp to that of sbp.
335  */
336 static void
_moveb(const char * msg,const BIGNUM * sbp,MINT * tmp)337 _moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
338 {
339 
340 	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
341 }
342 
343 /*
344  * Internal routine to set the value of tmp to that of smp.
345  */
346 static void
_movem(const char * msg,const MINT * smp,MINT * tmp)347 _movem(const char *msg, const MINT *smp, MINT *tmp)
348 {
349 
350 	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
351 }
352 
353 /*
354  * Compute the square root of nmp and put the result in xmp.  The
355  * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
356  *
357  * Note that the OpenSSL BIGNUM library does not have a square root
358  * function, so this had to be implemented by hand using Newton's
359  * recursive formula:
360  *
361  *		x = (x + (n / x)) / 2
362  *
363  * where x is the square root of the positive number n.  In the
364  * beginning, x should be a reasonable guess, but the value 1,
365  * although suboptimal, works, too; this is that is used below.
366  */
367 void
mp_msqrt(const MINT * nmp,MINT * xmp,MINT * rmp)368 mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
369 {
370 	BN_CTX *c;
371 	MINT *tolerance;
372 	MINT *ox, *x;
373 	MINT *z1, *z2, *z3;
374 	short i;
375 
376 	c = BN_CTX_new();
377 	if (c == NULL)
378 		_bnerr("msqrt");
379 	tolerance = _itom("msqrt", 1);
380 	x = _itom("msqrt", 1);
381 	ox = _itom("msqrt", 0);
382 	z1 = _itom("msqrt", 0);
383 	z2 = _itom("msqrt", 0);
384 	z3 = _itom("msqrt", 0);
385 	do {
386 		_movem("msqrt", x, ox);
387 		_mdiv("msqrt", nmp, x, z1, z2, c);
388 		_madd("msqrt", x, z1, z2);
389 		_sdiv("msqrt", z2, 2, x, &i, c);
390 		_msub("msqrt", ox, x, z3);
391 	} while (_mcmpa("msqrt", z3, tolerance) == 1);
392 	_movem("msqrt", x, xmp);
393 	_mult("msqrt", x, x, z1, c);
394 	_msub("msqrt", nmp, z1, z2);
395 	_movem("msqrt", z2, rmp);
396 	_mfree("msqrt", tolerance);
397 	_mfree("msqrt", ox);
398 	_mfree("msqrt", x);
399 	_mfree("msqrt", z1);
400 	_mfree("msqrt", z2);
401 	_mfree("msqrt", z3);
402 	BN_CTX_free(c);
403 }
404 
405 /*
406  * Compute rmp=mp1-mp2.
407  */
408 static void
_msub(const char * msg,const MINT * mp1,const MINT * mp2,MINT * rmp)409 _msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
410 {
411 	BIGNUM *b;
412 
413 	b = BN_new();
414 	if (b == NULL)
415 		_bnerr(msg);
416 	BN_ERRCHECK(msg, BN_sub(b, mp1->bn, mp2->bn));
417 	_moveb(msg, b, rmp);
418 	BN_free(b);
419 }
420 
421 void
mp_msub(const MINT * mp1,const MINT * mp2,MINT * rmp)422 mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
423 {
424 
425 	_msub("msub", mp1, mp2, rmp);
426 }
427 
428 /*
429  * Return a decimal representation of mp.  Return value must be
430  * free()'d.
431  */
432 static char *
_mtod(const char * msg,const MINT * mp)433 _mtod(const char *msg, const MINT *mp)
434 {
435 	char *s, *s2;
436 
437 	s = BN_bn2dec(mp->bn);
438 	if (s == NULL)
439 		_bnerr(msg);
440 	asprintf(&s2, "%s", s);
441 	if (s2 == NULL)
442 		MPERR(("%s", msg));
443 	OPENSSL_free(s);
444 	return (s2);
445 }
446 
447 /*
448  * Return a hexadecimal representation of mp.  Return value must be
449  * free()'d.
450  */
451 static char *
_mtox(const char * msg,const MINT * mp)452 _mtox(const char *msg, const MINT *mp)
453 {
454 	char *p, *s, *s2;
455 	int len;
456 
457 	s = BN_bn2hex(mp->bn);
458 	if (s == NULL)
459 		_bnerr(msg);
460 	asprintf(&s2, "%s", s);
461 	if (s2 == NULL)
462 		MPERR(("%s", msg));
463 	OPENSSL_free(s);
464 
465 	/*
466 	 * This is a kludge for libgmp compatibility.  The latter's
467 	 * implementation of this function returns lower-case letters,
468 	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
469 	 * newkey(1)) are sensitive to this.  Although it's probably
470 	 * their fault, it's nice to be compatible.
471 	 */
472 	len = strlen(s2);
473 	for (p = s2; p < s2 + len; p++)
474 		*p = tolower(*p);
475 
476 	return (s2);
477 }
478 
479 char *
mp_mtox(const MINT * mp)480 mp_mtox(const MINT *mp)
481 {
482 
483 	return (_mtox("mtox", mp));
484 }
485 
486 /*
487  * Compute rmp=mp1*mp2.
488  */
489 static void
_mult(const char * msg,const MINT * mp1,const MINT * mp2,MINT * rmp,BN_CTX * c)490 _mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
491 {
492 	BIGNUM *b;
493 
494 	b = BN_new();
495 	if (b == NULL)
496 		_bnerr(msg);
497 	BN_ERRCHECK(msg, BN_mul(b, mp1->bn, mp2->bn, c));
498 	_moveb(msg, b, rmp);
499 	BN_free(b);
500 }
501 
502 void
mp_mult(const MINT * mp1,const MINT * mp2,MINT * rmp)503 mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
504 {
505 	BN_CTX *c;
506 
507 	c = BN_CTX_new();
508 	if (c == NULL)
509 		_bnerr("mult");
510 	_mult("mult", mp1, mp2, rmp, c);
511 	BN_CTX_free(c);
512 }
513 
514 /*
515  * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
516  * means 'raise to power', not 'bitwise XOR'.)
517  */
518 void
mp_pow(const MINT * bmp,const MINT * emp,const MINT * mmp,MINT * rmp)519 mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
520 {
521 	BIGNUM *b;
522 	BN_CTX *c;
523 
524 	b = NULL;
525 	c = BN_CTX_new();
526 	if (c != NULL)
527 		b = BN_new();
528 	if (c == NULL || b == NULL)
529 		_bnerr("pow");
530 	BN_ERRCHECK("pow", BN_mod_exp(b, bmp->bn, emp->bn, mmp->bn, c));
531 	_moveb("pow", b, rmp);
532 	BN_free(b);
533 	BN_CTX_free(c);
534 }
535 
536 /*
537  * Compute rmp=bmp^e.  (See note above pow().)
538  */
539 void
mp_rpow(const MINT * bmp,short e,MINT * rmp)540 mp_rpow(const MINT *bmp, short e, MINT *rmp)
541 {
542 	MINT *emp;
543 	BIGNUM *b;
544 	BN_CTX *c;
545 
546 	b = NULL;
547 	c = BN_CTX_new();
548 	if (c != NULL)
549 		b = BN_new();
550 	if (c == NULL || b == NULL)
551 		_bnerr("rpow");
552 	emp = _itom("rpow", e);
553 	BN_ERRCHECK("rpow", BN_exp(b, bmp->bn, emp->bn, c));
554 	_moveb("rpow", b, rmp);
555 	_mfree("rpow", emp);
556 	BN_free(b);
557 	BN_CTX_free(c);
558 }
559 
560 /*
561  * Compute qmp=nmp/d and ro=nmp%d.
562  */
563 static void
_sdiv(const char * msg,const MINT * nmp,short d,MINT * qmp,short * ro,BN_CTX * c)564 _sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
565     BN_CTX *c)
566 {
567 	MINT *dmp, *rmp;
568 	BIGNUM *q, *r;
569 	char *s;
570 
571 	r = NULL;
572 	q = BN_new();
573 	if (q != NULL)
574 		r = BN_new();
575 	if (q == NULL || r == NULL)
576 		_bnerr(msg);
577 	dmp = _itom(msg, d);
578 	rmp = _itom(msg, 0);
579 	BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
580 	_moveb(msg, q, qmp);
581 	_moveb(msg, r, rmp);
582 	s = _mtox(msg, rmp);
583 	errno = 0;
584 	*ro = strtol(s, NULL, 16);
585 	if (errno != 0)
586 		MPERR(("%s underflow or overflow", msg));
587 	free(s);
588 	_mfree(msg, dmp);
589 	_mfree(msg, rmp);
590 	BN_free(r);
591 	BN_free(q);
592 }
593 
594 void
mp_sdiv(const MINT * nmp,short d,MINT * qmp,short * ro)595 mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
596 {
597 	BN_CTX *c;
598 
599 	c = BN_CTX_new();
600 	if (c == NULL)
601 		_bnerr("sdiv");
602 	_sdiv("sdiv", nmp, d, qmp, ro, c);
603 	BN_CTX_free(c);
604 }
605 
606 /*
607  * Convert a hexadecimal string to an MINT.
608  */
609 static MINT *
_xtom(const char * msg,const char * s)610 _xtom(const char *msg, const char *s)
611 {
612 	MINT *mp;
613 
614 	mp = malloc(sizeof(*mp));
615 	if (mp == NULL)
616 		MPERR(("%s", msg));
617 	mp->bn = BN_new();
618 	if (mp->bn == NULL)
619 		_bnerr(msg);
620 	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
621 	return (mp);
622 }
623 
624 MINT *
mp_xtom(const char * s)625 mp_xtom(const char *s)
626 {
627 
628 	return (_xtom("xtom", s));
629 }
630