1*a29d9d67Stb /* $OpenBSD: bn_isqrt.c,v 1.10 2023/06/04 17:28:35 tb Exp $ */
2e9fb6635Stb /*
3e9fb6635Stb * Copyright (c) 2022 Theo Buehler <tb@openbsd.org>
4e9fb6635Stb *
5e9fb6635Stb * Permission to use, copy, modify, and distribute this software for any
6e9fb6635Stb * purpose with or without fee is hereby granted, provided that the above
7e9fb6635Stb * copyright notice and this permission notice appear in all copies.
8e9fb6635Stb *
9e9fb6635Stb * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10e9fb6635Stb * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11e9fb6635Stb * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12e9fb6635Stb * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13e9fb6635Stb * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14e9fb6635Stb * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15e9fb6635Stb * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16e9fb6635Stb */
17e9fb6635Stb
18e9fb6635Stb #include <stddef.h>
19e9fb6635Stb #include <stdint.h>
20e9fb6635Stb
21e9fb6635Stb #include <openssl/bn.h>
22e9fb6635Stb #include <openssl/err.h>
23e9fb6635Stb
24c9675a23Stb #include "bn_local.h"
25*a29d9d67Stb #include "crypto_internal.h"
26e9fb6635Stb
27e9fb6635Stb /*
28e9fb6635Stb * Calculate integer square root of |n| using a variant of Newton's method.
29e9fb6635Stb *
30e9fb6635Stb * Returns the integer square root of |n| in the caller-provided |out_sqrt|;
31e9fb6635Stb * |*out_perfect| is set to 1 if and only if |n| is a perfect square.
32e9fb6635Stb * One of |out_sqrt| and |out_perfect| can be NULL; |in_ctx| can be NULL.
33e9fb6635Stb *
34e9fb6635Stb * Returns 0 on error, 1 on success.
35e9fb6635Stb *
36e9fb6635Stb * Adapted from pure Python describing cpython's math.isqrt(), without bothering
37e9fb6635Stb * with any of the optimizations in the C code. A correctness proof is here:
38e9fb6635Stb * https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean
39e9fb6635Stb * The comments in the Python code also give a rather detailed proof.
40e9fb6635Stb */
41e9fb6635Stb
42e9fb6635Stb int
bn_isqrt(BIGNUM * out_sqrt,int * out_perfect,const BIGNUM * n,BN_CTX * in_ctx)43e9fb6635Stb bn_isqrt(BIGNUM *out_sqrt, int *out_perfect, const BIGNUM *n, BN_CTX *in_ctx)
44e9fb6635Stb {
45e9fb6635Stb BN_CTX *ctx = NULL;
46e9fb6635Stb BIGNUM *a, *b;
47e9fb6635Stb int c, d, e, s;
48e9fb6635Stb int cmp, perfect;
49e9fb6635Stb int ret = 0;
50e9fb6635Stb
51e9fb6635Stb if (out_perfect == NULL && out_sqrt == NULL) {
52e9fb6635Stb BNerror(ERR_R_PASSED_NULL_PARAMETER);
53e9fb6635Stb goto err;
54e9fb6635Stb }
55e9fb6635Stb
56e9fb6635Stb if (BN_is_negative(n)) {
57e9fb6635Stb BNerror(BN_R_INVALID_RANGE);
58e9fb6635Stb goto err;
59e9fb6635Stb }
60e9fb6635Stb
61e9fb6635Stb if ((ctx = in_ctx) == NULL)
62e9fb6635Stb ctx = BN_CTX_new();
63e9fb6635Stb if (ctx == NULL)
64e9fb6635Stb goto err;
65e9fb6635Stb
66e9fb6635Stb BN_CTX_start(ctx);
67e9fb6635Stb
68e9fb6635Stb if ((a = BN_CTX_get(ctx)) == NULL)
69e9fb6635Stb goto err;
70e9fb6635Stb if ((b = BN_CTX_get(ctx)) == NULL)
71e9fb6635Stb goto err;
72e9fb6635Stb
73e9fb6635Stb if (BN_is_zero(n)) {
74e9fb6635Stb perfect = 1;
75ce6fd31bSjsing BN_zero(a);
76e9fb6635Stb goto done;
77e9fb6635Stb }
78e9fb6635Stb
79e9fb6635Stb if (!BN_one(a))
80e9fb6635Stb goto err;
81e9fb6635Stb
82e9fb6635Stb c = (BN_num_bits(n) - 1) / 2;
83e9fb6635Stb d = 0;
84e9fb6635Stb
85e9fb6635Stb /* Calculate s = floor(log(c)). */
86e9fb6635Stb if (!BN_set_word(b, c))
87e9fb6635Stb goto err;
88e9fb6635Stb s = BN_num_bits(b) - 1;
89e9fb6635Stb
90e9fb6635Stb /*
91e9fb6635Stb * By definition, the loop below is run <= floor(log(log(n))) times.
92e9fb6635Stb * Comments in the cpython code establish the loop invariant that
93e9fb6635Stb *
94e9fb6635Stb * (a - 1)^2 < n / 4^(c - d) < (a + 1)^2
95e9fb6635Stb *
96e9fb6635Stb * holds true in every iteration. Once this is proved via induction,
97e9fb6635Stb * correctness of the algorithm is easy.
98e9fb6635Stb *
99e9fb6635Stb * Roughly speaking, A = (a << (d - e)) is used for one Newton step
100e9fb6635Stb * "a = (A >> 1) + (m >> 1) / A" approximating m = (n >> 2 * (c - d)).
101e9fb6635Stb */
102e9fb6635Stb
103e9fb6635Stb for (; s >= 0; s--) {
104e9fb6635Stb e = d;
105e9fb6635Stb d = c >> s;
106e9fb6635Stb
107e9fb6635Stb if (!BN_rshift(b, n, 2 * c - d - e + 1))
108e9fb6635Stb goto err;
109e9fb6635Stb
110e9fb6635Stb if (!BN_div_ct(b, NULL, b, a, ctx))
111e9fb6635Stb goto err;
112e9fb6635Stb
113e9fb6635Stb if (!BN_lshift(a, a, d - e - 1))
114e9fb6635Stb goto err;
115e9fb6635Stb
116e9fb6635Stb if (!BN_add(a, a, b))
117e9fb6635Stb goto err;
118e9fb6635Stb }
119e9fb6635Stb
120e9fb6635Stb /*
121e9fb6635Stb * The loop invariant implies that either a or a - 1 is isqrt(n).
122e9fb6635Stb * Figure out which one it is. The invariant also implies that for
123e9fb6635Stb * a perfect square n, a must be the square root.
124e9fb6635Stb */
125e9fb6635Stb
126e9fb6635Stb if (!BN_sqr(b, a, ctx))
127e9fb6635Stb goto err;
128e9fb6635Stb
129e9fb6635Stb /* If a^2 > n, we must have isqrt(n) == a - 1. */
130e9fb6635Stb if ((cmp = BN_cmp(b, n)) > 0) {
131e9fb6635Stb if (!BN_sub_word(a, 1))
132e9fb6635Stb goto err;
133e9fb6635Stb }
134e9fb6635Stb
135e9fb6635Stb perfect = cmp == 0;
136e9fb6635Stb
137e9fb6635Stb done:
138e9fb6635Stb if (out_perfect != NULL)
139e9fb6635Stb *out_perfect = perfect;
140e9fb6635Stb
141e9fb6635Stb if (out_sqrt != NULL) {
142cef8f41dStb if (!bn_copy(out_sqrt, a))
143e9fb6635Stb goto err;
144e9fb6635Stb }
145e9fb6635Stb
146e9fb6635Stb ret = 1;
147e9fb6635Stb
148e9fb6635Stb err:
149e9fb6635Stb BN_CTX_end(ctx);
150e9fb6635Stb
151e9fb6635Stb if (ctx != in_ctx)
152e9fb6635Stb BN_CTX_free(ctx);
153e9fb6635Stb
154e9fb6635Stb return ret;
155e9fb6635Stb }
156e9fb6635Stb
157e9fb6635Stb /*
158e9fb6635Stb * is_square_mod_N[r % N] indicates whether r % N has a square root modulo N.
1593422461aStb * The tables are generated in regress/lib/libcrypto/bn/bn_isqrt.c.
160e9fb6635Stb */
161e9fb6635Stb
1625299dc9fStb const uint8_t is_square_mod_11[] = {
163e9fb6635Stb 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0,
164e9fb6635Stb };
165e9fb6635Stb CTASSERT(sizeof(is_square_mod_11) == 11);
166e9fb6635Stb
1675299dc9fStb const uint8_t is_square_mod_63[] = {
168e9fb6635Stb 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0,
169e9fb6635Stb 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0,
170e9fb6635Stb 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
171e9fb6635Stb 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
172e9fb6635Stb };
173e9fb6635Stb CTASSERT(sizeof(is_square_mod_63) == 63);
174e9fb6635Stb
1755299dc9fStb const uint8_t is_square_mod_64[] = {
176e9fb6635Stb 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
177e9fb6635Stb 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
178e9fb6635Stb 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
179e9fb6635Stb 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
180e9fb6635Stb };
181e9fb6635Stb CTASSERT(sizeof(is_square_mod_64) == 64);
182e9fb6635Stb
1835299dc9fStb const uint8_t is_square_mod_65[] = {
184e9fb6635Stb 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0,
185e9fb6635Stb 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0,
186e9fb6635Stb 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
187e9fb6635Stb 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
188e9fb6635Stb 1,
189e9fb6635Stb };
190e9fb6635Stb CTASSERT(sizeof(is_square_mod_65) == 65);
191e9fb6635Stb
192e9fb6635Stb /*
193e9fb6635Stb * Determine whether n is a perfect square or not.
194e9fb6635Stb *
195e9fb6635Stb * Returns 1 on success and 0 on error. In case of success, |*out_perfect| is
196e9fb6635Stb * set to 1 if and only if |n| is a perfect square.
197e9fb6635Stb */
198e9fb6635Stb
199e9fb6635Stb int
bn_is_perfect_square(int * out_perfect,const BIGNUM * n,BN_CTX * ctx)200e9fb6635Stb bn_is_perfect_square(int *out_perfect, const BIGNUM *n, BN_CTX *ctx)
201e9fb6635Stb {
202e9fb6635Stb BN_ULONG r;
203e9fb6635Stb
204e9fb6635Stb *out_perfect = 0;
205e9fb6635Stb
206e9fb6635Stb if (BN_is_negative(n))
207e9fb6635Stb return 1;
208e9fb6635Stb
209e9fb6635Stb /*
210e9fb6635Stb * Before performing an expensive bn_isqrt() operation, weed out many
211e9fb6635Stb * obvious non-squares. See H. Cohen, "A course in computational
212e9fb6635Stb * algebraic number theory", Algorithm 1.7.3.
213e9fb6635Stb *
214e9fb6635Stb * The idea is that a square remains a square when reduced modulo any
215e9fb6635Stb * number. The moduli are chosen in such a way that a non-square has
216e9fb6635Stb * probability < 1% of passing the four table lookups.
217e9fb6635Stb */
218e9fb6635Stb
219e9fb6635Stb /* n % 64 */
220e9fb6635Stb r = BN_lsw(n) & 0x3f;
221e9fb6635Stb
222e9fb6635Stb if (!is_square_mod_64[r % 64])
223e9fb6635Stb return 1;
224e9fb6635Stb
225e9fb6635Stb if ((r = BN_mod_word(n, 11 * 63 * 65)) == (BN_ULONG)-1)
226e9fb6635Stb return 0;
227e9fb6635Stb
228e9fb6635Stb if (!is_square_mod_63[r % 63] ||
229e9fb6635Stb !is_square_mod_65[r % 65] ||
230e9fb6635Stb !is_square_mod_11[r % 11])
231e9fb6635Stb return 1;
232e9fb6635Stb
233e9fb6635Stb return bn_isqrt(NULL, out_perfect, n, ctx);
234e9fb6635Stb }
235