xref: /freebsd/lib/libmp/mpasbn.c (revision 76f29359)
1 /*
2  * Copyright (c) 2001 Dima Dorfman.
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions
7  * are met:
8  * 1. Redistributions of source code must retain the above copyright
9  *    notice, this list of conditions and the following disclaimer.
10  * 2. Redistributions in binary form must reproduce the above copyright
11  *    notice, this list of conditions and the following disclaimer in the
12  *    documentation and/or other materials provided with the distribution.
13  *
14  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24  * SUCH DAMAGE.
25  */
26 
27 /*
28  * This is the traditional Berkeley MP library implemented in terms of
29  * the OpenSSL BIGNUM library.  It was written to replace libgmp, and
30  * is meant to be as compatible with the latter as feasible.
31  *
32  * There seems to be a lack of documentation for the Berkeley MP
33  * interface.  All I could find was libgmp documentation (which didn't
34  * talk about the semantics of the functions) and an old SunOS 4.1
35  * manual page from 1989.  The latter wasn't very detailed, either,
36  * but at least described what the function's arguments were.  In
37  * general the interface seems to be archaic, somewhat poorly
38  * designed, and poorly, if at all, documented.  It is considered
39  * harmful.
40  *
41  * Miscellaneous notes on this implementation:
42  *
43  *  - The SunOS manual page mentioned above indicates that if an error
44  *  occurs, the library should "produce messages and core images."
45  *  Given that most of the functions don't have return values (and
46  *  thus no sane way of alerting the caller to an error), this seems
47  *  reasonable.  The MPERR and MPERRX macros call warn and warnx,
48  *  respectively, then abort().
49  *
50  *  - All the functions which take an argument to be "filled in"
51  *  assume that the argument has been initialized by one of the *tom()
52  *  routines before being passed to it.  I never saw this documented
53  *  anywhere, but this seems to be consistent with the way this
54  *  library is used.
55  *
56  *  - msqrt() is the only routine which had to be implemented which
57  *  doesn't have a close counterpart in the OpenSSL BIGNUM library.
58  *  It was implemented by hand using Newton's recursive formula.
59  *  Doing it this way, although more error-prone, has the positive
60  *  sideaffect of testing a lot of other functions; if msqrt()
61  *  produces the correct results, most of the other routines will as
62  *  well.
63  *
64  *  - Internal-use-only routines (i.e., those defined here statically
65  *  and not in mp.h) have an underscore prepended to their name (this
66  *  is more for aesthetical reasons than technical).  All such
67  *  routines take an extra argument, 'msg', that denotes what they
68  *  should call themselves in an error message.  This is so a user
69  *  doesn't get an error message from a function they didn't call.
70  */
71 
72 #include <sys/cdefs.h>
73 __FBSDID("$FreeBSD$");
74 
75 #include <ctype.h>
76 #include <err.h>
77 #include <errno.h>
78 #include <stdio.h>
79 #include <stdlib.h>
80 #include <string.h>
81 
82 #include <openssl/crypto.h>
83 #include <openssl/err.h>
84 
85 #include "mp.h"
86 
87 #define MPERR(s)	do { warn s; abort(); } while (0)
88 #define MPERRX(s)	do { warnx s; abort(); } while (0)
89 #define BN_ERRCHECK(msg, expr) do {		\
90 	if (!(expr)) _bnerr(msg);		\
91 } while (0)
92 
93 static void _bnerr(const char *);
94 static MINT *_dtom(const char *, const char *);
95 static MINT *_itom(const char *, short);
96 static void _madd(const char *, const MINT *, const MINT *, MINT *);
97 static int _mcmpa(const char *, const MINT *, const MINT *);
98 static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *);
99 static void _mfree(const char *, MINT *);
100 static void _moveb(const char *, const BIGNUM *, MINT *);
101 static void _movem(const char *, const MINT *, MINT *);
102 static void _msub(const char *, const MINT *, const MINT *, MINT *);
103 static char *_mtod(const char *, const MINT *);
104 static char *_mtox(const char *, const MINT *);
105 static void _mult(const char *, const MINT *, const MINT *, MINT *);
106 static void _sdiv(const char *, const MINT *, short, MINT *, short *);
107 static MINT *_xtom(const char *, const char *);
108 
109 /*
110  * Report an error from one of the BN_* functions using MPERRX.
111  */
112 static void
113 _bnerr(const char *msg)
114 {
115 
116 	ERR_load_crypto_strings();
117 	MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
118 }
119 
120 /*
121  * Convert a decimal string to an MINT.
122  */
123 static MINT *
124 _dtom(const char *msg, const char *s)
125 {
126 	MINT *mp;
127 
128 	mp = malloc(sizeof(*mp));
129 	if (mp == NULL)
130 		MPERR(("%s", msg));
131 	mp->bn = BN_new();
132 	if (mp->bn == NULL)
133 		_bnerr(msg);
134 	BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
135 	return (mp);
136 }
137 
138 /*
139  * Compute the greatest common divisor of mp1 and mp2; result goes in rmp.
140  */
141 void
142 gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
143 {
144 	BIGNUM b;
145 	BN_CTX *c;
146 
147 	c = BN_CTX_new();
148 	if (c == NULL)
149 		_bnerr("gcd");
150 	BN_init(&b);
151 	BN_ERRCHECK("gcd", BN_gcd(&b, mp1->bn, mp2->bn, c));
152 	_moveb("gcd", &b, rmp);
153 	BN_free(&b);
154 	BN_CTX_free(c);
155 }
156 
157 /*
158  * Make an MINT out of a short integer.  Return value must be mfree()'d.
159  */
160 static MINT *
161 _itom(const char *msg, short n)
162 {
163 	MINT *mp;
164 	char *s;
165 
166 	asprintf(&s, "%x", n);
167 	if (s == NULL)
168 		MPERR(("%s", msg));
169 	mp = _xtom(msg, s);
170 	free(s);
171 	return (mp);
172 }
173 
174 MINT *
175 itom(short n)
176 {
177 
178 	return (_itom("itom", n));
179 }
180 
181 /*
182  * Compute rmp=mp1+mp2.
183  */
184 static void
185 _madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
186 {
187 	BIGNUM b;
188 
189 	BN_init(&b);
190 	BN_ERRCHECK(msg, BN_add(&b, mp1->bn, mp2->bn));
191 	_moveb(msg, &b, rmp);
192 	BN_free(&b);
193 }
194 
195 void
196 madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
197 {
198 
199 	_madd("madd", mp1, mp2, rmp);
200 }
201 
202 /*
203  * Return -1, 0, or 1 if mp1<mp2, mp1==mp2, or mp1>mp2, respectivley.
204  */
205 int
206 mcmp(const MINT *mp1, const MINT *mp2)
207 {
208 
209 	return (BN_cmp(mp1->bn, mp2->bn));
210 }
211 
212 /*
213  * Same as mcmp but compares absolute values.
214  */
215 static int
216 _mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
217 {
218 
219 	return (BN_ucmp(mp1->bn, mp2->bn));
220 }
221 
222 /*
223  * Compute qmp=nmp/dmp and rmp=nmp%dmp.
224  */
225 static void
226 _mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
227 {
228 	BIGNUM q, r;
229 	BN_CTX *c;
230 
231 	c = BN_CTX_new();
232 	if (c == NULL)
233 		_bnerr(msg);
234 	BN_init(&r);
235 	BN_init(&q);
236 	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
237 	_moveb(msg, &q, qmp);
238 	_moveb(msg, &r, rmp);
239 	BN_free(&q);
240 	BN_free(&r);
241 	BN_CTX_free(c);
242 }
243 
244 void
245 mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
246 {
247 
248 	_mdiv("mdiv", nmp, dmp, qmp, rmp);
249 }
250 
251 /*
252  * Free memory associated with an MINT.
253  */
254 static void
255 _mfree(const char *msg __unused, MINT *mp)
256 {
257 
258 	BN_clear(mp->bn);
259 	BN_free(mp->bn);
260 	free(mp);
261 }
262 
263 void
264 mfree(MINT *mp)
265 {
266 
267 	_mfree("mfree", mp);
268 }
269 
270 /*
271  * Read an integer from standard input and stick the result in mp.
272  * The input is treated to be in base 10.  This must be the silliest
273  * API in existence; why can't the program read in a string and call
274  * xtom()?  (Or if base 10 is desires, perhaps dtom() could be
275  * exported.)
276  */
277 void
278 min(MINT *mp)
279 {
280 	MINT *rmp;
281 	char *line, *nline;
282 	size_t linelen;
283 
284 	line = fgetln(stdin, &linelen);
285 	if (line == NULL)
286 		MPERR(("min"));
287 	nline = malloc(linelen);
288 	if (nline == NULL)
289 		MPERR(("min"));
290 	strncpy(nline, line, linelen);
291 	nline[linelen] = '\0';
292 	rmp = _dtom("min", nline);
293 	_movem("min", rmp, mp);
294 	_mfree("min", rmp);
295 	free(nline);
296 }
297 
298 /*
299  * Print the value of mp to standard output in base 10.  See blurb
300  * above min() for why this is so useless.
301  */
302 void
303 mout(const MINT *mp)
304 {
305 	char *s;
306 
307 	s = _mtod("mout", mp);
308 	printf("%s", s);
309 	free(s);
310 }
311 
312 /*
313  * Set the value of tmp to the value of smp (i.e., tmp=smp).
314  */
315 void
316 move(const MINT *smp, MINT *tmp)
317 {
318 
319 	_movem("move", smp, tmp);
320 }
321 
322 
323 /*
324  * Internal routine to set the value of tmp to that of sbp.
325  */
326 static void
327 _moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
328 {
329 
330 	BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
331 }
332 
333 /*
334  * Internal routine to set the value of tmp to that of smp.
335  */
336 static void
337 _movem(const char *msg, const MINT *smp, MINT *tmp)
338 {
339 
340 	BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
341 }
342 
343 /*
344  * Compute the square root of nmp and put the result in xmp.  The
345  * remainder goes in rmp.  Should satisfy: rmp=nmp-(xmp*xmp).
346  *
347  * Note that the OpenSSL BIGNUM library does not have a square root
348  * function, so this had to be implemented by hand using Newton's
349  * recursive formula:
350  *
351  *		x = (x + (n / x)) / 2
352  *
353  * where x is the square root of the positive number n.  In the
354  * beginning, x should be a reasonable guess, but the value 1,
355  * although suboptimal, works, too; this is that is used below.
356  */
357 void
358 msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
359 {
360 	MINT *tolerance;
361 	MINT *ox, *x;
362 	MINT *z1, *z2, *z3;
363 	short i;
364 
365 	tolerance = _itom("msqrt", 1);
366 	x = _itom("msqrt", 1);
367 	ox = _itom("msqrt", 0);
368 	z1 = _itom("msqrt", 0);
369 	z2 = _itom("msqrt", 0);
370 	z3 = _itom("msqrt", 0);
371 	do {
372 		_movem("msqrt", x, ox);
373 		_mdiv("msqrt", nmp, x, z1, z2);
374 		_madd("msqrt", x, z1, z2);
375 		_sdiv("msqrt", z2, 2, x, &i);
376 		_msub("msqrt", ox, x, z3);
377 	} while (_mcmpa("msqrt", z3, tolerance) == 1);
378 	_movem("msqrt", x, xmp);
379 	_mult("msqrt", x, x, z1);
380 	_msub("msqrt", nmp, z1, z2);
381 	_movem("msqrt", z2, rmp);
382 	_mfree("msqrt", tolerance);
383 	_mfree("msqrt", ox);
384 	_mfree("msqrt", x);
385 	_mfree("msqrt", z1);
386 	_mfree("msqrt", z2);
387 	_mfree("msqrt", z3);
388 }
389 
390 /*
391  * Compute rmp=mp1-mp2.
392  */
393 static void
394 _msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
395 {
396 	BIGNUM b;
397 
398 	BN_init(&b);
399 	BN_ERRCHECK(msg, BN_sub(&b, mp1->bn, mp2->bn));
400 	_moveb(msg, &b, rmp);
401 	BN_free(&b);
402 }
403 
404 void
405 msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
406 {
407 
408 	_msub("msub", mp1, mp2, rmp);
409 }
410 
411 /*
412  * Return a decimal representation of mp.  Return value must be
413  * free()'d.
414  */
415 static char *
416 _mtod(const char *msg, const MINT *mp)
417 {
418 	char *s, *s2;
419 
420 	s = BN_bn2dec(mp->bn);
421 	if (s == NULL)
422 		_bnerr(msg);
423 	asprintf(&s2, "%s", s);
424 	if (s2 == NULL)
425 		MPERR(("%s", msg));
426 	OPENSSL_free(s);
427 	return (s2);
428 }
429 
430 /*
431  * Return a hexadecimal representation of mp.  Return value must be
432  * free()'d.
433  */
434 static char *
435 _mtox(const char *msg, const MINT *mp)
436 {
437 	char *p, *s, *s2;
438 	int len;
439 
440 	s = BN_bn2hex(mp->bn);
441 	if (s == NULL)
442 		_bnerr(msg);
443 	asprintf(&s2, "%s", s);
444 	if (s2 == NULL)
445 		MPERR(("%s", msg));
446 	OPENSSL_free(s);
447 
448 	/*
449 	 * This is a kludge for libgmp compatibility.  The latter's
450 	 * implementation of this function returns lower-case letters,
451 	 * but BN_bn2hex returns upper-case.  Some programs (e.g.,
452 	 * newkey(1)) are sensitive to this.  Although it's probably
453 	 * their fault, it's nice to be compatible.
454 	 */
455 	len = strlen(s2);
456 	for (p = s2; p < s2 + len; p++)
457 		*p = tolower(*p);
458 
459 	return (s2);
460 }
461 
462 char *
463 mtox(const MINT *mp)
464 {
465 
466 	return (_mtox("mtox", mp));
467 }
468 
469 /*
470  * Compute rmp=mp1*mp2.
471  */
472 static void
473 _mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
474 {
475 	BIGNUM b;
476 	BN_CTX *c;
477 
478 	c = BN_CTX_new();
479 	if (c == NULL)
480 		_bnerr(msg);
481 	BN_init(&b);
482 	BN_ERRCHECK(msg, BN_mul(&b, mp1->bn, mp2->bn, c));
483 	_moveb(msg, &b, rmp);
484 	BN_free(&b);
485 	BN_CTX_free(c);
486 }
487 
488 void
489 mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
490 {
491 
492 	_mult("mult", mp1, mp2, rmp);
493 }
494 
495 /*
496  * Compute rmp=(bmp^emp)mod mmp.  (Note that here and above rpow() '^'
497  * means 'raise to power', not 'bitwise XOR'.)
498  */
499 void
500 pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
501 {
502 	BIGNUM b;
503 	BN_CTX *c;
504 
505 	c = BN_CTX_new();
506 	if (c == NULL)
507 		_bnerr("pow");
508 	BN_init(&b);
509 	BN_ERRCHECK("pow", BN_mod_exp(&b, bmp->bn, emp->bn, mmp->bn, c));
510 	_moveb("pow", &b, rmp);
511 	BN_free(&b);
512 	BN_CTX_free(c);
513 }
514 
515 /*
516  * Compute rmp=bmp^e.  (See note above pow().)
517  */
518 void
519 rpow(const MINT *bmp, short e, MINT *rmp)
520 {
521 	MINT *emp;
522 	BIGNUM b;
523 	BN_CTX *c;
524 
525 	c = BN_CTX_new();
526 	if (c == NULL)
527 		_bnerr("rpow");
528 	BN_init(&b);
529 	emp = _itom("rpow", e);
530 	BN_ERRCHECK("rpow", BN_exp(&b, bmp->bn, emp->bn, c));
531 	_moveb("rpow", &b, rmp);
532 	_mfree("rpow", emp);
533 	BN_free(&b);
534 	BN_CTX_free(c);
535 }
536 
537 /*
538  * Compute qmp=nmp/d and ro=nmp%d.
539  */
540 static void
541 _sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro)
542 {
543 	MINT *dmp, *rmp;
544 	BIGNUM q, r;
545 	BN_CTX *c;
546 	char *s;
547 
548 	c = BN_CTX_new();
549 	if (c == NULL)
550 		_bnerr(msg);
551 	BN_init(&q);
552 	BN_init(&r);
553 	dmp = _itom(msg, d);
554 	rmp = _itom(msg, 0);
555 	BN_ERRCHECK(msg, BN_div(&q, &r, nmp->bn, dmp->bn, c));
556 	_moveb(msg, &q, qmp);
557 	_moveb(msg, &r, rmp);
558 	s = _mtox(msg, rmp);
559 	errno = 0;
560 	*ro = strtol(s, NULL, 16);
561 	if (errno != 0)
562 		MPERR(("%s underflow or overflow", msg));
563 	free(s);
564 	_mfree(msg, dmp);
565 	_mfree(msg, rmp);
566 	BN_free(&r);
567 	BN_free(&q);
568 	BN_CTX_free(c);
569 }
570 
571 void
572 sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
573 {
574 
575 	_sdiv("sdiv", nmp, d, qmp, ro);
576 }
577 
578 /*
579  * Convert a hexadecimal string to an MINT.
580  */
581 static MINT *
582 _xtom(const char *msg, const char *s)
583 {
584 	MINT *mp;
585 
586 	mp = malloc(sizeof(*mp));
587 	if (mp == NULL)
588 		MPERR(("%s", msg));
589 	mp->bn = BN_new();
590 	if (mp->bn == NULL)
591 		_bnerr(msg);
592 	BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
593 	return (mp);
594 }
595 
596 MINT *
597 xtom(const char *s)
598 {
599 
600 	return (_xtom("xtom", s));
601 }
602