xref: /openbsd/lib/libcrypto/bn/bn_isqrt.c (revision a29d9d67)
1*a29d9d67Stb /*	$OpenBSD: bn_isqrt.c,v 1.10 2023/06/04 17:28:35 tb Exp $ */
2e9fb6635Stb /*
3e9fb6635Stb  * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
4e9fb6635Stb  *
5e9fb6635Stb  * Permission to use, copy, modify, and distribute this software for any
6e9fb6635Stb  * purpose with or without fee is hereby granted, provided that the above
7e9fb6635Stb  * copyright notice and this permission notice appear in all copies.
8e9fb6635Stb  *
9e9fb6635Stb  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10e9fb6635Stb  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11e9fb6635Stb  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12e9fb6635Stb  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13e9fb6635Stb  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14e9fb6635Stb  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15e9fb6635Stb  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16e9fb6635Stb  */
17e9fb6635Stb 
18e9fb6635Stb #include <stddef.h>
19e9fb6635Stb #include <stdint.h>
20e9fb6635Stb 
21e9fb6635Stb #include <openssl/bn.h>
22e9fb6635Stb #include <openssl/err.h>
23e9fb6635Stb 
24c9675a23Stb #include "bn_local.h"
25*a29d9d67Stb #include "crypto_internal.h"
26e9fb6635Stb 
27e9fb6635Stb /*
28e9fb6635Stb  * Calculate integer square root of |n| using a variant of Newton's method.
29e9fb6635Stb  *
30e9fb6635Stb  * Returns the integer square root of |n| in the caller-provided |out_sqrt|;
31e9fb6635Stb  * |*out_perfect| is set to 1 if and only if |n| is a perfect square.
32e9fb6635Stb  * One of |out_sqrt| and |out_perfect| can be NULL; |in_ctx| can be NULL.
33e9fb6635Stb  *
34e9fb6635Stb  * Returns 0 on error, 1 on success.
35e9fb6635Stb  *
36e9fb6635Stb  * Adapted from pure Python describing cpython's math.isqrt(), without bothering
37e9fb6635Stb  * with any of the optimizations in the C code. A correctness proof is here:
38e9fb6635Stb  * https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
39e9fb6635Stb  * The comments in the Python code also give a rather detailed proof.
40e9fb6635Stb  */
41e9fb6635Stb 
42e9fb6635Stb int
bn_isqrt(BIGNUM * out_sqrt,int * out_perfect,const BIGNUM * n,BN_CTX * in_ctx)43e9fb6635Stb bn_isqrt(BIGNUM *out_sqrt, int *out_perfect, const BIGNUM *n, BN_CTX *in_ctx)
44e9fb6635Stb {
45e9fb6635Stb 	BN_CTX *ctx = NULL;
46e9fb6635Stb 	BIGNUM *a, *b;
47e9fb6635Stb 	int c, d, e, s;
48e9fb6635Stb 	int cmp, perfect;
49e9fb6635Stb 	int ret = 0;
50e9fb6635Stb 
51e9fb6635Stb 	if (out_perfect == NULL && out_sqrt == NULL) {
52e9fb6635Stb 		BNerror(ERR_R_PASSED_NULL_PARAMETER);
53e9fb6635Stb 		goto err;
54e9fb6635Stb 	}
55e9fb6635Stb 
56e9fb6635Stb 	if (BN_is_negative(n)) {
57e9fb6635Stb 		BNerror(BN_R_INVALID_RANGE);
58e9fb6635Stb 		goto err;
59e9fb6635Stb 	}
60e9fb6635Stb 
61e9fb6635Stb 	if ((ctx = in_ctx) == NULL)
62e9fb6635Stb 		ctx = BN_CTX_new();
63e9fb6635Stb 	if (ctx == NULL)
64e9fb6635Stb 		goto err;
65e9fb6635Stb 
66e9fb6635Stb 	BN_CTX_start(ctx);
67e9fb6635Stb 
68e9fb6635Stb 	if ((a = BN_CTX_get(ctx)) == NULL)
69e9fb6635Stb 		goto err;
70e9fb6635Stb 	if ((b = BN_CTX_get(ctx)) == NULL)
71e9fb6635Stb 		goto err;
72e9fb6635Stb 
73e9fb6635Stb 	if (BN_is_zero(n)) {
74e9fb6635Stb 		perfect = 1;
75ce6fd31bSjsing 		BN_zero(a);
76e9fb6635Stb 		goto done;
77e9fb6635Stb 	}
78e9fb6635Stb 
79e9fb6635Stb 	if (!BN_one(a))
80e9fb6635Stb 		goto err;
81e9fb6635Stb 
82e9fb6635Stb 	c = (BN_num_bits(n) - 1) / 2;
83e9fb6635Stb 	d = 0;
84e9fb6635Stb 
85e9fb6635Stb 	/* Calculate s = floor(log(c)). */
86e9fb6635Stb 	if (!BN_set_word(b, c))
87e9fb6635Stb 		goto err;
88e9fb6635Stb 	s = BN_num_bits(b) - 1;
89e9fb6635Stb 
90e9fb6635Stb 	/*
91e9fb6635Stb 	 * By definition, the loop below is run <= floor(log(log(n))) times.
92e9fb6635Stb 	 * Comments in the cpython code establish the loop invariant that
93e9fb6635Stb 	 *
94e9fb6635Stb 	 *	(a - 1)^2 < n / 4^(c - d) < (a + 1)^2
95e9fb6635Stb 	 *
96e9fb6635Stb 	 * holds true in every iteration. Once this is proved via induction,
97e9fb6635Stb 	 * correctness of the algorithm is easy.
98e9fb6635Stb 	 *
99e9fb6635Stb 	 * Roughly speaking, A = (a << (d - e)) is used for one Newton step
100e9fb6635Stb 	 * "a = (A >> 1) + (m >> 1) / A" approximating m = (n >> 2 * (c - d)).
101e9fb6635Stb 	 */
102e9fb6635Stb 
103e9fb6635Stb 	for (; s >= 0; s--) {
104e9fb6635Stb 		e = d;
105e9fb6635Stb 		d = c >> s;
106e9fb6635Stb 
107e9fb6635Stb 		if (!BN_rshift(b, n, 2 * c - d - e + 1))
108e9fb6635Stb 			goto err;
109e9fb6635Stb 
110e9fb6635Stb 		if (!BN_div_ct(b, NULL, b, a, ctx))
111e9fb6635Stb 			goto err;
112e9fb6635Stb 
113e9fb6635Stb 		if (!BN_lshift(a, a, d - e - 1))
114e9fb6635Stb 			goto err;
115e9fb6635Stb 
116e9fb6635Stb 		if (!BN_add(a, a, b))
117e9fb6635Stb 			goto err;
118e9fb6635Stb 	}
119e9fb6635Stb 
120e9fb6635Stb 	/*
121e9fb6635Stb 	 * The loop invariant implies that either a or a - 1 is isqrt(n).
122e9fb6635Stb 	 * Figure out which one it is. The invariant also implies that for
123e9fb6635Stb 	 * a perfect square n, a must be the square root.
124e9fb6635Stb 	 */
125e9fb6635Stb 
126e9fb6635Stb 	if (!BN_sqr(b, a, ctx))
127e9fb6635Stb 		goto err;
128e9fb6635Stb 
129e9fb6635Stb 	/* If a^2 > n, we must have isqrt(n) == a - 1. */
130e9fb6635Stb 	if ((cmp = BN_cmp(b, n)) > 0) {
131e9fb6635Stb 		if (!BN_sub_word(a, 1))
132e9fb6635Stb 			goto err;
133e9fb6635Stb 	}
134e9fb6635Stb 
135e9fb6635Stb 	perfect = cmp == 0;
136e9fb6635Stb 
137e9fb6635Stb  done:
138e9fb6635Stb 	if (out_perfect != NULL)
139e9fb6635Stb 		*out_perfect = perfect;
140e9fb6635Stb 
141e9fb6635Stb 	if (out_sqrt != NULL) {
142cef8f41dStb 		if (!bn_copy(out_sqrt, a))
143e9fb6635Stb 			goto err;
144e9fb6635Stb 	}
145e9fb6635Stb 
146e9fb6635Stb 	ret = 1;
147e9fb6635Stb 
148e9fb6635Stb  err:
149e9fb6635Stb 	BN_CTX_end(ctx);
150e9fb6635Stb 
151e9fb6635Stb 	if (ctx != in_ctx)
152e9fb6635Stb 		BN_CTX_free(ctx);
153e9fb6635Stb 
154e9fb6635Stb 	return ret;
155e9fb6635Stb }
156e9fb6635Stb 
157e9fb6635Stb /*
158e9fb6635Stb  * is_square_mod_N[r % N] indicates whether r % N has a square root modulo N.
1593422461aStb  * The tables are generated in regress/lib/libcrypto/bn/bn_isqrt.c.
160e9fb6635Stb  */
161e9fb6635Stb 
1625299dc9fStb const uint8_t is_square_mod_11[] = {
163e9fb6635Stb 	1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0,
164e9fb6635Stb };
165e9fb6635Stb CTASSERT(sizeof(is_square_mod_11) == 11);
166e9fb6635Stb 
1675299dc9fStb const uint8_t is_square_mod_63[] = {
168e9fb6635Stb 	1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
169e9fb6635Stb 	1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0,
170e9fb6635Stb 	0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
171e9fb6635Stb 	0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
172e9fb6635Stb };
173e9fb6635Stb CTASSERT(sizeof(is_square_mod_63) == 63);
174e9fb6635Stb 
1755299dc9fStb const uint8_t is_square_mod_64[] = {
176e9fb6635Stb 	1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
177e9fb6635Stb 	1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
178e9fb6635Stb 	0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
179e9fb6635Stb 	0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
180e9fb6635Stb };
181e9fb6635Stb CTASSERT(sizeof(is_square_mod_64) == 64);
182e9fb6635Stb 
1835299dc9fStb const uint8_t is_square_mod_65[] = {
184e9fb6635Stb 	1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0,
185e9fb6635Stb 	1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0,
186e9fb6635Stb 	0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
187e9fb6635Stb 	0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
188e9fb6635Stb 	1,
189e9fb6635Stb };
190e9fb6635Stb CTASSERT(sizeof(is_square_mod_65) == 65);
191e9fb6635Stb 
192e9fb6635Stb /*
193e9fb6635Stb  * Determine whether n is a perfect square or not.
194e9fb6635Stb  *
195e9fb6635Stb  * Returns 1 on success and 0 on error. In case of success, |*out_perfect| is
196e9fb6635Stb  * set to 1 if and only if |n| is a perfect square.
197e9fb6635Stb  */
198e9fb6635Stb 
199e9fb6635Stb int
bn_is_perfect_square(int * out_perfect,const BIGNUM * n,BN_CTX * ctx)200e9fb6635Stb bn_is_perfect_square(int *out_perfect, const BIGNUM *n, BN_CTX *ctx)
201e9fb6635Stb {
202e9fb6635Stb 	BN_ULONG r;
203e9fb6635Stb 
204e9fb6635Stb 	*out_perfect = 0;
205e9fb6635Stb 
206e9fb6635Stb 	if (BN_is_negative(n))
207e9fb6635Stb 		return 1;
208e9fb6635Stb 
209e9fb6635Stb 	/*
210e9fb6635Stb 	 * Before performing an expensive bn_isqrt() operation, weed out many
211e9fb6635Stb 	 * obvious non-squares. See H. Cohen, "A course in computational
212e9fb6635Stb 	 * algebraic number theory", Algorithm 1.7.3.
213e9fb6635Stb 	 *
214e9fb6635Stb 	 * The idea is that a square remains a square when reduced modulo any
215e9fb6635Stb 	 * number. The moduli are chosen in such a way that a non-square has
216e9fb6635Stb 	 * probability < 1% of passing the four table lookups.
217e9fb6635Stb 	 */
218e9fb6635Stb 
219e9fb6635Stb 	/* n % 64 */
220e9fb6635Stb 	r = BN_lsw(n) & 0x3f;
221e9fb6635Stb 
222e9fb6635Stb 	if (!is_square_mod_64[r % 64])
223e9fb6635Stb 		return 1;
224e9fb6635Stb 
225e9fb6635Stb 	if ((r = BN_mod_word(n, 11 * 63 * 65)) == (BN_ULONG)-1)
226e9fb6635Stb 		return 0;
227e9fb6635Stb 
228e9fb6635Stb 	if (!is_square_mod_63[r % 63] ||
229e9fb6635Stb 	    !is_square_mod_65[r % 65] ||
230e9fb6635Stb 	    !is_square_mod_11[r % 11])
231e9fb6635Stb 		return 1;
232e9fb6635Stb 
233e9fb6635Stb 	return bn_isqrt(NULL, out_perfect, n, ctx);
234e9fb6635Stb }
235