xref: /openbsd/lib/libcrypto/bn/bn_isqrt.c (revision 3c6df8cf)
1 /*	$OpenBSD: bn_isqrt.c,v 1.9 2023/05/19 00:54:28 deraadt Exp $ */
2 /*
3  * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <stddef.h>
19 #include <stdint.h>
20 
21 #include <openssl/bn.h>
22 #include <openssl/err.h>
23 
24 #include "bn_local.h"
25 
26 #define CTASSERT(x)	extern char  _ctassert[(x) ? 1 : -1 ]   \
27 			    __attribute__((__unused__))
28 
29 /*
30  * Calculate integer square root of |n| using a variant of Newton's method.
31  *
32  * Returns the integer square root of |n| in the caller-provided |out_sqrt|;
33  * |*out_perfect| is set to 1 if and only if |n| is a perfect square.
34  * One of |out_sqrt| and |out_perfect| can be NULL; |in_ctx| can be NULL.
35  *
36  * Returns 0 on error, 1 on success.
37  *
38  * Adapted from pure Python describing cpython's math.isqrt(), without bothering
39  * with any of the optimizations in the C code. A correctness proof is here:
40  * https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
41  * The comments in the Python code also give a rather detailed proof.
42  */
43 
44 int
45 bn_isqrt(BIGNUM *out_sqrt, int *out_perfect, const BIGNUM *n, BN_CTX *in_ctx)
46 {
47 	BN_CTX *ctx = NULL;
48 	BIGNUM *a, *b;
49 	int c, d, e, s;
50 	int cmp, perfect;
51 	int ret = 0;
52 
53 	if (out_perfect == NULL && out_sqrt == NULL) {
54 		BNerror(ERR_R_PASSED_NULL_PARAMETER);
55 		goto err;
56 	}
57 
58 	if (BN_is_negative(n)) {
59 		BNerror(BN_R_INVALID_RANGE);
60 		goto err;
61 	}
62 
63 	if ((ctx = in_ctx) == NULL)
64 		ctx = BN_CTX_new();
65 	if (ctx == NULL)
66 		goto err;
67 
68 	BN_CTX_start(ctx);
69 
70 	if ((a = BN_CTX_get(ctx)) == NULL)
71 		goto err;
72 	if ((b = BN_CTX_get(ctx)) == NULL)
73 		goto err;
74 
75 	if (BN_is_zero(n)) {
76 		perfect = 1;
77 		BN_zero(a);
78 		goto done;
79 	}
80 
81 	if (!BN_one(a))
82 		goto err;
83 
84 	c = (BN_num_bits(n) - 1) / 2;
85 	d = 0;
86 
87 	/* Calculate s = floor(log(c)). */
88 	if (!BN_set_word(b, c))
89 		goto err;
90 	s = BN_num_bits(b) - 1;
91 
92 	/*
93 	 * By definition, the loop below is run <= floor(log(log(n))) times.
94 	 * Comments in the cpython code establish the loop invariant that
95 	 *
96 	 *	(a - 1)^2 < n / 4^(c - d) < (a + 1)^2
97 	 *
98 	 * holds true in every iteration. Once this is proved via induction,
99 	 * correctness of the algorithm is easy.
100 	 *
101 	 * Roughly speaking, A = (a << (d - e)) is used for one Newton step
102 	 * "a = (A >> 1) + (m >> 1) / A" approximating m = (n >> 2 * (c - d)).
103 	 */
104 
105 	for (; s >= 0; s--) {
106 		e = d;
107 		d = c >> s;
108 
109 		if (!BN_rshift(b, n, 2 * c - d - e + 1))
110 			goto err;
111 
112 		if (!BN_div_ct(b, NULL, b, a, ctx))
113 			goto err;
114 
115 		if (!BN_lshift(a, a, d - e - 1))
116 			goto err;
117 
118 		if (!BN_add(a, a, b))
119 			goto err;
120 	}
121 
122 	/*
123 	 * The loop invariant implies that either a or a - 1 is isqrt(n).
124 	 * Figure out which one it is. The invariant also implies that for
125 	 * a perfect square n, a must be the square root.
126 	 */
127 
128 	if (!BN_sqr(b, a, ctx))
129 		goto err;
130 
131 	/* If a^2 > n, we must have isqrt(n) == a - 1. */
132 	if ((cmp = BN_cmp(b, n)) > 0) {
133 		if (!BN_sub_word(a, 1))
134 			goto err;
135 	}
136 
137 	perfect = cmp == 0;
138 
139  done:
140 	if (out_perfect != NULL)
141 		*out_perfect = perfect;
142 
143 	if (out_sqrt != NULL) {
144 		if (!bn_copy(out_sqrt, a))
145 			goto err;
146 	}
147 
148 	ret = 1;
149 
150  err:
151 	BN_CTX_end(ctx);
152 
153 	if (ctx != in_ctx)
154 		BN_CTX_free(ctx);
155 
156 	return ret;
157 }
158 
159 /*
160  * is_square_mod_N[r % N] indicates whether r % N has a square root modulo N.
161  * The tables are generated in regress/lib/libcrypto/bn/bn_isqrt.c.
162  */
163 
164 const uint8_t is_square_mod_11[] = {
165 	1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0,
166 };
167 CTASSERT(sizeof(is_square_mod_11) == 11);
168 
169 const uint8_t is_square_mod_63[] = {
170 	1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
171 	1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0,
172 	0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
173 	0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
174 };
175 CTASSERT(sizeof(is_square_mod_63) == 63);
176 
177 const uint8_t is_square_mod_64[] = {
178 	1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
179 	1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
180 	0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
181 	0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
182 };
183 CTASSERT(sizeof(is_square_mod_64) == 64);
184 
185 const uint8_t is_square_mod_65[] = {
186 	1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0,
187 	1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0,
188 	0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
189 	0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
190 	1,
191 };
192 CTASSERT(sizeof(is_square_mod_65) == 65);
193 
194 /*
195  * Determine whether n is a perfect square or not.
196  *
197  * Returns 1 on success and 0 on error. In case of success, |*out_perfect| is
198  * set to 1 if and only if |n| is a perfect square.
199  */
200 
201 int
202 bn_is_perfect_square(int *out_perfect, const BIGNUM *n, BN_CTX *ctx)
203 {
204 	BN_ULONG r;
205 
206 	*out_perfect = 0;
207 
208 	if (BN_is_negative(n))
209 		return 1;
210 
211 	/*
212 	 * Before performing an expensive bn_isqrt() operation, weed out many
213 	 * obvious non-squares. See H. Cohen, "A course in computational
214 	 * algebraic number theory", Algorithm 1.7.3.
215 	 *
216 	 * The idea is that a square remains a square when reduced modulo any
217 	 * number. The moduli are chosen in such a way that a non-square has
218 	 * probability < 1% of passing the four table lookups.
219 	 */
220 
221 	/* n % 64 */
222 	r = BN_lsw(n) & 0x3f;
223 
224 	if (!is_square_mod_64[r % 64])
225 		return 1;
226 
227 	if ((r = BN_mod_word(n, 11 * 63 * 65)) == (BN_ULONG)-1)
228 		return 0;
229 
230 	if (!is_square_mod_63[r % 63] ||
231 	    !is_square_mod_65[r % 65] ||
232 	    !is_square_mod_11[r % 11])
233 		return 1;
234 
235 	return bn_isqrt(NULL, out_perfect, n, ctx);
236 }
237