1 /* $OpenBSD: bn_isqrt.c,v 1.9 2023/05/19 00:54:28 deraadt Exp $ */ 2 /* 3 * Copyright (c) 2022 Theo Buehler <tb@openbsd.org> 4 * 5 * Permission to use, copy, modify, and distribute this software for any 6 * purpose with or without fee is hereby granted, provided that the above 7 * copyright notice and this permission notice appear in all copies. 8 * 9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 */ 17 18 #include <stddef.h> 19 #include <stdint.h> 20 21 #include <openssl/bn.h> 22 #include <openssl/err.h> 23 24 #include "bn_local.h" 25 26 #define CTASSERT(x) extern char _ctassert[(x) ? 1 : -1 ] \ 27 __attribute__((__unused__)) 28 29 /* 30 * Calculate integer square root of |n| using a variant of Newton's method. 31 * 32 * Returns the integer square root of |n| in the caller-provided |out_sqrt|; 33 * |*out_perfect| is set to 1 if and only if |n| is a perfect square. 34 * One of |out_sqrt| and |out_perfect| can be NULL; |in_ctx| can be NULL. 35 * 36 * Returns 0 on error, 1 on success. 37 * 38 * Adapted from pure Python describing cpython's math.isqrt(), without bothering 39 * with any of the optimizations in the C code. A correctness proof is here: 40 * https://github.com/mdickinson/snippets/blob/master/proofs/isqrt/src/isqrt.lean 41 * The comments in the Python code also give a rather detailed proof. 42 */ 43 44 int 45 bn_isqrt(BIGNUM *out_sqrt, int *out_perfect, const BIGNUM *n, BN_CTX *in_ctx) 46 { 47 BN_CTX *ctx = NULL; 48 BIGNUM *a, *b; 49 int c, d, e, s; 50 int cmp, perfect; 51 int ret = 0; 52 53 if (out_perfect == NULL && out_sqrt == NULL) { 54 BNerror(ERR_R_PASSED_NULL_PARAMETER); 55 goto err; 56 } 57 58 if (BN_is_negative(n)) { 59 BNerror(BN_R_INVALID_RANGE); 60 goto err; 61 } 62 63 if ((ctx = in_ctx) == NULL) 64 ctx = BN_CTX_new(); 65 if (ctx == NULL) 66 goto err; 67 68 BN_CTX_start(ctx); 69 70 if ((a = BN_CTX_get(ctx)) == NULL) 71 goto err; 72 if ((b = BN_CTX_get(ctx)) == NULL) 73 goto err; 74 75 if (BN_is_zero(n)) { 76 perfect = 1; 77 BN_zero(a); 78 goto done; 79 } 80 81 if (!BN_one(a)) 82 goto err; 83 84 c = (BN_num_bits(n) - 1) / 2; 85 d = 0; 86 87 /* Calculate s = floor(log(c)). */ 88 if (!BN_set_word(b, c)) 89 goto err; 90 s = BN_num_bits(b) - 1; 91 92 /* 93 * By definition, the loop below is run <= floor(log(log(n))) times. 94 * Comments in the cpython code establish the loop invariant that 95 * 96 * (a - 1)^2 < n / 4^(c - d) < (a + 1)^2 97 * 98 * holds true in every iteration. Once this is proved via induction, 99 * correctness of the algorithm is easy. 100 * 101 * Roughly speaking, A = (a << (d - e)) is used for one Newton step 102 * "a = (A >> 1) + (m >> 1) / A" approximating m = (n >> 2 * (c - d)). 103 */ 104 105 for (; s >= 0; s--) { 106 e = d; 107 d = c >> s; 108 109 if (!BN_rshift(b, n, 2 * c - d - e + 1)) 110 goto err; 111 112 if (!BN_div_ct(b, NULL, b, a, ctx)) 113 goto err; 114 115 if (!BN_lshift(a, a, d - e - 1)) 116 goto err; 117 118 if (!BN_add(a, a, b)) 119 goto err; 120 } 121 122 /* 123 * The loop invariant implies that either a or a - 1 is isqrt(n). 124 * Figure out which one it is. The invariant also implies that for 125 * a perfect square n, a must be the square root. 126 */ 127 128 if (!BN_sqr(b, a, ctx)) 129 goto err; 130 131 /* If a^2 > n, we must have isqrt(n) == a - 1. */ 132 if ((cmp = BN_cmp(b, n)) > 0) { 133 if (!BN_sub_word(a, 1)) 134 goto err; 135 } 136 137 perfect = cmp == 0; 138 139 done: 140 if (out_perfect != NULL) 141 *out_perfect = perfect; 142 143 if (out_sqrt != NULL) { 144 if (!bn_copy(out_sqrt, a)) 145 goto err; 146 } 147 148 ret = 1; 149 150 err: 151 BN_CTX_end(ctx); 152 153 if (ctx != in_ctx) 154 BN_CTX_free(ctx); 155 156 return ret; 157 } 158 159 /* 160 * is_square_mod_N[r % N] indicates whether r % N has a square root modulo N. 161 * The tables are generated in regress/lib/libcrypto/bn/bn_isqrt.c. 162 */ 163 164 const uint8_t is_square_mod_11[] = { 165 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 166 }; 167 CTASSERT(sizeof(is_square_mod_11) == 11); 168 169 const uint8_t is_square_mod_63[] = { 170 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 171 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 172 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 173 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 174 }; 175 CTASSERT(sizeof(is_square_mod_63) == 63); 176 177 const uint8_t is_square_mod_64[] = { 178 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 179 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 180 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 181 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 182 }; 183 CTASSERT(sizeof(is_square_mod_64) == 64); 184 185 const uint8_t is_square_mod_65[] = { 186 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 187 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 188 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 189 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 190 1, 191 }; 192 CTASSERT(sizeof(is_square_mod_65) == 65); 193 194 /* 195 * Determine whether n is a perfect square or not. 196 * 197 * Returns 1 on success and 0 on error. In case of success, |*out_perfect| is 198 * set to 1 if and only if |n| is a perfect square. 199 */ 200 201 int 202 bn_is_perfect_square(int *out_perfect, const BIGNUM *n, BN_CTX *ctx) 203 { 204 BN_ULONG r; 205 206 *out_perfect = 0; 207 208 if (BN_is_negative(n)) 209 return 1; 210 211 /* 212 * Before performing an expensive bn_isqrt() operation, weed out many 213 * obvious non-squares. See H. Cohen, "A course in computational 214 * algebraic number theory", Algorithm 1.7.3. 215 * 216 * The idea is that a square remains a square when reduced modulo any 217 * number. The moduli are chosen in such a way that a non-square has 218 * probability < 1% of passing the four table lookups. 219 */ 220 221 /* n % 64 */ 222 r = BN_lsw(n) & 0x3f; 223 224 if (!is_square_mod_64[r % 64]) 225 return 1; 226 227 if ((r = BN_mod_word(n, 11 * 63 * 65)) == (BN_ULONG)-1) 228 return 0; 229 230 if (!is_square_mod_63[r % 63] || 231 !is_square_mod_65[r % 65] || 232 !is_square_mod_11[r % 11]) 233 return 1; 234 235 return bn_isqrt(NULL, out_perfect, n, ctx); 236 } 237