1 /*
2 * Tiny arbitrary precision floating point library
3 *
4 * Copyright (c) 2017-2020 Fabrice Bellard
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to deal
8 * in the Software without restriction, including without limitation the rights
9 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 * copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 * THE SOFTWARE.
23 */
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <inttypes.h>
27 #include <math.h>
28 #include <string.h>
29 #include <assert.h>
30
31 #ifdef __AVX2__
32 #include <immintrin.h>
33 #endif
34
35 #include "cutils.h"
36 #include "libbf.h"
37
38 /* enable it to check the multiplication result */
39 //#define USE_MUL_CHECK
40 /* enable it to use FFT/NTT multiplication */
41 #define USE_FFT_MUL
42 /* enable decimal floating point support */
43 #define USE_BF_DEC
44
45 //#define inline __attribute__((always_inline))
46
47 #ifdef __AVX2__
48 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
49 #else
50 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
51 #endif
52
53 /* XXX: adjust */
54 #define DIVNORM_LARGE_THRESHOLD 50
55 #define UDIV1NORM_THRESHOLD 3
56
57 #if LIMB_BITS == 64
58 #define FMT_LIMB1 "%" PRIx64
59 #define FMT_LIMB "%016" PRIx64
60 #define PRId_LIMB PRId64
61 #define PRIu_LIMB PRIu64
62
63 #else
64
65 #define FMT_LIMB1 "%x"
66 #define FMT_LIMB "%08x"
67 #define PRId_LIMB "d"
68 #define PRIu_LIMB "u"
69
70 #endif
71
72 typedef intptr_t mp_size_t;
73
74 typedef int bf_op2_func_t(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
75 bf_flags_t flags);
76
77 #ifdef USE_FFT_MUL
78
79 #define FFT_MUL_R_OVERLAP_A (1 << 0)
80 #define FFT_MUL_R_OVERLAP_B (1 << 1)
81 #define FFT_MUL_R_NORESIZE (1 << 2)
82
83 static no_inline int fft_mul(bf_context_t *s,
84 bf_t *res, limb_t *a_tab, limb_t a_len,
85 limb_t *b_tab, limb_t b_len, int mul_flags);
86 static void fft_clear_cache(bf_context_t *s);
87 #endif
88 #ifdef USE_BF_DEC
89 static void mp_pow_init(void);
90 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos);
91 #endif
92
93
94 /* could leading zeros */
clz(limb_t a)95 static inline int clz(limb_t a)
96 {
97 if (a == 0) {
98 return LIMB_BITS;
99 } else {
100 #if LIMB_BITS == 64
101 return clz64(a);
102 #else
103 return clz32(a);
104 #endif
105 }
106 }
107
ctz(limb_t a)108 static inline int ctz(limb_t a)
109 {
110 if (a == 0) {
111 return LIMB_BITS;
112 } else {
113 #if LIMB_BITS == 64
114 return ctz64(a);
115 #else
116 return ctz32(a);
117 #endif
118 }
119 }
120
ceil_log2(limb_t a)121 static inline int ceil_log2(limb_t a)
122 {
123 if (a <= 1)
124 return 0;
125 else
126 return LIMB_BITS - clz(a - 1);
127 }
128
129 #if 0 //unused
130 /* b must be >= 1 */
131 static inline slimb_t ceil_div(slimb_t a, slimb_t b)
132 {
133 if (a >= 0)
134 return (a + b - 1) / b;
135 else
136 return a / b;
137 }
138 #endif
139
140 /* b must be >= 1 */
floor_div(slimb_t a,slimb_t b)141 static inline slimb_t floor_div(slimb_t a, slimb_t b)
142 {
143 if (a >= 0) {
144 return a / b;
145 } else {
146 return (a - b + 1) / b;
147 }
148 }
149
150 /* return r = a modulo b (0 <= r <= b - 1. b must be >= 1 */
smod(slimb_t a,slimb_t b)151 static inline limb_t smod(slimb_t a, slimb_t b)
152 {
153 a = a % (slimb_t)b;
154 if (a < 0)
155 a += b;
156 return a;
157 }
158
159 #define malloc(s) malloc_is_forbidden(s)
160 #define free(p) free_is_forbidden(p)
161 #define realloc(p, s) realloc_is_forbidden(p, s)
162
bf_context_init(bf_context_t * s,bf_realloc_func_t * realloc_func,void * realloc_opaque)163 void bf_context_init(bf_context_t *s, bf_realloc_func_t *realloc_func,
164 void *realloc_opaque)
165 {
166 memset(s, 0, sizeof(*s));
167 s->realloc_func = realloc_func;
168 s->realloc_opaque = realloc_opaque;
169 #ifdef USE_BF_DEC
170 mp_pow_init();
171 #endif
172 }
173
bf_context_end(bf_context_t * s)174 void bf_context_end(bf_context_t *s)
175 {
176 bf_clear_cache(s);
177 }
178
bf_init(bf_context_t * s,bf_t * r)179 void bf_init(bf_context_t *s, bf_t *r)
180 {
181 r->ctx = s;
182 r->sign = 0;
183 r->expn = BF_EXP_ZERO;
184 r->len = 0;
185 r->tab = NULL;
186 }
187
188 /* return 0 if OK, -1 if alloc error */
bf_resize(bf_t * r,limb_t len)189 int bf_resize(bf_t *r, limb_t len)
190 {
191 limb_t *tab;
192
193 if (len != r->len) {
194 tab = bf_realloc(r->ctx, r->tab, len * sizeof(limb_t));
195 if (!tab && len != 0)
196 return -1;
197 r->tab = tab;
198 r->len = len;
199 }
200 return 0;
201 }
202
203 /* return 0 or BF_ST_MEM_ERROR */
bf_set_ui(bf_t * r,uint64_t a)204 int bf_set_ui(bf_t *r, uint64_t a)
205 {
206 r->sign = 0;
207 if (a == 0) {
208 r->expn = BF_EXP_ZERO;
209 bf_resize(r, 0); /* cannot fail */
210 }
211 #if LIMB_BITS == 32
212 else if (a <= 0xffffffff)
213 #else
214 else
215 #endif
216 {
217 int shift;
218 if (bf_resize(r, 1))
219 goto fail;
220 shift = clz(a);
221 r->tab[0] = a << shift;
222 r->expn = LIMB_BITS - shift;
223 }
224 #if LIMB_BITS == 32
225 else {
226 uint32_t a1, a0;
227 int shift;
228 if (bf_resize(r, 2))
229 goto fail;
230 a0 = a;
231 a1 = a >> 32;
232 shift = clz(a1);
233 r->tab[0] = a0 << shift;
234 r->tab[1] = (a1 << shift) | (a0 >> (LIMB_BITS - shift));
235 r->expn = 2 * LIMB_BITS - shift;
236 }
237 #endif
238 return 0;
239 fail:
240 bf_set_nan(r);
241 return BF_ST_MEM_ERROR;
242 }
243
244 /* return 0 or BF_ST_MEM_ERROR */
bf_set_si(bf_t * r,int64_t a)245 int bf_set_si(bf_t *r, int64_t a)
246 {
247 int ret;
248
249 if (a < 0) {
250 ret = bf_set_ui(r, -a);
251 r->sign = 1;
252 } else {
253 ret = bf_set_ui(r, a);
254 }
255 return ret;
256 }
257
bf_set_nan(bf_t * r)258 void bf_set_nan(bf_t *r)
259 {
260 bf_resize(r, 0); /* cannot fail */
261 r->expn = BF_EXP_NAN;
262 r->sign = 0;
263 }
264
bf_set_zero(bf_t * r,int is_neg)265 void bf_set_zero(bf_t *r, int is_neg)
266 {
267 bf_resize(r, 0); /* cannot fail */
268 r->expn = BF_EXP_ZERO;
269 r->sign = is_neg;
270 }
271
bf_set_inf(bf_t * r,int is_neg)272 void bf_set_inf(bf_t *r, int is_neg)
273 {
274 bf_resize(r, 0); /* cannot fail */
275 r->expn = BF_EXP_INF;
276 r->sign = is_neg;
277 }
278
279 /* return 0 or BF_ST_MEM_ERROR */
bf_set(bf_t * r,const bf_t * a)280 int bf_set(bf_t *r, const bf_t *a)
281 {
282 if (r == a)
283 return 0;
284 if (bf_resize(r, a->len)) {
285 bf_set_nan(r);
286 return BF_ST_MEM_ERROR;
287 }
288 r->sign = a->sign;
289 r->expn = a->expn;
290 memcpy(r->tab, a->tab, a->len * sizeof(limb_t));
291 return 0;
292 }
293
294 /* equivalent to bf_set(r, a); bf_delete(a) */
bf_move(bf_t * r,bf_t * a)295 void bf_move(bf_t *r, bf_t *a)
296 {
297 bf_context_t *s = r->ctx;
298 if (r == a)
299 return;
300 bf_free(s, r->tab);
301 *r = *a;
302 }
303
get_limbz(const bf_t * a,limb_t idx)304 static limb_t get_limbz(const bf_t *a, limb_t idx)
305 {
306 if (idx >= a->len)
307 return 0;
308 else
309 return a->tab[idx];
310 }
311
312 /* get LIMB_BITS at bit position 'pos' in tab */
get_bits(const limb_t * tab,limb_t len,slimb_t pos)313 static inline limb_t get_bits(const limb_t *tab, limb_t len, slimb_t pos)
314 {
315 limb_t i, a0, a1;
316 int p;
317
318 i = pos >> LIMB_LOG2_BITS;
319 p = pos & (LIMB_BITS - 1);
320 if (i < len)
321 a0 = tab[i];
322 else
323 a0 = 0;
324 if (p == 0) {
325 return a0;
326 } else {
327 i++;
328 if (i < len)
329 a1 = tab[i];
330 else
331 a1 = 0;
332 return (a0 >> p) | (a1 << (LIMB_BITS - p));
333 }
334 }
335
get_bit(const limb_t * tab,limb_t len,slimb_t pos)336 static inline limb_t get_bit(const limb_t *tab, limb_t len, slimb_t pos)
337 {
338 slimb_t i;
339 i = pos >> LIMB_LOG2_BITS;
340 if (i < 0 || i >= len)
341 return 0;
342 return (tab[i] >> (pos & (LIMB_BITS - 1))) & 1;
343 }
344
limb_mask(int start,int last)345 static inline limb_t limb_mask(int start, int last)
346 {
347 limb_t v;
348 int n;
349 n = last - start + 1;
350 if (n == LIMB_BITS)
351 v = -1;
352 else
353 v = (((limb_t)1 << n) - 1) << start;
354 return v;
355 }
356
mp_scan_nz(const limb_t * tab,mp_size_t n)357 static limb_t mp_scan_nz(const limb_t *tab, mp_size_t n)
358 {
359 mp_size_t i;
360 for(i = 0; i < n; i++) {
361 if (tab[i] != 0)
362 return 1;
363 }
364 return 0;
365 }
366
367 /* return != 0 if one bit between 0 and bit_pos inclusive is not zero. */
scan_bit_nz(const bf_t * r,slimb_t bit_pos)368 static inline limb_t scan_bit_nz(const bf_t *r, slimb_t bit_pos)
369 {
370 slimb_t pos;
371 limb_t v;
372
373 pos = bit_pos >> LIMB_LOG2_BITS;
374 if (pos < 0)
375 return 0;
376 v = r->tab[pos] & limb_mask(0, bit_pos & (LIMB_BITS - 1));
377 if (v != 0)
378 return 1;
379 pos--;
380 while (pos >= 0) {
381 if (r->tab[pos] != 0)
382 return 1;
383 pos--;
384 }
385 return 0;
386 }
387
388 /* return the addend for rounding. Note that prec can be <= 0 (for
389 BF_FLAG_RADPNT_PREC) */
bf_get_rnd_add(int * pret,const bf_t * r,limb_t l,slimb_t prec,int rnd_mode)390 static int bf_get_rnd_add(int *pret, const bf_t *r, limb_t l,
391 slimb_t prec, int rnd_mode)
392 {
393 int add_one, inexact;
394 limb_t bit1, bit0;
395
396 if (rnd_mode == BF_RNDF) {
397 bit0 = 1; /* faithful rounding does not honor the INEXACT flag */
398 } else {
399 /* starting limb for bit 'prec + 1' */
400 bit0 = scan_bit_nz(r, l * LIMB_BITS - 1 - bf_max(0, prec + 1));
401 }
402
403 /* get the bit at 'prec' */
404 bit1 = get_bit(r->tab, l, l * LIMB_BITS - 1 - prec);
405 inexact = (bit1 | bit0) != 0;
406
407 add_one = 0;
408 switch(rnd_mode) {
409 case BF_RNDZ:
410 break;
411 case BF_RNDN:
412 if (bit1) {
413 if (bit0) {
414 add_one = 1;
415 } else {
416 /* round to even */
417 add_one =
418 get_bit(r->tab, l, l * LIMB_BITS - 1 - (prec - 1));
419 }
420 }
421 break;
422 case BF_RNDD:
423 case BF_RNDU:
424 if (r->sign == (rnd_mode == BF_RNDD))
425 add_one = inexact;
426 break;
427 case BF_RNDNA:
428 case BF_RNDF:
429 add_one = bit1;
430 break;
431 case BF_RNDNU:
432 if (bit1) {
433 if (r->sign)
434 add_one = bit0;
435 else
436 add_one = 1;
437 }
438 break;
439 default:
440 abort();
441 }
442
443 if (inexact)
444 *pret |= BF_ST_INEXACT;
445 return add_one;
446 }
447
bf_set_overflow(bf_t * r,int sign,limb_t prec,bf_flags_t flags)448 static int bf_set_overflow(bf_t *r, int sign, limb_t prec, bf_flags_t flags)
449 {
450 slimb_t i, l, e_max;
451 int rnd_mode;
452
453 rnd_mode = flags & BF_RND_MASK;
454 if (prec == BF_PREC_INF ||
455 rnd_mode == BF_RNDN ||
456 rnd_mode == BF_RNDNA ||
457 rnd_mode == BF_RNDNU ||
458 (rnd_mode == BF_RNDD && sign == 1) ||
459 (rnd_mode == BF_RNDU && sign == 0)) {
460 bf_set_inf(r, sign);
461 } else {
462 /* set to maximum finite number */
463 l = (prec + LIMB_BITS - 1) / LIMB_BITS;
464 if (bf_resize(r, l)) {
465 bf_set_nan(r);
466 return BF_ST_MEM_ERROR;
467 }
468 r->tab[0] = limb_mask((-prec) & (LIMB_BITS - 1),
469 LIMB_BITS - 1);
470 for(i = 1; i < l; i++)
471 r->tab[i] = (limb_t)-1;
472 e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
473 r->expn = e_max;
474 r->sign = sign;
475 }
476 return BF_ST_OVERFLOW | BF_ST_INEXACT;
477 }
478
479 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
480 assumed to have length 'l' (1 <= l <= r->len). Note: 'prec1' can be
481 infinite (BF_PREC_INF). Can fail with BF_ST_MEM_ERROR in case of
482 overflow not returning infinity. */
__bf_round(bf_t * r,limb_t prec1,bf_flags_t flags,limb_t l)483 static int __bf_round(bf_t *r, limb_t prec1, bf_flags_t flags, limb_t l)
484 {
485 limb_t v, a;
486 int shift, add_one, ret, rnd_mode;
487 slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
488
489 /* e_min and e_max are computed to match the IEEE 754 conventions */
490 e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
491 e_min = -e_range + 3;
492 e_max = e_range;
493
494 if (flags & BF_FLAG_RADPNT_PREC) {
495 /* 'prec' is the precision after the radix point */
496 if (prec1 != BF_PREC_INF)
497 prec = r->expn + prec1;
498 else
499 prec = prec1;
500 } else if (unlikely(r->expn < e_min) && (flags & BF_FLAG_SUBNORMAL)) {
501 /* restrict the precision in case of potentially subnormal
502 result */
503 assert(prec1 != BF_PREC_INF);
504 prec = prec1 - (e_min - r->expn);
505 } else {
506 prec = prec1;
507 }
508
509 /* round to prec bits */
510 rnd_mode = flags & BF_RND_MASK;
511 ret = 0;
512 add_one = bf_get_rnd_add(&ret, r, l, prec, rnd_mode);
513
514 if (prec <= 0) {
515 if (add_one) {
516 bf_resize(r, 1); /* cannot fail */
517 r->tab[0] = (limb_t)1 << (LIMB_BITS - 1);
518 r->expn += 1 - prec;
519 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
520 return ret;
521 } else {
522 goto underflow;
523 }
524 } else if (add_one) {
525 limb_t carry;
526
527 /* add one starting at digit 'prec - 1' */
528 bit_pos = l * LIMB_BITS - 1 - (prec - 1);
529 pos = bit_pos >> LIMB_LOG2_BITS;
530 carry = (limb_t)1 << (bit_pos & (LIMB_BITS - 1));
531
532 for(i = pos; i < l; i++) {
533 v = r->tab[i] + carry;
534 carry = (v < carry);
535 r->tab[i] = v;
536 if (carry == 0)
537 break;
538 }
539 if (carry) {
540 /* shift right by one digit */
541 v = 1;
542 for(i = l - 1; i >= pos; i--) {
543 a = r->tab[i];
544 r->tab[i] = (a >> 1) | (v << (LIMB_BITS - 1));
545 v = a;
546 }
547 r->expn++;
548 }
549 }
550
551 /* check underflow */
552 if (unlikely(r->expn < e_min)) {
553 if (flags & BF_FLAG_SUBNORMAL) {
554 /* if inexact, also set the underflow flag */
555 if (ret & BF_ST_INEXACT)
556 ret |= BF_ST_UNDERFLOW;
557 } else {
558 underflow:
559 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
560 bf_set_zero(r, r->sign);
561 return ret;
562 }
563 }
564
565 /* check overflow */
566 if (unlikely(r->expn > e_max))
567 return bf_set_overflow(r, r->sign, prec1, flags);
568
569 /* keep the bits starting at 'prec - 1' */
570 bit_pos = l * LIMB_BITS - 1 - (prec - 1);
571 i = bit_pos >> LIMB_LOG2_BITS;
572 if (i >= 0) {
573 shift = bit_pos & (LIMB_BITS - 1);
574 if (shift != 0)
575 r->tab[i] &= limb_mask(shift, LIMB_BITS - 1);
576 } else {
577 i = 0;
578 }
579 /* remove trailing zeros */
580 while (r->tab[i] == 0)
581 i++;
582 if (i > 0) {
583 l -= i;
584 memmove(r->tab, r->tab + i, l * sizeof(limb_t));
585 }
586 bf_resize(r, l); /* cannot fail */
587 return ret;
588 }
589
590 /* 'r' must be a finite number. */
bf_normalize_and_round(bf_t * r,limb_t prec1,bf_flags_t flags)591 int bf_normalize_and_round(bf_t *r, limb_t prec1, bf_flags_t flags)
592 {
593 limb_t l, v, a;
594 int shift, ret;
595 slimb_t i;
596
597 // bf_print_str("bf_renorm", r);
598 l = r->len;
599 while (l > 0 && r->tab[l - 1] == 0)
600 l--;
601 if (l == 0) {
602 /* zero */
603 r->expn = BF_EXP_ZERO;
604 bf_resize(r, 0); /* cannot fail */
605 ret = 0;
606 } else {
607 r->expn -= (r->len - l) * LIMB_BITS;
608 /* shift to have the MSB set to '1' */
609 v = r->tab[l - 1];
610 shift = clz(v);
611 if (shift != 0) {
612 v = 0;
613 for(i = 0; i < l; i++) {
614 a = r->tab[i];
615 r->tab[i] = (a << shift) | (v >> (LIMB_BITS - shift));
616 v = a;
617 }
618 r->expn -= shift;
619 }
620 ret = __bf_round(r, prec1, flags, l);
621 }
622 // bf_print_str("r_final", r);
623 return ret;
624 }
625
626 /* return true if rounding can be done at precision 'prec' assuming
627 the exact result r is such that |r-a| <= 2^(EXP(a)-k). */
628 /* XXX: check the case where the exponent would be incremented by the
629 rounding */
bf_can_round(const bf_t * a,slimb_t prec,bf_rnd_t rnd_mode,slimb_t k)630 int bf_can_round(const bf_t *a, slimb_t prec, bf_rnd_t rnd_mode, slimb_t k)
631 {
632 BOOL is_rndn;
633 slimb_t bit_pos, n;
634 limb_t bit;
635
636 if (a->expn == BF_EXP_INF || a->expn == BF_EXP_NAN)
637 return FALSE;
638 if (rnd_mode == BF_RNDF) {
639 return (k >= (prec + 1));
640 }
641 if (a->expn == BF_EXP_ZERO)
642 return FALSE;
643 is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA ||
644 rnd_mode == BF_RNDNU);
645 if (k < (prec + 2))
646 return FALSE;
647 bit_pos = a->len * LIMB_BITS - 1 - prec;
648 n = k - prec;
649 /* bit pattern for RNDN or RNDNA: 0111.. or 1000...
650 for other rounding modes: 000... or 111...
651 */
652 bit = get_bit(a->tab, a->len, bit_pos);
653 bit_pos--;
654 n--;
655 bit ^= is_rndn;
656 /* XXX: slow, but a few iterations on average */
657 while (n != 0) {
658 if (get_bit(a->tab, a->len, bit_pos) != bit)
659 return TRUE;
660 bit_pos--;
661 n--;
662 }
663 return FALSE;
664 }
665
666 /* Cannot fail with BF_ST_MEM_ERROR. */
bf_round(bf_t * r,limb_t prec,bf_flags_t flags)667 int bf_round(bf_t *r, limb_t prec, bf_flags_t flags)
668 {
669 if (r->len == 0)
670 return 0;
671 return __bf_round(r, prec, flags, r->len);
672 }
673
674 /* for debugging */
dump_limbs(const char * str,const limb_t * tab,limb_t n)675 static __maybe_unused void dump_limbs(const char *str, const limb_t *tab, limb_t n)
676 {
677 limb_t i;
678 printf("%s: len=%" PRId_LIMB "\n", str, n);
679 for(i = 0; i < n; i++) {
680 printf("%" PRId_LIMB ": " FMT_LIMB "\n",
681 i, tab[i]);
682 }
683 }
684
mp_print_str(const char * str,const limb_t * tab,limb_t n)685 void mp_print_str(const char *str, const limb_t *tab, limb_t n)
686 {
687 slimb_t i;
688 printf("%s= 0x", str);
689 for(i = n - 1; i >= 0; i--) {
690 if (i != (n - 1))
691 printf("_");
692 printf(FMT_LIMB, tab[i]);
693 }
694 printf("\n");
695 }
696
mp_print_str_h(const char * str,const limb_t * tab,limb_t n,limb_t high)697 static __maybe_unused void mp_print_str_h(const char *str,
698 const limb_t *tab, limb_t n,
699 limb_t high)
700 {
701 slimb_t i;
702 printf("%s= 0x", str);
703 printf(FMT_LIMB, high);
704 for(i = n - 1; i >= 0; i--) {
705 printf("_");
706 printf(FMT_LIMB, tab[i]);
707 }
708 printf("\n");
709 }
710
711 /* for debugging */
bf_print_str(const char * str,const bf_t * a)712 void bf_print_str(const char *str, const bf_t *a)
713 {
714 slimb_t i;
715 printf("%s=", str);
716
717 if (a->expn == BF_EXP_NAN) {
718 printf("NaN");
719 } else {
720 if (a->sign)
721 putchar('-');
722 if (a->expn == BF_EXP_ZERO) {
723 putchar('0');
724 } else if (a->expn == BF_EXP_INF) {
725 printf("Inf");
726 } else {
727 printf("0x0.");
728 for(i = a->len - 1; i >= 0; i--)
729 printf(FMT_LIMB, a->tab[i]);
730 printf("p%" PRId_LIMB, a->expn);
731 }
732 }
733 printf("\n");
734 }
735
736 /* compare the absolute value of 'a' and 'b'. Return < 0 if a < b, 0
737 if a = b and > 0 otherwise. */
bf_cmpu(const bf_t * a,const bf_t * b)738 int bf_cmpu(const bf_t *a, const bf_t *b)
739 {
740 slimb_t i;
741 limb_t len, v1, v2;
742
743 if (a->expn != b->expn) {
744 if (a->expn < b->expn)
745 return -1;
746 else
747 return 1;
748 }
749 len = bf_max(a->len, b->len);
750 for(i = len - 1; i >= 0; i--) {
751 v1 = get_limbz(a, a->len - len + i);
752 v2 = get_limbz(b, b->len - len + i);
753 if (v1 != v2) {
754 if (v1 < v2)
755 return -1;
756 else
757 return 1;
758 }
759 }
760 return 0;
761 }
762
763 /* Full order: -0 < 0, NaN == NaN and NaN is larger than all other numbers */
bf_cmp_full(const bf_t * a,const bf_t * b)764 int bf_cmp_full(const bf_t *a, const bf_t *b)
765 {
766 int res;
767
768 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
769 if (a->expn == b->expn)
770 res = 0;
771 else if (a->expn == BF_EXP_NAN)
772 res = 1;
773 else
774 res = -1;
775 } else if (a->sign != b->sign) {
776 res = 1 - 2 * a->sign;
777 } else {
778 res = bf_cmpu(a, b);
779 if (a->sign)
780 res = -res;
781 }
782 return res;
783 }
784
785 #define BF_CMP_EQ 1
786 #define BF_CMP_LT 2
787 #define BF_CMP_LE 3
788
bf_cmp(const bf_t * a,const bf_t * b,int op)789 static int bf_cmp(const bf_t *a, const bf_t *b, int op)
790 {
791 BOOL is_both_zero;
792 int res;
793
794 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN)
795 return 0;
796 if (a->sign != b->sign) {
797 is_both_zero = (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_ZERO);
798 if (is_both_zero) {
799 return op & BF_CMP_EQ;
800 } else if (op & BF_CMP_LT) {
801 return a->sign;
802 } else {
803 return FALSE;
804 }
805 } else {
806 res = bf_cmpu(a, b);
807 if (res == 0) {
808 return op & BF_CMP_EQ;
809 } else if (op & BF_CMP_LT) {
810 return (res < 0) ^ a->sign;
811 } else {
812 return FALSE;
813 }
814 }
815 }
816
bf_cmp_eq(const bf_t * a,const bf_t * b)817 int bf_cmp_eq(const bf_t *a, const bf_t *b)
818 {
819 return bf_cmp(a, b, BF_CMP_EQ);
820 }
821
bf_cmp_le(const bf_t * a,const bf_t * b)822 int bf_cmp_le(const bf_t *a, const bf_t *b)
823 {
824 return bf_cmp(a, b, BF_CMP_LE);
825 }
826
bf_cmp_lt(const bf_t * a,const bf_t * b)827 int bf_cmp_lt(const bf_t *a, const bf_t *b)
828 {
829 return bf_cmp(a, b, BF_CMP_LT);
830 }
831
832 /* Compute the number of bits 'n' matching the pattern:
833 a= X1000..0
834 b= X0111..1
835
836 When computing a-b, the result will have at least n leading zero
837 bits.
838
839 Precondition: a > b and a.expn - b.expn = 0 or 1
840 */
count_cancelled_bits(const bf_t * a,const bf_t * b)841 static limb_t count_cancelled_bits(const bf_t *a, const bf_t *b)
842 {
843 slimb_t bit_offset, b_offset, n;
844 int p, p1;
845 limb_t v1, v2, mask;
846
847 bit_offset = a->len * LIMB_BITS - 1;
848 b_offset = (b->len - a->len) * LIMB_BITS - (LIMB_BITS - 1) +
849 a->expn - b->expn;
850 n = 0;
851
852 /* first search the equals bits */
853 for(;;) {
854 v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
855 v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
856 // printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
857 if (v1 != v2)
858 break;
859 n += LIMB_BITS;
860 bit_offset -= LIMB_BITS;
861 }
862 /* find the position of the first different bit */
863 p = clz(v1 ^ v2) + 1;
864 n += p;
865 /* then search for '0' in a and '1' in b */
866 p = LIMB_BITS - p;
867 if (p > 0) {
868 /* search in the trailing p bits of v1 and v2 */
869 mask = limb_mask(0, p - 1);
870 p1 = bf_min(clz(v1 & mask), clz((~v2) & mask)) - (LIMB_BITS - p);
871 n += p1;
872 if (p1 != p)
873 goto done;
874 }
875 bit_offset -= LIMB_BITS;
876 for(;;) {
877 v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
878 v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
879 // printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
880 if (v1 != 0 || v2 != -1) {
881 /* different: count the matching bits */
882 p1 = bf_min(clz(v1), clz(~v2));
883 n += p1;
884 break;
885 }
886 n += LIMB_BITS;
887 bit_offset -= LIMB_BITS;
888 }
889 done:
890 return n;
891 }
892
bf_add_internal(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int b_neg)893 static int bf_add_internal(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
894 bf_flags_t flags, int b_neg)
895 {
896 const bf_t *tmp;
897 int is_sub, ret, cmp_res, a_sign, b_sign;
898
899 a_sign = a->sign;
900 b_sign = b->sign ^ b_neg;
901 is_sub = a_sign ^ b_sign;
902 cmp_res = bf_cmpu(a, b);
903 if (cmp_res < 0) {
904 tmp = a;
905 a = b;
906 b = tmp;
907 a_sign = b_sign; /* b_sign is never used later */
908 }
909 /* abs(a) >= abs(b) */
910 if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
911 /* zero result */
912 bf_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
913 ret = 0;
914 } else if (a->len == 0 || b->len == 0) {
915 ret = 0;
916 if (a->expn >= BF_EXP_INF) {
917 if (a->expn == BF_EXP_NAN) {
918 /* at least one operand is NaN */
919 bf_set_nan(r);
920 } else if (b->expn == BF_EXP_INF && is_sub) {
921 /* infinities with different signs */
922 bf_set_nan(r);
923 ret = BF_ST_INVALID_OP;
924 } else {
925 bf_set_inf(r, a_sign);
926 }
927 } else {
928 /* at least one zero and not subtract */
929 bf_set(r, a);
930 r->sign = a_sign;
931 goto renorm;
932 }
933 } else {
934 slimb_t d, a_offset, b_bit_offset, i, cancelled_bits;
935 limb_t carry, v1, v2, u, r_len, carry1, precl, tot_len, z, sub_mask;
936
937 r->sign = a_sign;
938 r->expn = a->expn;
939 d = a->expn - b->expn;
940 /* must add more precision for the leading cancelled bits in
941 subtraction */
942 if (is_sub) {
943 if (d <= 1)
944 cancelled_bits = count_cancelled_bits(a, b);
945 else
946 cancelled_bits = 1;
947 } else {
948 cancelled_bits = 0;
949 }
950
951 /* add two extra bits for rounding */
952 precl = (cancelled_bits + prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
953 tot_len = bf_max(a->len, b->len + (d + LIMB_BITS - 1) / LIMB_BITS);
954 r_len = bf_min(precl, tot_len);
955 if (bf_resize(r, r_len))
956 goto fail;
957 a_offset = a->len - r_len;
958 b_bit_offset = (b->len - r_len) * LIMB_BITS + d;
959
960 /* compute the bits before for the rounding */
961 carry = is_sub;
962 z = 0;
963 sub_mask = -is_sub;
964 i = r_len - tot_len;
965 while (i < 0) {
966 slimb_t ap, bp;
967 BOOL inflag;
968
969 ap = a_offset + i;
970 bp = b_bit_offset + i * LIMB_BITS;
971 inflag = FALSE;
972 if (ap >= 0 && ap < a->len) {
973 v1 = a->tab[ap];
974 inflag = TRUE;
975 } else {
976 v1 = 0;
977 }
978 if (bp + LIMB_BITS > 0 && bp < (slimb_t)(b->len * LIMB_BITS)) {
979 v2 = get_bits(b->tab, b->len, bp);
980 inflag = TRUE;
981 } else {
982 v2 = 0;
983 }
984 if (!inflag) {
985 /* outside 'a' and 'b': go directly to the next value
986 inside a or b so that the running time does not
987 depend on the exponent difference */
988 i = 0;
989 if (ap < 0)
990 i = bf_min(i, -a_offset);
991 /* b_bit_offset + i * LIMB_BITS + LIMB_BITS >= 1
992 equivalent to
993 i >= ceil(-b_bit_offset + 1 - LIMB_BITS) / LIMB_BITS)
994 */
995 if (bp + LIMB_BITS <= 0)
996 i = bf_min(i, (-b_bit_offset) >> LIMB_LOG2_BITS);
997 } else {
998 i++;
999 }
1000 v2 ^= sub_mask;
1001 u = v1 + v2;
1002 carry1 = u < v1;
1003 u += carry;
1004 carry = (u < carry) | carry1;
1005 z |= u;
1006 }
1007 /* and the result */
1008 for(i = 0; i < r_len; i++) {
1009 v1 = get_limbz(a, a_offset + i);
1010 v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS);
1011 v2 ^= sub_mask;
1012 u = v1 + v2;
1013 carry1 = u < v1;
1014 u += carry;
1015 carry = (u < carry) | carry1;
1016 r->tab[i] = u;
1017 }
1018 /* set the extra bits for the rounding */
1019 r->tab[0] |= (z != 0);
1020
1021 /* carry is only possible in add case */
1022 if (!is_sub && carry) {
1023 if (bf_resize(r, r_len + 1))
1024 goto fail;
1025 r->tab[r_len] = 1;
1026 r->expn += LIMB_BITS;
1027 }
1028 renorm:
1029 ret = bf_normalize_and_round(r, prec, flags);
1030 }
1031 return ret;
1032 fail:
1033 bf_set_nan(r);
1034 return BF_ST_MEM_ERROR;
1035 }
1036
__bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1037 static int __bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1038 bf_flags_t flags)
1039 {
1040 return bf_add_internal(r, a, b, prec, flags, 0);
1041 }
1042
__bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1043 static int __bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1044 bf_flags_t flags)
1045 {
1046 return bf_add_internal(r, a, b, prec, flags, 1);
1047 }
1048
mp_add(limb_t * res,const limb_t * op1,const limb_t * op2,limb_t n,limb_t carry)1049 limb_t mp_add(limb_t *res, const limb_t *op1, const limb_t *op2,
1050 limb_t n, limb_t carry)
1051 {
1052 slimb_t i;
1053 limb_t k, a, v, k1;
1054
1055 k = carry;
1056 for(i=0;i<n;i++) {
1057 v = op1[i];
1058 a = v + op2[i];
1059 k1 = a < v;
1060 a = a + k;
1061 k = (a < k) | k1;
1062 res[i] = a;
1063 }
1064 return k;
1065 }
1066
mp_add_ui(limb_t * tab,limb_t b,size_t n)1067 limb_t mp_add_ui(limb_t *tab, limb_t b, size_t n)
1068 {
1069 size_t i;
1070 limb_t k, a;
1071
1072 k=b;
1073 for(i=0;i<n;i++) {
1074 if (k == 0)
1075 break;
1076 a = tab[i] + k;
1077 k = (a < k);
1078 tab[i] = a;
1079 }
1080 return k;
1081 }
1082
mp_sub(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)1083 limb_t mp_sub(limb_t *res, const limb_t *op1, const limb_t *op2,
1084 mp_size_t n, limb_t carry)
1085 {
1086 int i;
1087 limb_t k, a, v, k1;
1088
1089 k = carry;
1090 for(i=0;i<n;i++) {
1091 v = op1[i];
1092 a = v - op2[i];
1093 k1 = a > v;
1094 v = a - k;
1095 k = (v > a) | k1;
1096 res[i] = v;
1097 }
1098 return k;
1099 }
1100
1101 /* compute 0 - op2 */
mp_neg(limb_t * res,const limb_t * op2,mp_size_t n,limb_t carry)1102 static limb_t mp_neg(limb_t *res, const limb_t *op2, mp_size_t n, limb_t carry)
1103 {
1104 int i;
1105 limb_t k, a, v, k1;
1106
1107 k = carry;
1108 for(i=0;i<n;i++) {
1109 v = 0;
1110 a = v - op2[i];
1111 k1 = a > v;
1112 v = a - k;
1113 k = (v > a) | k1;
1114 res[i] = v;
1115 }
1116 return k;
1117 }
1118
mp_sub_ui(limb_t * tab,limb_t b,mp_size_t n)1119 limb_t mp_sub_ui(limb_t *tab, limb_t b, mp_size_t n)
1120 {
1121 mp_size_t i;
1122 limb_t k, a, v;
1123
1124 k=b;
1125 for(i=0;i<n;i++) {
1126 v = tab[i];
1127 a = v - k;
1128 k = a > v;
1129 tab[i] = a;
1130 if (k == 0)
1131 break;
1132 }
1133 return k;
1134 }
1135
1136 /* r = (a + high*B^n) >> shift. Return the remainder r (0 <= r < 2^shift).
1137 1 <= shift <= LIMB_BITS - 1 */
mp_shr(limb_t * tab_r,const limb_t * tab,mp_size_t n,int shift,limb_t high)1138 static limb_t mp_shr(limb_t *tab_r, const limb_t *tab, mp_size_t n,
1139 int shift, limb_t high)
1140 {
1141 mp_size_t i;
1142 limb_t l, a;
1143
1144 assert(shift >= 1 && shift < LIMB_BITS);
1145 l = high;
1146 for(i = n - 1; i >= 0; i--) {
1147 a = tab[i];
1148 tab_r[i] = (a >> shift) | (l << (LIMB_BITS - shift));
1149 l = a;
1150 }
1151 return l & (((limb_t)1 << shift) - 1);
1152 }
1153
1154 /* tabr[] = taba[] * b + l. Return the high carry */
mp_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t l)1155 static limb_t mp_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1156 limb_t b, limb_t l)
1157 {
1158 limb_t i;
1159 dlimb_t t;
1160
1161 for(i = 0; i < n; i++) {
1162 t = (dlimb_t)taba[i] * (dlimb_t)b + l;
1163 tabr[i] = t;
1164 l = t >> LIMB_BITS;
1165 }
1166 return l;
1167 }
1168
1169 /* tabr[] += taba[] * b, return the high word. */
mp_add_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1170 static limb_t mp_add_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1171 limb_t b)
1172 {
1173 limb_t i, l;
1174 dlimb_t t;
1175
1176 l = 0;
1177 for(i = 0; i < n; i++) {
1178 t = (dlimb_t)taba[i] * (dlimb_t)b + l + tabr[i];
1179 tabr[i] = t;
1180 l = t >> LIMB_BITS;
1181 }
1182 return l;
1183 }
1184
1185 /* size of the result : op1_size + op2_size. */
mp_mul_basecase(limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1186 static void mp_mul_basecase(limb_t *result,
1187 const limb_t *op1, limb_t op1_size,
1188 const limb_t *op2, limb_t op2_size)
1189 {
1190 limb_t i, r;
1191
1192 result[op1_size] = mp_mul1(result, op1, op1_size, op2[0], 0);
1193 for(i=1;i<op2_size;i++) {
1194 r = mp_add_mul1(result + i, op1, op1_size, op2[i]);
1195 result[i + op1_size] = r;
1196 }
1197 }
1198
1199 /* return 0 if OK, -1 if memory error */
1200 /* XXX: change API so that result can be allocated */
mp_mul(bf_context_t * s,limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1201 int mp_mul(bf_context_t *s, limb_t *result,
1202 const limb_t *op1, limb_t op1_size,
1203 const limb_t *op2, limb_t op2_size)
1204 {
1205 #ifdef USE_FFT_MUL
1206 if (unlikely(bf_min(op1_size, op2_size) >= FFT_MUL_THRESHOLD)) {
1207 bf_t r_s, *r = &r_s;
1208 r->tab = result;
1209 /* XXX: optimize memory usage in API */
1210 if (fft_mul(s, r, (limb_t *)op1, op1_size,
1211 (limb_t *)op2, op2_size, FFT_MUL_R_NORESIZE))
1212 return -1;
1213 } else
1214 #endif
1215 {
1216 mp_mul_basecase(result, op1, op1_size, op2, op2_size);
1217 }
1218 return 0;
1219 }
1220
1221 /* tabr[] -= taba[] * b. Return the value to substract to the high
1222 word. */
mp_sub_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1223 static limb_t mp_sub_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1224 limb_t b)
1225 {
1226 limb_t i, l;
1227 dlimb_t t;
1228
1229 l = 0;
1230 for(i = 0; i < n; i++) {
1231 t = tabr[i] - (dlimb_t)taba[i] * (dlimb_t)b - l;
1232 tabr[i] = t;
1233 l = -(t >> LIMB_BITS);
1234 }
1235 return l;
1236 }
1237
1238 /* WARNING: d must be >= 2^(LIMB_BITS-1) */
udiv1norm_init(limb_t d)1239 static inline limb_t udiv1norm_init(limb_t d)
1240 {
1241 limb_t a0, a1;
1242 a1 = -d - 1;
1243 a0 = -1;
1244 return (((dlimb_t)a1 << LIMB_BITS) | a0) / d;
1245 }
1246
1247 /* return the quotient and the remainder in '*pr'of 'a1*2^LIMB_BITS+a0
1248 / d' with 0 <= a1 < d. */
udiv1norm(limb_t * pr,limb_t a1,limb_t a0,limb_t d,limb_t d_inv)1249 static inline limb_t udiv1norm(limb_t *pr, limb_t a1, limb_t a0,
1250 limb_t d, limb_t d_inv)
1251 {
1252 limb_t n1m, n_adj, q, r, ah;
1253 dlimb_t a;
1254 n1m = ((slimb_t)a0 >> (LIMB_BITS - 1));
1255 n_adj = a0 + (n1m & d);
1256 a = (dlimb_t)d_inv * (a1 - n1m) + n_adj;
1257 q = (a >> LIMB_BITS) + a1;
1258 /* compute a - q * r and update q so that the remainder is\
1259 between 0 and d - 1 */
1260 a = ((dlimb_t)a1 << LIMB_BITS) | a0;
1261 a = a - (dlimb_t)q * d - d;
1262 ah = a >> LIMB_BITS;
1263 q += 1 + ah;
1264 r = (limb_t)a + (ah & d);
1265 *pr = r;
1266 return q;
1267 }
1268
1269 /* b must be >= 1 << (LIMB_BITS - 1) */
mp_div1norm(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t r)1270 static limb_t mp_div1norm(limb_t *tabr, const limb_t *taba, limb_t n,
1271 limb_t b, limb_t r)
1272 {
1273 slimb_t i;
1274
1275 if (n >= UDIV1NORM_THRESHOLD) {
1276 limb_t b_inv;
1277 b_inv = udiv1norm_init(b);
1278 for(i = n - 1; i >= 0; i--) {
1279 tabr[i] = udiv1norm(&r, r, taba[i], b, b_inv);
1280 }
1281 } else {
1282 dlimb_t a1;
1283 for(i = n - 1; i >= 0; i--) {
1284 a1 = ((dlimb_t)r << LIMB_BITS) | taba[i];
1285 tabr[i] = a1 / b;
1286 r = a1 % b;
1287 }
1288 }
1289 return r;
1290 }
1291
1292 static int mp_divnorm_large(bf_context_t *s,
1293 limb_t *tabq, limb_t *taba, limb_t na,
1294 const limb_t *tabb, limb_t nb);
1295
1296 /* base case division: divides taba[0..na-1] by tabb[0..nb-1]. tabb[nb
1297 - 1] must be >= 1 << (LIMB_BITS - 1). na - nb must be >= 0. 'taba'
1298 is modified and contains the remainder (nb limbs). tabq[0..na-nb]
1299 contains the quotient with tabq[na - nb] <= 1. */
mp_divnorm(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1300 static int mp_divnorm(bf_context_t *s, limb_t *tabq, limb_t *taba, limb_t na,
1301 const limb_t *tabb, limb_t nb)
1302 {
1303 limb_t r, a, c, q, v, b1, b1_inv, n, dummy_r;
1304 slimb_t i, j;
1305
1306 b1 = tabb[nb - 1];
1307 if (nb == 1) {
1308 taba[0] = mp_div1norm(tabq, taba, na, b1, 0);
1309 return 0;
1310 }
1311 n = na - nb;
1312 if (bf_min(n, nb) >= DIVNORM_LARGE_THRESHOLD) {
1313 return mp_divnorm_large(s, tabq, taba, na, tabb, nb);
1314 }
1315
1316 if (n >= UDIV1NORM_THRESHOLD)
1317 b1_inv = udiv1norm_init(b1);
1318 else
1319 b1_inv = 0;
1320
1321 /* first iteration: the quotient is only 0 or 1 */
1322 q = 1;
1323 for(j = nb - 1; j >= 0; j--) {
1324 if (taba[n + j] != tabb[j]) {
1325 if (taba[n + j] < tabb[j])
1326 q = 0;
1327 break;
1328 }
1329 }
1330 tabq[n] = q;
1331 if (q) {
1332 mp_sub(taba + n, taba + n, tabb, nb, 0);
1333 }
1334
1335 for(i = n - 1; i >= 0; i--) {
1336 if (unlikely(taba[i + nb] >= b1)) {
1337 q = -1;
1338 } else if (b1_inv) {
1339 q = udiv1norm(&dummy_r, taba[i + nb], taba[i + nb - 1], b1, b1_inv);
1340 } else {
1341 dlimb_t al;
1342 al = ((dlimb_t)taba[i + nb] << LIMB_BITS) | taba[i + nb - 1];
1343 q = al / b1;
1344 r = al % b1;
1345 }
1346 r = mp_sub_mul1(taba + i, tabb, nb, q);
1347
1348 v = taba[i + nb];
1349 a = v - r;
1350 c = (a > v);
1351 taba[i + nb] = a;
1352
1353 if (c != 0) {
1354 /* negative result */
1355 for(;;) {
1356 q--;
1357 c = mp_add(taba + i, taba + i, tabb, nb, 0);
1358 /* propagate carry and test if positive result */
1359 if (c != 0) {
1360 if (++taba[i + nb] == 0) {
1361 break;
1362 }
1363 }
1364 }
1365 }
1366 tabq[i] = q;
1367 }
1368 return 0;
1369 }
1370
1371 /* compute r=B^(2*n)/a such as a*r < B^(2*n) < a*r + 2 with n >= 1. 'a'
1372 has n limbs with a[n-1] >= B/2 and 'r' has n+1 limbs with r[n] = 1.
1373
1374 See Modern Computer Arithmetic by Richard P. Brent and Paul
1375 Zimmermann, algorithm 3.5 */
mp_recip(bf_context_t * s,limb_t * tabr,const limb_t * taba,limb_t n)1376 int mp_recip(bf_context_t *s, limb_t *tabr, const limb_t *taba, limb_t n)
1377 {
1378 mp_size_t l, h, k, i;
1379 limb_t *tabxh, *tabt, c, *tabu;
1380
1381 if (n <= 2) {
1382 /* return ceil(B^(2*n)/a) - 1 */
1383 /* XXX: could avoid allocation */
1384 tabu = bf_malloc(s, sizeof(limb_t) * (2 * n + 1));
1385 tabt = bf_malloc(s, sizeof(limb_t) * (n + 2));
1386 if (!tabt || !tabu)
1387 goto fail;
1388 for(i = 0; i < 2 * n; i++)
1389 tabu[i] = 0;
1390 tabu[2 * n] = 1;
1391 if (mp_divnorm(s, tabt, tabu, 2 * n + 1, taba, n))
1392 goto fail;
1393 for(i = 0; i < n + 1; i++)
1394 tabr[i] = tabt[i];
1395 if (mp_scan_nz(tabu, n) == 0) {
1396 /* only happens for a=B^n/2 */
1397 mp_sub_ui(tabr, 1, n + 1);
1398 }
1399 } else {
1400 l = (n - 1) / 2;
1401 h = n - l;
1402 /* n=2p -> l=p-1, h = p + 1, k = p + 3
1403 n=2p+1-> l=p, h = p + 1; k = p + 2
1404 */
1405 tabt = bf_malloc(s, sizeof(limb_t) * (n + h + 1));
1406 tabu = bf_malloc(s, sizeof(limb_t) * (n + 2 * h - l + 2));
1407 if (!tabt || !tabu)
1408 goto fail;
1409 tabxh = tabr + l;
1410 if (mp_recip(s, tabxh, taba + l, h))
1411 goto fail;
1412 if (mp_mul(s, tabt, taba, n, tabxh, h + 1)) /* n + h + 1 limbs */
1413 goto fail;
1414 while (tabt[n + h] != 0) {
1415 mp_sub_ui(tabxh, 1, h + 1);
1416 c = mp_sub(tabt, tabt, taba, n, 0);
1417 mp_sub_ui(tabt + n, c, h + 1);
1418 }
1419 /* T = B^(n+h) - T */
1420 mp_neg(tabt, tabt, n + h + 1, 0);
1421 tabt[n + h]++;
1422 if (mp_mul(s, tabu, tabt + l, n + h + 1 - l, tabxh, h + 1))
1423 goto fail;
1424 /* n + 2*h - l + 2 limbs */
1425 k = 2 * h - l;
1426 for(i = 0; i < l; i++)
1427 tabr[i] = tabu[i + k];
1428 mp_add(tabr + l, tabr + l, tabu + 2 * h, h, 0);
1429 }
1430 bf_free(s, tabt);
1431 bf_free(s, tabu);
1432 return 0;
1433 fail:
1434 bf_free(s, tabt);
1435 bf_free(s, tabu);
1436 return -1;
1437 }
1438
1439 /* return -1, 0 or 1 */
mp_cmp(const limb_t * taba,const limb_t * tabb,mp_size_t n)1440 static int mp_cmp(const limb_t *taba, const limb_t *tabb, mp_size_t n)
1441 {
1442 mp_size_t i;
1443 for(i = n - 1; i >= 0; i--) {
1444 if (taba[i] != tabb[i]) {
1445 if (taba[i] < tabb[i])
1446 return -1;
1447 else
1448 return 1;
1449 }
1450 }
1451 return 0;
1452 }
1453
1454 //#define DEBUG_DIVNORM_LARGE
1455 //#define DEBUG_DIVNORM_LARGE2
1456
1457 /* subquadratic divnorm */
mp_divnorm_large(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1458 static int mp_divnorm_large(bf_context_t *s,
1459 limb_t *tabq, limb_t *taba, limb_t na,
1460 const limb_t *tabb, limb_t nb)
1461 {
1462 limb_t *tabb_inv, nq, *tabt, i, n;
1463 nq = na - nb;
1464 #ifdef DEBUG_DIVNORM_LARGE
1465 printf("na=%d nb=%d nq=%d\n", (int)na, (int)nb, (int)nq);
1466 mp_print_str("a", taba, na);
1467 mp_print_str("b", tabb, nb);
1468 #endif
1469 assert(nq >= 1);
1470 n = nq;
1471 if (nq < nb)
1472 n++;
1473 tabb_inv = bf_malloc(s, sizeof(limb_t) * (n + 1));
1474 tabt = bf_malloc(s, sizeof(limb_t) * 2 * (n + 1));
1475 if (!tabb_inv || !tabt)
1476 goto fail;
1477
1478 if (n >= nb) {
1479 for(i = 0; i < n - nb; i++)
1480 tabt[i] = 0;
1481 for(i = 0; i < nb; i++)
1482 tabt[i + n - nb] = tabb[i];
1483 } else {
1484 /* truncate B: need to increment it so that the approximate
1485 inverse is smaller that the exact inverse */
1486 for(i = 0; i < n; i++)
1487 tabt[i] = tabb[i + nb - n];
1488 if (mp_add_ui(tabt, 1, n)) {
1489 /* tabt = B^n : tabb_inv = B^n */
1490 memset(tabb_inv, 0, n * sizeof(limb_t));
1491 tabb_inv[n] = 1;
1492 goto recip_done;
1493 }
1494 }
1495 if (mp_recip(s, tabb_inv, tabt, n))
1496 goto fail;
1497 recip_done:
1498 /* Q=A*B^-1 */
1499 if (mp_mul(s, tabt, tabb_inv, n + 1, taba + na - (n + 1), n + 1))
1500 goto fail;
1501
1502 for(i = 0; i < nq + 1; i++)
1503 tabq[i] = tabt[i + 2 * (n + 1) - (nq + 1)];
1504 #ifdef DEBUG_DIVNORM_LARGE
1505 mp_print_str("q", tabq, nq + 1);
1506 #endif
1507
1508 bf_free(s, tabt);
1509 bf_free(s, tabb_inv);
1510 tabb_inv = NULL;
1511
1512 /* R=A-B*Q */
1513 tabt = bf_malloc(s, sizeof(limb_t) * (na + 1));
1514 if (!tabt)
1515 goto fail;
1516 if (mp_mul(s, tabt, tabq, nq + 1, tabb, nb))
1517 goto fail;
1518 /* we add one more limb for the result */
1519 mp_sub(taba, taba, tabt, nb + 1, 0);
1520 bf_free(s, tabt);
1521 /* the approximated quotient is smaller than than the exact one,
1522 hence we may have to increment it */
1523 #ifdef DEBUG_DIVNORM_LARGE2
1524 int cnt = 0;
1525 static int cnt_max;
1526 #endif
1527 for(;;) {
1528 if (taba[nb] == 0 && mp_cmp(taba, tabb, nb) < 0)
1529 break;
1530 taba[nb] -= mp_sub(taba, taba, tabb, nb, 0);
1531 mp_add_ui(tabq, 1, nq + 1);
1532 #ifdef DEBUG_DIVNORM_LARGE2
1533 cnt++;
1534 #endif
1535 }
1536 #ifdef DEBUG_DIVNORM_LARGE2
1537 if (cnt > cnt_max) {
1538 cnt_max = cnt;
1539 printf("\ncnt=%d nq=%d nb=%d\n", cnt_max, (int)nq, (int)nb);
1540 }
1541 #endif
1542 return 0;
1543 fail:
1544 bf_free(s, tabb_inv);
1545 bf_free(s, tabt);
1546 return -1;
1547 }
1548
bf_mul(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1549 int bf_mul(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1550 bf_flags_t flags)
1551 {
1552 int ret, r_sign;
1553
1554 if (a->len < b->len) {
1555 const bf_t *tmp = a;
1556 a = b;
1557 b = tmp;
1558 }
1559 r_sign = a->sign ^ b->sign;
1560 /* here b->len <= a->len */
1561 if (b->len == 0) {
1562 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1563 bf_set_nan(r);
1564 ret = 0;
1565 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
1566 if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
1567 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
1568 bf_set_nan(r);
1569 ret = BF_ST_INVALID_OP;
1570 } else {
1571 bf_set_inf(r, r_sign);
1572 ret = 0;
1573 }
1574 } else {
1575 bf_set_zero(r, r_sign);
1576 ret = 0;
1577 }
1578 } else {
1579 bf_t tmp, *r1 = NULL;
1580 limb_t a_len, b_len, precl;
1581 limb_t *a_tab, *b_tab;
1582
1583 a_len = a->len;
1584 b_len = b->len;
1585
1586 if ((flags & BF_RND_MASK) == BF_RNDF) {
1587 /* faithful rounding does not require using the full inputs */
1588 precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1589 a_len = bf_min(a_len, precl);
1590 b_len = bf_min(b_len, precl);
1591 }
1592 a_tab = a->tab + a->len - a_len;
1593 b_tab = b->tab + b->len - b_len;
1594
1595 #ifdef USE_FFT_MUL
1596 if (b_len >= FFT_MUL_THRESHOLD) {
1597 int mul_flags = 0;
1598 if (r == a)
1599 mul_flags |= FFT_MUL_R_OVERLAP_A;
1600 if (r == b)
1601 mul_flags |= FFT_MUL_R_OVERLAP_B;
1602 if (fft_mul(r->ctx, r, a_tab, a_len, b_tab, b_len, mul_flags))
1603 goto fail;
1604 } else
1605 #endif
1606 {
1607 if (r == a || r == b) {
1608 bf_init(r->ctx, &tmp);
1609 r1 = r;
1610 r = &tmp;
1611 }
1612 if (bf_resize(r, a_len + b_len)) {
1613 fail:
1614 bf_set_nan(r);
1615 ret = BF_ST_MEM_ERROR;
1616 goto done;
1617 }
1618 mp_mul_basecase(r->tab, a_tab, a_len, b_tab, b_len);
1619 }
1620 r->sign = r_sign;
1621 r->expn = a->expn + b->expn;
1622 ret = bf_normalize_and_round(r, prec, flags);
1623 done:
1624 if (r == &tmp)
1625 bf_move(r1, &tmp);
1626 }
1627 return ret;
1628 }
1629
1630 /* multiply 'r' by 2^e */
bf_mul_2exp(bf_t * r,slimb_t e,limb_t prec,bf_flags_t flags)1631 int bf_mul_2exp(bf_t *r, slimb_t e, limb_t prec, bf_flags_t flags)
1632 {
1633 slimb_t e_max;
1634 if (r->len == 0)
1635 return 0;
1636 e_max = ((limb_t)1 << BF_EXP_BITS_MAX) - 1;
1637 e = bf_max(e, -e_max);
1638 e = bf_min(e, e_max);
1639 r->expn += e;
1640 return __bf_round(r, prec, flags, r->len);
1641 }
1642
1643 /* Return e such as a=m*2^e with m odd integer. return 0 if a is zero,
1644 Infinite or Nan. */
bf_get_exp_min(const bf_t * a)1645 slimb_t bf_get_exp_min(const bf_t *a)
1646 {
1647 slimb_t i;
1648 limb_t v;
1649 int k;
1650
1651 for(i = 0; i < a->len; i++) {
1652 v = a->tab[i];
1653 if (v != 0) {
1654 k = ctz(v);
1655 return a->expn - (a->len - i) * LIMB_BITS + k;
1656 }
1657 }
1658 return 0;
1659 }
1660
1661 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
1662 integer defined as floor(a/b) and r = a - q * b. */
bf_tdivremu(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b)1663 static void bf_tdivremu(bf_t *q, bf_t *r,
1664 const bf_t *a, const bf_t *b)
1665 {
1666 if (bf_cmpu(a, b) < 0) {
1667 bf_set_ui(q, 0);
1668 bf_set(r, a);
1669 } else {
1670 bf_div(q, a, b, bf_max(a->expn - b->expn + 1, 2), BF_RNDZ);
1671 bf_rint(q, BF_RNDZ);
1672 bf_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
1673 bf_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
1674 }
1675 }
1676
__bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1677 static int __bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1678 bf_flags_t flags)
1679 {
1680 bf_context_t *s = r->ctx;
1681 int ret, r_sign;
1682 limb_t n, nb, precl;
1683
1684 r_sign = a->sign ^ b->sign;
1685 if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
1686 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1687 bf_set_nan(r);
1688 return 0;
1689 } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
1690 bf_set_nan(r);
1691 return BF_ST_INVALID_OP;
1692 } else if (a->expn == BF_EXP_INF) {
1693 bf_set_inf(r, r_sign);
1694 return 0;
1695 } else {
1696 bf_set_zero(r, r_sign);
1697 return 0;
1698 }
1699 } else if (a->expn == BF_EXP_ZERO) {
1700 if (b->expn == BF_EXP_ZERO) {
1701 bf_set_nan(r);
1702 return BF_ST_INVALID_OP;
1703 } else {
1704 bf_set_zero(r, r_sign);
1705 return 0;
1706 }
1707 } else if (b->expn == BF_EXP_ZERO) {
1708 bf_set_inf(r, r_sign);
1709 return BF_ST_DIVIDE_ZERO;
1710 }
1711
1712 /* number of limbs of the quotient (2 extra bits for rounding) */
1713 precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1714 nb = b->len;
1715 n = bf_max(a->len, precl);
1716
1717 {
1718 limb_t *taba, na;
1719 slimb_t d;
1720
1721 na = n + nb;
1722 taba = bf_malloc(s, (na + 1) * sizeof(limb_t));
1723 if (!taba)
1724 goto fail;
1725 d = na - a->len;
1726 memset(taba, 0, d * sizeof(limb_t));
1727 memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
1728 if (bf_resize(r, n + 1))
1729 goto fail;
1730 if (mp_divnorm(s, r->tab, taba, na, b->tab, nb))
1731 goto fail;
1732
1733 /* see if non zero remainder */
1734 if (mp_scan_nz(taba, nb))
1735 r->tab[0] |= 1;
1736 bf_free(r->ctx, taba);
1737 r->expn = a->expn - b->expn + LIMB_BITS;
1738 r->sign = r_sign;
1739 ret = bf_normalize_and_round(r, prec, flags);
1740 }
1741 return ret;
1742 fail:
1743 bf_set_nan(r);
1744 return BF_ST_MEM_ERROR;
1745 }
1746
1747 /* division and remainder.
1748
1749 rnd_mode is the rounding mode for the quotient. The additional
1750 rounding mode BF_RND_EUCLIDIAN is supported.
1751
1752 'q' is an integer. 'r' is rounded with prec and flags (prec can be
1753 BF_PREC_INF).
1754 */
bf_divrem(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1755 int bf_divrem(bf_t *q, bf_t *r, const bf_t *a, const bf_t *b,
1756 limb_t prec, bf_flags_t flags, int rnd_mode)
1757 {
1758 bf_t a1_s, *a1 = &a1_s;
1759 bf_t b1_s, *b1 = &b1_s;
1760 int q_sign, ret;
1761 BOOL is_ceil, is_rndn;
1762
1763 assert(q != a && q != b);
1764 assert(r != a && r != b);
1765 assert(q != r);
1766
1767 if (a->len == 0 || b->len == 0) {
1768 bf_set_zero(q, 0);
1769 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1770 bf_set_nan(r);
1771 return 0;
1772 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
1773 bf_set_nan(r);
1774 return BF_ST_INVALID_OP;
1775 } else {
1776 bf_set(r, a);
1777 return bf_round(r, prec, flags);
1778 }
1779 }
1780
1781 q_sign = a->sign ^ b->sign;
1782 is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA ||
1783 rnd_mode == BF_RNDNU);
1784 switch(rnd_mode) {
1785 default:
1786 case BF_RNDZ:
1787 case BF_RNDN:
1788 case BF_RNDNA:
1789 is_ceil = FALSE;
1790 break;
1791 case BF_RNDD:
1792 is_ceil = q_sign;
1793 break;
1794 case BF_RNDU:
1795 is_ceil = q_sign ^ 1;
1796 break;
1797 case BF_DIVREM_EUCLIDIAN:
1798 is_ceil = a->sign;
1799 break;
1800 case BF_RNDNU:
1801 /* XXX: unsupported yet */
1802 abort();
1803 }
1804
1805 a1->expn = a->expn;
1806 a1->tab = a->tab;
1807 a1->len = a->len;
1808 a1->sign = 0;
1809
1810 b1->expn = b->expn;
1811 b1->tab = b->tab;
1812 b1->len = b->len;
1813 b1->sign = 0;
1814
1815 /* XXX: could improve to avoid having a large 'q' */
1816 bf_tdivremu(q, r, a1, b1);
1817 if (bf_is_nan(q) || bf_is_nan(r))
1818 goto fail;
1819
1820 if (r->len != 0) {
1821 if (is_rndn) {
1822 int res;
1823 b1->expn--;
1824 res = bf_cmpu(r, b1);
1825 b1->expn++;
1826 if (res > 0 ||
1827 (res == 0 &&
1828 (rnd_mode == BF_RNDNA ||
1829 get_bit(q->tab, q->len, q->len * LIMB_BITS - q->expn)))) {
1830 goto do_sub_r;
1831 }
1832 } else if (is_ceil) {
1833 do_sub_r:
1834 ret = bf_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
1835 ret |= bf_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
1836 if (ret & BF_ST_MEM_ERROR)
1837 goto fail;
1838 }
1839 }
1840
1841 r->sign ^= a->sign;
1842 q->sign = q_sign;
1843 return bf_round(r, prec, flags);
1844 fail:
1845 bf_set_nan(q);
1846 bf_set_nan(r);
1847 return BF_ST_MEM_ERROR;
1848 }
1849
bf_fmod(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1850 int bf_fmod(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1851 bf_flags_t flags)
1852 {
1853 bf_t q_s, *q = &q_s;
1854 int ret;
1855
1856 bf_init(r->ctx, q);
1857 ret = bf_divrem(q, r, a, b, prec, flags, BF_RNDZ);
1858 bf_delete(q);
1859 return ret;
1860 }
1861
bf_remainder(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1862 int bf_remainder(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1863 bf_flags_t flags)
1864 {
1865 bf_t q_s, *q = &q_s;
1866 int ret;
1867
1868 bf_init(r->ctx, q);
1869 ret = bf_divrem(q, r, a, b, prec, flags, BF_RNDN);
1870 bf_delete(q);
1871 return ret;
1872 }
1873
bf_get_limb(slimb_t * pres,const bf_t * a,int flags)1874 static inline int bf_get_limb(slimb_t *pres, const bf_t *a, int flags)
1875 {
1876 #if LIMB_BITS == 32
1877 return bf_get_int32(pres, a, flags);
1878 #else
1879 return bf_get_int64(pres, a, flags);
1880 #endif
1881 }
1882
bf_remquo(slimb_t * pq,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1883 int bf_remquo(slimb_t *pq, bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1884 bf_flags_t flags)
1885 {
1886 bf_t q_s, *q = &q_s;
1887 int ret;
1888
1889 bf_init(r->ctx, q);
1890 ret = bf_divrem(q, r, a, b, prec, flags, BF_RNDN);
1891 bf_get_limb(pq, q, BF_GET_INT_MOD);
1892 bf_delete(q);
1893 return ret;
1894 }
1895
mul_mod(limb_t a,limb_t b,limb_t m)1896 static __maybe_unused inline limb_t mul_mod(limb_t a, limb_t b, limb_t m)
1897 {
1898 dlimb_t t;
1899 t = (dlimb_t)a * (dlimb_t)b;
1900 return t % m;
1901 }
1902
1903 #if defined(USE_MUL_CHECK)
mp_mod1(const limb_t * tab,limb_t n,limb_t m,limb_t r)1904 static limb_t mp_mod1(const limb_t *tab, limb_t n, limb_t m, limb_t r)
1905 {
1906 slimb_t i;
1907 dlimb_t t;
1908
1909 for(i = n - 1; i >= 0; i--) {
1910 t = ((dlimb_t)r << LIMB_BITS) | tab[i];
1911 r = t % m;
1912 }
1913 return r;
1914 }
1915 #endif
1916
1917 static const uint16_t sqrt_table[192] = {
1918 128,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,144,145,146,147,148,149,150,150,151,152,153,154,155,155,156,157,158,159,160,160,161,162,163,163,164,165,166,167,167,168,169,170,170,171,172,173,173,174,175,176,176,177,178,178,179,180,181,181,182,183,183,184,185,185,186,187,187,188,189,189,190,191,192,192,193,193,194,195,195,196,197,197,198,199,199,200,201,201,202,203,203,204,204,205,206,206,207,208,208,209,209,210,211,211,212,212,213,214,214,215,215,216,217,217,218,218,219,219,220,221,221,222,222,223,224,224,225,225,226,226,227,227,228,229,229,230,230,231,231,232,232,233,234,234,235,235,236,236,237,237,238,238,239,240,240,241,241,242,242,243,243,244,244,245,245,246,246,247,247,248,248,249,249,250,250,251,251,252,252,253,253,254,254,255,
1919 };
1920
1921 /* a >= 2^(LIMB_BITS - 2). Return (s, r) with s=floor(sqrt(a)) and
1922 r=a-s^2. 0 <= r <= 2 * s */
mp_sqrtrem1(limb_t * pr,limb_t a)1923 static limb_t mp_sqrtrem1(limb_t *pr, limb_t a)
1924 {
1925 limb_t s1, r1, s, r, q, u, num;
1926
1927 /* use a table for the 16 -> 8 bit sqrt */
1928 s1 = sqrt_table[(a >> (LIMB_BITS - 8)) - 64];
1929 r1 = (a >> (LIMB_BITS - 16)) - s1 * s1;
1930 if (r1 > 2 * s1) {
1931 r1 -= 2 * s1 + 1;
1932 s1++;
1933 }
1934
1935 /* one iteration to get a 32 -> 16 bit sqrt */
1936 num = (r1 << 8) | ((a >> (LIMB_BITS - 32 + 8)) & 0xff);
1937 q = num / (2 * s1); /* q <= 2^8 */
1938 u = num % (2 * s1);
1939 s = (s1 << 8) + q;
1940 r = (u << 8) | ((a >> (LIMB_BITS - 32)) & 0xff);
1941 r -= q * q;
1942 if ((slimb_t)r < 0) {
1943 s--;
1944 r += 2 * s + 1;
1945 }
1946
1947 #if LIMB_BITS == 64
1948 s1 = s;
1949 r1 = r;
1950 /* one more iteration for 64 -> 32 bit sqrt */
1951 num = (r1 << 16) | ((a >> (LIMB_BITS - 64 + 16)) & 0xffff);
1952 q = num / (2 * s1); /* q <= 2^16 */
1953 u = num % (2 * s1);
1954 s = (s1 << 16) + q;
1955 r = (u << 16) | ((a >> (LIMB_BITS - 64)) & 0xffff);
1956 r -= q * q;
1957 if ((slimb_t)r < 0) {
1958 s--;
1959 r += 2 * s + 1;
1960 }
1961 #endif
1962 *pr = r;
1963 return s;
1964 }
1965
1966 /* return floor(sqrt(a)) */
bf_isqrt(limb_t a)1967 limb_t bf_isqrt(limb_t a)
1968 {
1969 limb_t s, r;
1970 int k;
1971
1972 if (a == 0)
1973 return 0;
1974 k = clz(a) & ~1;
1975 s = mp_sqrtrem1(&r, a << k);
1976 s >>= (k >> 1);
1977 return s;
1978 }
1979
mp_sqrtrem2(limb_t * tabs,limb_t * taba)1980 static limb_t mp_sqrtrem2(limb_t *tabs, limb_t *taba)
1981 {
1982 limb_t s1, r1, s, q, u, a0, a1;
1983 dlimb_t r, num;
1984 int l;
1985
1986 a0 = taba[0];
1987 a1 = taba[1];
1988 s1 = mp_sqrtrem1(&r1, a1);
1989 l = LIMB_BITS / 2;
1990 num = ((dlimb_t)r1 << l) | (a0 >> l);
1991 q = num / (2 * s1);
1992 u = num % (2 * s1);
1993 s = (s1 << l) + q;
1994 r = ((dlimb_t)u << l) | (a0 & (((limb_t)1 << l) - 1));
1995 if (unlikely((q >> l) != 0))
1996 r -= (dlimb_t)1 << LIMB_BITS; /* special case when q=2^l */
1997 else
1998 r -= q * q;
1999 if ((slimb_t)(r >> LIMB_BITS) < 0) {
2000 s--;
2001 r += 2 * (dlimb_t)s + 1;
2002 }
2003 tabs[0] = s;
2004 taba[0] = r;
2005 return r >> LIMB_BITS;
2006 }
2007
2008 //#define DEBUG_SQRTREM
2009
2010 /* tmp_buf must contain (n / 2 + 1 limbs). *prh contains the highest
2011 limb of the remainder. */
mp_sqrtrem_rec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf,limb_t * prh)2012 static int mp_sqrtrem_rec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n,
2013 limb_t *tmp_buf, limb_t *prh)
2014 {
2015 limb_t l, h, rh, ql, qh, c, i;
2016
2017 if (n == 1) {
2018 *prh = mp_sqrtrem2(tabs, taba);
2019 return 0;
2020 }
2021 #ifdef DEBUG_SQRTREM
2022 mp_print_str("a", taba, 2 * n);
2023 #endif
2024 l = n / 2;
2025 h = n - l;
2026 if (mp_sqrtrem_rec(s, tabs + l, taba + 2 * l, h, tmp_buf, &qh))
2027 return -1;
2028 #ifdef DEBUG_SQRTREM
2029 mp_print_str("s1", tabs + l, h);
2030 mp_print_str_h("r1", taba + 2 * l, h, qh);
2031 mp_print_str_h("r2", taba + l, n, qh);
2032 #endif
2033
2034 /* the remainder is in taba + 2 * l. Its high bit is in qh */
2035 if (qh) {
2036 mp_sub(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
2037 }
2038 /* instead of dividing by 2*s, divide by s (which is normalized)
2039 and update q and r */
2040 if (mp_divnorm(s, tmp_buf, taba + l, n, tabs + l, h))
2041 return -1;
2042 qh += tmp_buf[l];
2043 for(i = 0; i < l; i++)
2044 tabs[i] = tmp_buf[i];
2045 ql = mp_shr(tabs, tabs, l, 1, qh & 1);
2046 qh = qh >> 1; /* 0 or 1 */
2047 if (ql)
2048 rh = mp_add(taba + l, taba + l, tabs + l, h, 0);
2049 else
2050 rh = 0;
2051 #ifdef DEBUG_SQRTREM
2052 mp_print_str_h("q", tabs, l, qh);
2053 mp_print_str_h("u", taba + l, h, rh);
2054 #endif
2055
2056 mp_add_ui(tabs + l, qh, h);
2057 #ifdef DEBUG_SQRTREM
2058 mp_print_str_h("s2", tabs, n, sh);
2059 #endif
2060
2061 /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
2062 /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
2063 if (qh) {
2064 c = qh;
2065 } else {
2066 if (mp_mul(s, taba + n, tabs, l, tabs, l))
2067 return -1;
2068 c = mp_sub(taba, taba, taba + n, 2 * l, 0);
2069 }
2070 rh -= mp_sub_ui(taba + 2 * l, c, n - 2 * l);
2071 if ((slimb_t)rh < 0) {
2072 mp_sub_ui(tabs, 1, n);
2073 rh += mp_add_mul1(taba, tabs, n, 2);
2074 rh += mp_add_ui(taba, 1, n);
2075 }
2076 *prh = rh;
2077 return 0;
2078 }
2079
2080 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= 2 ^ (LIMB_BITS
2081 - 2). Return (s, r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2
2082 * s. tabs has n limbs. r is returned in the lower n limbs of
2083 taba. Its r[n] is the returned value of the function. */
2084 /* Algorithm from the article "Karatsuba Square Root" by Paul Zimmermann and
2085 inspirated from its GMP implementation */
mp_sqrtrem(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)2086 int mp_sqrtrem(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
2087 {
2088 limb_t tmp_buf1[8];
2089 limb_t *tmp_buf;
2090 mp_size_t n2;
2091 int ret;
2092 n2 = n / 2 + 1;
2093 if (n2 <= countof(tmp_buf1)) {
2094 tmp_buf = tmp_buf1;
2095 } else {
2096 tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
2097 if (!tmp_buf)
2098 return -1;
2099 }
2100 ret = mp_sqrtrem_rec(s, tabs, taba, n, tmp_buf, taba + n);
2101 if (tmp_buf != tmp_buf1)
2102 bf_free(s, tmp_buf);
2103 return ret;
2104 }
2105
2106 /* Integer square root with remainder. 'a' must be an integer. r =
2107 floor(sqrt(a)) and rem = a - r^2. BF_ST_INEXACT is set if the result
2108 is inexact. 'rem' can be NULL if the remainder is not needed. */
bf_sqrtrem(bf_t * r,bf_t * rem1,const bf_t * a)2109 int bf_sqrtrem(bf_t *r, bf_t *rem1, const bf_t *a)
2110 {
2111 int ret;
2112
2113 if (a->len == 0) {
2114 if (a->expn == BF_EXP_NAN) {
2115 bf_set_nan(r);
2116 } else if (a->expn == BF_EXP_INF && a->sign) {
2117 goto invalid_op;
2118 } else {
2119 bf_set(r, a);
2120 }
2121 if (rem1)
2122 bf_set_ui(rem1, 0);
2123 ret = 0;
2124 } else if (a->sign) {
2125 invalid_op:
2126 bf_set_nan(r);
2127 if (rem1)
2128 bf_set_ui(rem1, 0);
2129 ret = BF_ST_INVALID_OP;
2130 } else {
2131 bf_t rem_s, *rem;
2132
2133 bf_sqrt(r, a, (a->expn + 1) / 2, BF_RNDZ);
2134 bf_rint(r, BF_RNDZ);
2135 /* see if the result is exact by computing the remainder */
2136 if (rem1) {
2137 rem = rem1;
2138 } else {
2139 rem = &rem_s;
2140 bf_init(r->ctx, rem);
2141 }
2142 /* XXX: could avoid recomputing the remainder */
2143 bf_mul(rem, r, r, BF_PREC_INF, BF_RNDZ);
2144 bf_neg(rem);
2145 bf_add(rem, rem, a, BF_PREC_INF, BF_RNDZ);
2146 if (bf_is_nan(rem)) {
2147 ret = BF_ST_MEM_ERROR;
2148 goto done;
2149 }
2150 if (rem->len != 0) {
2151 ret = BF_ST_INEXACT;
2152 } else {
2153 ret = 0;
2154 }
2155 done:
2156 if (!rem1)
2157 bf_delete(rem);
2158 }
2159 return ret;
2160 }
2161
bf_sqrt(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)2162 int bf_sqrt(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
2163 {
2164 bf_context_t *s = a->ctx;
2165 int ret;
2166
2167 assert(r != a);
2168
2169 if (a->len == 0) {
2170 if (a->expn == BF_EXP_NAN) {
2171 bf_set_nan(r);
2172 } else if (a->expn == BF_EXP_INF && a->sign) {
2173 goto invalid_op;
2174 } else {
2175 bf_set(r, a);
2176 }
2177 ret = 0;
2178 } else if (a->sign) {
2179 invalid_op:
2180 bf_set_nan(r);
2181 ret = BF_ST_INVALID_OP;
2182 } else {
2183 limb_t *a1;
2184 slimb_t n, n1;
2185 limb_t res;
2186
2187 /* convert the mantissa to an integer with at least 2 *
2188 prec + 4 bits */
2189 n = (2 * (prec + 2) + 2 * LIMB_BITS - 1) / (2 * LIMB_BITS);
2190 if (bf_resize(r, n))
2191 goto fail;
2192 a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
2193 if (!a1)
2194 goto fail;
2195 n1 = bf_min(2 * n, a->len);
2196 memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
2197 memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
2198 if (a->expn & 1) {
2199 res = mp_shr(a1, a1, 2 * n, 1, 0);
2200 } else {
2201 res = 0;
2202 }
2203 if (mp_sqrtrem(s, r->tab, a1, n)) {
2204 bf_free(s, a1);
2205 goto fail;
2206 }
2207 if (!res) {
2208 res = mp_scan_nz(a1, n + 1);
2209 }
2210 bf_free(s, a1);
2211 if (!res) {
2212 res = mp_scan_nz(a->tab, a->len - n1);
2213 }
2214 if (res != 0)
2215 r->tab[0] |= 1;
2216 r->sign = 0;
2217 r->expn = (a->expn + 1) >> 1;
2218 ret = bf_round(r, prec, flags);
2219 }
2220 return ret;
2221 fail:
2222 bf_set_nan(r);
2223 return BF_ST_MEM_ERROR;
2224 }
2225
bf_op2(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,bf_op2_func_t * func)2226 static no_inline int bf_op2(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2227 bf_flags_t flags, bf_op2_func_t *func)
2228 {
2229 bf_t tmp;
2230 int ret;
2231
2232 if (r == a || r == b) {
2233 bf_init(r->ctx, &tmp);
2234 ret = func(&tmp, a, b, prec, flags);
2235 bf_move(r, &tmp);
2236 } else {
2237 ret = func(r, a, b, prec, flags);
2238 }
2239 return ret;
2240 }
2241
bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2242 int bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2243 bf_flags_t flags)
2244 {
2245 return bf_op2(r, a, b, prec, flags, __bf_add);
2246 }
2247
bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2248 int bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2249 bf_flags_t flags)
2250 {
2251 return bf_op2(r, a, b, prec, flags, __bf_sub);
2252 }
2253
bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2254 int bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2255 bf_flags_t flags)
2256 {
2257 return bf_op2(r, a, b, prec, flags, __bf_div);
2258 }
2259
bf_mul_ui(bf_t * r,const bf_t * a,uint64_t b1,limb_t prec,bf_flags_t flags)2260 int bf_mul_ui(bf_t *r, const bf_t *a, uint64_t b1, limb_t prec,
2261 bf_flags_t flags)
2262 {
2263 bf_t b;
2264 int ret;
2265 bf_init(r->ctx, &b);
2266 ret = bf_set_ui(&b, b1);
2267 ret |= bf_mul(r, a, &b, prec, flags);
2268 bf_delete(&b);
2269 return ret;
2270 }
2271
bf_mul_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2272 int bf_mul_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2273 bf_flags_t flags)
2274 {
2275 bf_t b;
2276 int ret;
2277 bf_init(r->ctx, &b);
2278 ret = bf_set_si(&b, b1);
2279 ret |= bf_mul(r, a, &b, prec, flags);
2280 bf_delete(&b);
2281 return ret;
2282 }
2283
bf_add_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2284 int bf_add_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2285 bf_flags_t flags)
2286 {
2287 bf_t b;
2288 int ret;
2289
2290 bf_init(r->ctx, &b);
2291 ret = bf_set_si(&b, b1);
2292 ret |= bf_add(r, a, &b, prec, flags);
2293 bf_delete(&b);
2294 return ret;
2295 }
2296
bf_pow_ui(bf_t * r,const bf_t * a,limb_t b,limb_t prec,bf_flags_t flags)2297 static int bf_pow_ui(bf_t *r, const bf_t *a, limb_t b, limb_t prec,
2298 bf_flags_t flags)
2299 {
2300 int ret, n_bits, i;
2301
2302 assert(r != a);
2303 if (b == 0)
2304 return bf_set_ui(r, 1);
2305 ret = bf_set(r, a);
2306 n_bits = LIMB_BITS - clz(b);
2307 for(i = n_bits - 2; i >= 0; i--) {
2308 ret |= bf_mul(r, r, r, prec, flags);
2309 if ((b >> i) & 1)
2310 ret |= bf_mul(r, r, a, prec, flags);
2311 }
2312 return ret;
2313 }
2314
bf_pow_ui_ui(bf_t * r,limb_t a1,limb_t b,limb_t prec,bf_flags_t flags)2315 static int bf_pow_ui_ui(bf_t *r, limb_t a1, limb_t b,
2316 limb_t prec, bf_flags_t flags)
2317 {
2318 bf_t a;
2319 int ret;
2320
2321 if (a1 == 10 && b <= LIMB_DIGITS) {
2322 /* use precomputed powers. We do not round at this point
2323 because we expect the caller to do it */
2324 ret = bf_set_ui(r, mp_pow_dec[b]);
2325 } else {
2326 bf_init(r->ctx, &a);
2327 ret = bf_set_ui(&a, a1);
2328 ret |= bf_pow_ui(r, &a, b, prec, flags);
2329 bf_delete(&a);
2330 }
2331 return ret;
2332 }
2333
2334 /* convert to integer (infinite precision) */
bf_rint(bf_t * r,int rnd_mode)2335 int bf_rint(bf_t *r, int rnd_mode)
2336 {
2337 return bf_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
2338 }
2339
2340 /* logical operations */
2341 #define BF_LOGIC_OR 0
2342 #define BF_LOGIC_XOR 1
2343 #define BF_LOGIC_AND 2
2344
bf_logic_op1(limb_t a,limb_t b,int op)2345 static inline limb_t bf_logic_op1(limb_t a, limb_t b, int op)
2346 {
2347 switch(op) {
2348 case BF_LOGIC_OR:
2349 return a | b;
2350 case BF_LOGIC_XOR:
2351 return a ^ b;
2352 default:
2353 case BF_LOGIC_AND:
2354 return a & b;
2355 }
2356 }
2357
bf_logic_op(bf_t * r,const bf_t * a1,const bf_t * b1,int op)2358 static int bf_logic_op(bf_t *r, const bf_t *a1, const bf_t *b1, int op)
2359 {
2360 bf_t b1_s, a1_s, *a, *b;
2361 limb_t a_sign, b_sign, r_sign;
2362 slimb_t l, i, a_bit_offset, b_bit_offset;
2363 limb_t v1, v2, v1_mask, v2_mask, r_mask;
2364 int ret;
2365
2366 assert(r != a1 && r != b1);
2367
2368 if (a1->expn <= 0)
2369 a_sign = 0; /* minus zero is considered as positive */
2370 else
2371 a_sign = a1->sign;
2372
2373 if (b1->expn <= 0)
2374 b_sign = 0; /* minus zero is considered as positive */
2375 else
2376 b_sign = b1->sign;
2377
2378 if (a_sign) {
2379 a = &a1_s;
2380 bf_init(r->ctx, a);
2381 if (bf_add_si(a, a1, 1, BF_PREC_INF, BF_RNDZ)) {
2382 b = NULL;
2383 goto fail;
2384 }
2385 } else {
2386 a = (bf_t *)a1;
2387 }
2388
2389 if (b_sign) {
2390 b = &b1_s;
2391 bf_init(r->ctx, b);
2392 if (bf_add_si(b, b1, 1, BF_PREC_INF, BF_RNDZ))
2393 goto fail;
2394 } else {
2395 b = (bf_t *)b1;
2396 }
2397
2398 r_sign = bf_logic_op1(a_sign, b_sign, op);
2399 if (op == BF_LOGIC_AND && r_sign == 0) {
2400 /* no need to compute extra zeros for and */
2401 if (a_sign == 0 && b_sign == 0)
2402 l = bf_min(a->expn, b->expn);
2403 else if (a_sign == 0)
2404 l = a->expn;
2405 else
2406 l = b->expn;
2407 } else {
2408 l = bf_max(a->expn, b->expn);
2409 }
2410 /* Note: a or b can be zero */
2411 l = (bf_max(l, 1) + LIMB_BITS - 1) / LIMB_BITS;
2412 if (bf_resize(r, l))
2413 goto fail;
2414 a_bit_offset = a->len * LIMB_BITS - a->expn;
2415 b_bit_offset = b->len * LIMB_BITS - b->expn;
2416 v1_mask = -a_sign;
2417 v2_mask = -b_sign;
2418 r_mask = -r_sign;
2419 for(i = 0; i < l; i++) {
2420 v1 = get_bits(a->tab, a->len, a_bit_offset + i * LIMB_BITS) ^ v1_mask;
2421 v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS) ^ v2_mask;
2422 r->tab[i] = bf_logic_op1(v1, v2, op) ^ r_mask;
2423 }
2424 r->expn = l * LIMB_BITS;
2425 r->sign = r_sign;
2426 bf_normalize_and_round(r, BF_PREC_INF, BF_RNDZ); /* cannot fail */
2427 if (r_sign) {
2428 if (bf_add_si(r, r, -1, BF_PREC_INF, BF_RNDZ))
2429 goto fail;
2430 }
2431 ret = 0;
2432 done:
2433 if (a == &a1_s)
2434 bf_delete(a);
2435 if (b == &b1_s)
2436 bf_delete(b);
2437 return ret;
2438 fail:
2439 bf_set_nan(r);
2440 ret = BF_ST_MEM_ERROR;
2441 goto done;
2442 }
2443
2444 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_or(bf_t * r,const bf_t * a,const bf_t * b)2445 int bf_logic_or(bf_t *r, const bf_t *a, const bf_t *b)
2446 {
2447 return bf_logic_op(r, a, b, BF_LOGIC_OR);
2448 }
2449
2450 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_xor(bf_t * r,const bf_t * a,const bf_t * b)2451 int bf_logic_xor(bf_t *r, const bf_t *a, const bf_t *b)
2452 {
2453 return bf_logic_op(r, a, b, BF_LOGIC_XOR);
2454 }
2455
2456 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_and(bf_t * r,const bf_t * a,const bf_t * b)2457 int bf_logic_and(bf_t *r, const bf_t *a, const bf_t *b)
2458 {
2459 return bf_logic_op(r, a, b, BF_LOGIC_AND);
2460 }
2461
2462 /* conversion between fixed size types */
2463
2464 typedef union {
2465 double d;
2466 uint64_t u;
2467 } Float64Union;
2468
bf_get_float64(const bf_t * a,double * pres,bf_rnd_t rnd_mode)2469 int bf_get_float64(const bf_t *a, double *pres, bf_rnd_t rnd_mode)
2470 {
2471 Float64Union u;
2472 int e, ret;
2473 uint64_t m;
2474
2475 ret = 0;
2476 if (a->expn == BF_EXP_NAN) {
2477 u.u = 0x7ff8000000000000; /* quiet nan */
2478 } else {
2479 bf_t b_s, *b = &b_s;
2480
2481 bf_init(a->ctx, b);
2482 bf_set(b, a);
2483 if (bf_is_finite(b)) {
2484 ret = bf_round(b, 53, rnd_mode | BF_FLAG_SUBNORMAL | bf_set_exp_bits(11));
2485 }
2486 if (b->expn == BF_EXP_INF) {
2487 e = (1 << 11) - 1;
2488 m = 0;
2489 } else if (b->expn == BF_EXP_ZERO) {
2490 e = 0;
2491 m = 0;
2492 } else {
2493 e = b->expn + 1023 - 1;
2494 #if LIMB_BITS == 32
2495 if (b->len == 2) {
2496 m = ((uint64_t)b->tab[1] << 32) | b->tab[0];
2497 } else {
2498 m = ((uint64_t)b->tab[0] << 32);
2499 }
2500 #else
2501 m = b->tab[0];
2502 #endif
2503 if (e <= 0) {
2504 /* subnormal */
2505 m = m >> (12 - e);
2506 e = 0;
2507 } else {
2508 m = (m << 1) >> 12;
2509 }
2510 }
2511 u.u = m | ((uint64_t)e << 52) | ((uint64_t)b->sign << 63);
2512 bf_delete(b);
2513 }
2514 *pres = u.d;
2515 return ret;
2516 }
2517
bf_set_float64(bf_t * a,double d)2518 int bf_set_float64(bf_t *a, double d)
2519 {
2520 Float64Union u;
2521 uint64_t m;
2522 int shift, e, sgn;
2523
2524 u.d = d;
2525 sgn = u.u >> 63;
2526 e = (u.u >> 52) & ((1 << 11) - 1);
2527 m = u.u & (((uint64_t)1 << 52) - 1);
2528 if (e == ((1 << 11) - 1)) {
2529 if (m != 0) {
2530 bf_set_nan(a);
2531 } else {
2532 bf_set_inf(a, sgn);
2533 }
2534 } else if (e == 0) {
2535 if (m == 0) {
2536 bf_set_zero(a, sgn);
2537 } else {
2538 /* subnormal number */
2539 m <<= 12;
2540 shift = clz64(m);
2541 m <<= shift;
2542 e = -shift;
2543 goto norm;
2544 }
2545 } else {
2546 m = (m << 11) | ((uint64_t)1 << 63);
2547 norm:
2548 a->expn = e - 1023 + 1;
2549 #if LIMB_BITS == 32
2550 if (bf_resize(a, 2))
2551 goto fail;
2552 a->tab[0] = m;
2553 a->tab[1] = m >> 32;
2554 #else
2555 if (bf_resize(a, 1))
2556 goto fail;
2557 a->tab[0] = m;
2558 #endif
2559 a->sign = sgn;
2560 }
2561 return 0;
2562 fail:
2563 bf_set_nan(a);
2564 return BF_ST_MEM_ERROR;
2565 }
2566
2567 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
2568 is an overflow and 0 otherwise. */
bf_get_int32(int * pres,const bf_t * a,int flags)2569 int bf_get_int32(int *pres, const bf_t *a, int flags)
2570 {
2571 uint32_t v;
2572 int ret;
2573 if (a->expn >= BF_EXP_INF) {
2574 ret = 0;
2575 if (flags & BF_GET_INT_MOD) {
2576 v = 0;
2577 } else if (a->expn == BF_EXP_INF) {
2578 v = (uint32_t)INT32_MAX + a->sign;
2579 /* XXX: return overflow ? */
2580 } else {
2581 v = INT32_MAX;
2582 }
2583 } else if (a->expn <= 0) {
2584 v = 0;
2585 ret = 0;
2586 } else if (a->expn <= 31) {
2587 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2588 if (a->sign)
2589 v = -v;
2590 ret = 0;
2591 } else if (!(flags & BF_GET_INT_MOD)) {
2592 ret = BF_ST_OVERFLOW;
2593 if (a->sign) {
2594 v = (uint32_t)INT32_MAX + 1;
2595 if (a->expn == 32 &&
2596 (a->tab[a->len - 1] >> (LIMB_BITS - 32)) == v) {
2597 ret = 0;
2598 }
2599 } else {
2600 v = INT32_MAX;
2601 }
2602 } else {
2603 v = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
2604 if (a->sign)
2605 v = -v;
2606 ret = 0;
2607 }
2608 *pres = v;
2609 return ret;
2610 }
2611
2612 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
2613 is an overflow and 0 otherwise. */
bf_get_int64(int64_t * pres,const bf_t * a,int flags)2614 int bf_get_int64(int64_t *pres, const bf_t *a, int flags)
2615 {
2616 uint64_t v;
2617 int ret;
2618 if (a->expn >= BF_EXP_INF) {
2619 ret = 0;
2620 if (flags & BF_GET_INT_MOD) {
2621 v = 0;
2622 } else if (a->expn == BF_EXP_INF) {
2623 v = (uint64_t)INT64_MAX + a->sign;
2624 } else {
2625 v = INT64_MAX;
2626 }
2627 } else if (a->expn <= 0) {
2628 v = 0;
2629 ret = 0;
2630 } else if (a->expn <= 63) {
2631 #if LIMB_BITS == 32
2632 if (a->expn <= 32)
2633 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2634 else
2635 v = (((uint64_t)a->tab[a->len - 1] << 32) |
2636 get_limbz(a, a->len - 2)) >> (64 - a->expn);
2637 #else
2638 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2639 #endif
2640 if (a->sign)
2641 v = -v;
2642 ret = 0;
2643 } else if (!(flags & BF_GET_INT_MOD)) {
2644 ret = BF_ST_OVERFLOW;
2645 if (a->sign) {
2646 uint64_t v1;
2647 v = (uint64_t)INT64_MAX + 1;
2648 if (a->expn == 64) {
2649 v1 = a->tab[a->len - 1];
2650 #if LIMB_BITS == 32
2651 v1 = (v1 << 32) | get_limbz(a, a->len - 2);
2652 #endif
2653 if (v1 == v)
2654 ret = 0;
2655 }
2656 } else {
2657 v = INT64_MAX;
2658 }
2659 } else {
2660 slimb_t bit_pos = a->len * LIMB_BITS - a->expn;
2661 v = get_bits(a->tab, a->len, bit_pos);
2662 #if LIMB_BITS == 32
2663 v |= (uint64_t)get_bits(a->tab, a->len, bit_pos + 32) << 32;
2664 #endif
2665 if (a->sign)
2666 v = -v;
2667 ret = 0;
2668 }
2669 *pres = v;
2670 return ret;
2671 }
2672
2673 /* base conversion from radix */
2674
2675 static const uint8_t digits_per_limb_table[BF_RADIX_MAX - 1] = {
2676 #if LIMB_BITS == 32
2677 32,20,16,13,12,11,10,10, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
2678 #else
2679 64,40,32,27,24,22,21,20,19,18,17,17,16,16,16,15,15,15,14,14,14,14,13,13,13,13,13,13,13,12,12,12,12,12,12,
2680 #endif
2681 };
2682
get_limb_radix(int radix)2683 static limb_t get_limb_radix(int radix)
2684 {
2685 int i, k;
2686 limb_t radixl;
2687
2688 k = digits_per_limb_table[radix - 2];
2689 radixl = radix;
2690 for(i = 1; i < k; i++)
2691 radixl *= radix;
2692 return radixl;
2693 }
2694
2695 /* return != 0 if error */
bf_integer_from_radix_rec(bf_t * r,const limb_t * tab,limb_t n,int level,limb_t n0,limb_t radix,bf_t * pow_tab)2696 static int bf_integer_from_radix_rec(bf_t *r, const limb_t *tab,
2697 limb_t n, int level, limb_t n0,
2698 limb_t radix, bf_t *pow_tab)
2699 {
2700 int ret;
2701 if (n == 1) {
2702 ret = bf_set_ui(r, tab[0]);
2703 } else {
2704 bf_t T_s, *T = &T_s, *B;
2705 limb_t n1, n2;
2706
2707 n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
2708 n1 = n - n2;
2709 // printf("level=%d n0=%ld n1=%ld n2=%ld\n", level, n0, n1, n2);
2710 B = &pow_tab[level];
2711 if (B->len == 0) {
2712 ret = bf_pow_ui_ui(B, radix, n2, BF_PREC_INF, BF_RNDZ);
2713 if (ret)
2714 return ret;
2715 }
2716 ret = bf_integer_from_radix_rec(r, tab + n2, n1, level + 1, n0,
2717 radix, pow_tab);
2718 if (ret)
2719 return ret;
2720 ret = bf_mul(r, r, B, BF_PREC_INF, BF_RNDZ);
2721 if (ret)
2722 return ret;
2723 bf_init(r->ctx, T);
2724 ret = bf_integer_from_radix_rec(T, tab, n2, level + 1, n0,
2725 radix, pow_tab);
2726 if (!ret)
2727 ret = bf_add(r, r, T, BF_PREC_INF, BF_RNDZ);
2728 bf_delete(T);
2729 }
2730 return ret;
2731 // bf_print_str(" r=", r);
2732 }
2733
2734 /* return 0 if OK != 0 if memory error */
bf_integer_from_radix(bf_t * r,const limb_t * tab,limb_t n,limb_t radix)2735 static int bf_integer_from_radix(bf_t *r, const limb_t *tab,
2736 limb_t n, limb_t radix)
2737 {
2738 bf_context_t *s = r->ctx;
2739 int pow_tab_len, i, ret;
2740 limb_t radixl;
2741 bf_t *pow_tab;
2742
2743 radixl = get_limb_radix(radix);
2744 pow_tab_len = ceil_log2(n) + 2; /* XXX: check */
2745 pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
2746 if (!pow_tab)
2747 return -1;
2748 for(i = 0; i < pow_tab_len; i++)
2749 bf_init(r->ctx, &pow_tab[i]);
2750 ret = bf_integer_from_radix_rec(r, tab, n, 0, n, radixl, pow_tab);
2751 for(i = 0; i < pow_tab_len; i++) {
2752 bf_delete(&pow_tab[i]);
2753 }
2754 bf_free(s, pow_tab);
2755 return ret;
2756 }
2757
2758 /* compute and round T * radix^expn. */
bf_mul_pow_radix(bf_t * r,const bf_t * T,limb_t radix,slimb_t expn,limb_t prec,bf_flags_t flags)2759 int bf_mul_pow_radix(bf_t *r, const bf_t *T, limb_t radix,
2760 slimb_t expn, limb_t prec, bf_flags_t flags)
2761 {
2762 int ret, expn_sign, overflow;
2763 slimb_t e, extra_bits, prec1, ziv_extra_bits;
2764 bf_t B_s, *B = &B_s;
2765
2766 if (T->len == 0) {
2767 return bf_set(r, T);
2768 } else if (expn == 0) {
2769 ret = bf_set(r, T);
2770 ret |= bf_round(r, prec, flags);
2771 return ret;
2772 }
2773
2774 e = expn;
2775 expn_sign = 0;
2776 if (e < 0) {
2777 e = -e;
2778 expn_sign = 1;
2779 }
2780 bf_init(r->ctx, B);
2781 if (prec == BF_PREC_INF) {
2782 /* infinite precision: only used if the result is known to be exact */
2783 ret = bf_pow_ui_ui(B, radix, e, BF_PREC_INF, BF_RNDN);
2784 if (expn_sign) {
2785 ret |= bf_div(r, T, B, T->len * LIMB_BITS, BF_RNDN);
2786 } else {
2787 ret |= bf_mul(r, T, B, BF_PREC_INF, BF_RNDN);
2788 }
2789 } else {
2790 ziv_extra_bits = 16;
2791 for(;;) {
2792 prec1 = prec + ziv_extra_bits;
2793 /* XXX: correct overflow/underflow handling */
2794 /* XXX: rigorous error analysis needed */
2795 extra_bits = ceil_log2(e) * 2 + 1;
2796 ret = bf_pow_ui_ui(B, radix, e, prec1 + extra_bits, BF_RNDN);
2797 overflow = !bf_is_finite(B);
2798 /* XXX: if bf_pow_ui_ui returns an exact result, can stop
2799 after the next operation */
2800 if (expn_sign)
2801 ret |= bf_div(r, T, B, prec1 + extra_bits, BF_RNDN);
2802 else
2803 ret |= bf_mul(r, T, B, prec1 + extra_bits, BF_RNDN);
2804 if (ret & BF_ST_MEM_ERROR)
2805 break;
2806 if ((ret & BF_ST_INEXACT) &&
2807 !bf_can_round(r, prec, flags & BF_RND_MASK, prec1) &&
2808 !overflow) {
2809 /* and more precision and retry */
2810 ziv_extra_bits = ziv_extra_bits + (ziv_extra_bits / 2);
2811 } else {
2812 ret = bf_round(r, prec, flags) | (ret & BF_ST_INEXACT);
2813 break;
2814 }
2815 }
2816 }
2817 bf_delete(B);
2818 return ret;
2819 }
2820
to_digit(int c)2821 static inline int to_digit(int c)
2822 {
2823 if (c >= '0' && c <= '9')
2824 return c - '0';
2825 else if (c >= 'A' && c <= 'Z')
2826 return c - 'A' + 10;
2827 else if (c >= 'a' && c <= 'z')
2828 return c - 'a' + 10;
2829 else
2830 return 36;
2831 }
2832
2833 /* add a limb at 'pos' and decrement pos. new space is created if
2834 needed. Return 0 if OK, -1 if memory error */
bf_add_limb(bf_t * a,slimb_t * ppos,limb_t v)2835 static int bf_add_limb(bf_t *a, slimb_t *ppos, limb_t v)
2836 {
2837 slimb_t pos;
2838 pos = *ppos;
2839 if (unlikely(pos < 0)) {
2840 limb_t new_size, d, *new_tab;
2841 new_size = bf_max(a->len + 1, a->len * 3 / 2);
2842 new_tab = bf_realloc(a->ctx, a->tab, sizeof(limb_t) * new_size);
2843 if (!new_tab)
2844 return -1;
2845 a->tab = new_tab;
2846 d = new_size - a->len;
2847 memmove(a->tab + d, a->tab, a->len * sizeof(limb_t));
2848 a->len = new_size;
2849 pos += d;
2850 }
2851 a->tab[pos--] = v;
2852 *ppos = pos;
2853 return 0;
2854 }
2855
bf_tolower(int c)2856 static int bf_tolower(int c)
2857 {
2858 if (c >= 'A' && c <= 'Z')
2859 c = c - 'A' + 'a';
2860 return c;
2861 }
2862
strcasestart(const char * str,const char * val,const char ** ptr)2863 static int strcasestart(const char *str, const char *val, const char **ptr)
2864 {
2865 const char *p, *q;
2866 p = str;
2867 q = val;
2868 while (*q != '\0') {
2869 if (bf_tolower(*p) != *q)
2870 return 0;
2871 p++;
2872 q++;
2873 }
2874 if (ptr)
2875 *ptr = p;
2876 return 1;
2877 }
2878
bf_atof_internal(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)2879 static int bf_atof_internal(bf_t *r, slimb_t *pexponent,
2880 const char *str, const char **pnext, int radix,
2881 limb_t prec, bf_flags_t flags, BOOL is_dec)
2882 {
2883 const char *p, *p_start;
2884 int is_neg, radix_bits, exp_is_neg, ret, digits_per_limb, shift;
2885 limb_t cur_limb;
2886 slimb_t pos, expn, int_len, digit_count;
2887 BOOL has_decpt, is_bin_exp;
2888 bf_t a_s, *a;
2889
2890 *pexponent = 0;
2891 p = str;
2892 if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2893 strcasestart(p, "nan", &p)) {
2894 bf_set_nan(r);
2895 ret = 0;
2896 goto done;
2897 }
2898 is_neg = 0;
2899
2900 if (p[0] == '+') {
2901 p++;
2902 p_start = p;
2903 } else if (p[0] == '-') {
2904 is_neg = 1;
2905 p++;
2906 p_start = p;
2907 } else {
2908 p_start = p;
2909 }
2910 if (p[0] == '0') {
2911 if ((p[1] == 'x' || p[1] == 'X') &&
2912 (radix == 0 || radix == 16) &&
2913 !(flags & BF_ATOF_NO_HEX)) {
2914 radix = 16;
2915 p += 2;
2916 } else if ((p[1] == 'o' || p[1] == 'O') &&
2917 radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2918 p += 2;
2919 radix = 8;
2920 } else if ((p[1] == 'b' || p[1] == 'B') &&
2921 radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2922 p += 2;
2923 radix = 2;
2924 } else {
2925 goto no_prefix;
2926 }
2927 /* there must be a digit after the prefix */
2928 if (to_digit((uint8_t)*p) >= radix) {
2929 bf_set_nan(r);
2930 ret = 0;
2931 goto done;
2932 }
2933 no_prefix: ;
2934 } else {
2935 if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2936 strcasestart(p, "inf", &p)) {
2937 bf_set_inf(r, is_neg);
2938 ret = 0;
2939 goto done;
2940 }
2941 }
2942
2943 if (radix == 0)
2944 radix = 10;
2945 if (is_dec) {
2946 assert(radix == 10);
2947 radix_bits = 0;
2948 a = r;
2949 } else if ((radix & (radix - 1)) != 0) {
2950 radix_bits = 0; /* base is not a power of two */
2951 a = &a_s;
2952 bf_init(r->ctx, a);
2953 } else {
2954 radix_bits = ceil_log2(radix);
2955 a = r;
2956 }
2957
2958 /* skip leading zeros */
2959 /* XXX: could also skip zeros after the decimal point */
2960 while (*p == '0')
2961 p++;
2962
2963 if (radix_bits) {
2964 shift = digits_per_limb = LIMB_BITS;
2965 } else {
2966 radix_bits = 0;
2967 shift = digits_per_limb = digits_per_limb_table[radix - 2];
2968 }
2969 cur_limb = 0;
2970 bf_resize(a, 1);
2971 pos = 0;
2972 has_decpt = FALSE;
2973 int_len = digit_count = 0;
2974 for(;;) {
2975 limb_t c;
2976 if (*p == '.' && (p > p_start || to_digit(p[1]) < radix)) {
2977 if (has_decpt)
2978 break;
2979 has_decpt = TRUE;
2980 int_len = digit_count;
2981 p++;
2982 }
2983 c = to_digit(*p);
2984 if (c >= radix)
2985 break;
2986 digit_count++;
2987 p++;
2988 if (radix_bits) {
2989 shift -= radix_bits;
2990 if (shift <= 0) {
2991 cur_limb |= c >> (-shift);
2992 if (bf_add_limb(a, &pos, cur_limb))
2993 goto mem_error;
2994 if (shift < 0)
2995 cur_limb = c << (LIMB_BITS + shift);
2996 else
2997 cur_limb = 0;
2998 shift += LIMB_BITS;
2999 } else {
3000 cur_limb |= c << shift;
3001 }
3002 } else {
3003 cur_limb = cur_limb * radix + c;
3004 shift--;
3005 if (shift == 0) {
3006 if (bf_add_limb(a, &pos, cur_limb))
3007 goto mem_error;
3008 shift = digits_per_limb;
3009 cur_limb = 0;
3010 }
3011 }
3012 }
3013 if (!has_decpt)
3014 int_len = digit_count;
3015
3016 /* add the last limb and pad with zeros */
3017 if (shift != digits_per_limb) {
3018 if (radix_bits == 0) {
3019 while (shift != 0) {
3020 cur_limb *= radix;
3021 shift--;
3022 }
3023 }
3024 if (bf_add_limb(a, &pos, cur_limb)) {
3025 mem_error:
3026 ret = BF_ST_MEM_ERROR;
3027 if (!radix_bits)
3028 bf_delete(a);
3029 bf_set_nan(r);
3030 goto done;
3031 }
3032 }
3033
3034 /* reset the next limbs to zero (we prefer to reallocate in the
3035 renormalization) */
3036 memset(a->tab, 0, (pos + 1) * sizeof(limb_t));
3037
3038 if (p == p_start) {
3039 ret = 0;
3040 if (!radix_bits)
3041 bf_delete(a);
3042 bf_set_nan(r);
3043 goto done;
3044 }
3045
3046 /* parse the exponent, if any */
3047 expn = 0;
3048 is_bin_exp = FALSE;
3049 if (((radix == 10 && (*p == 'e' || *p == 'E')) ||
3050 (radix != 10 && (*p == '@' ||
3051 (radix_bits && (*p == 'p' || *p == 'P'))))) &&
3052 p > p_start) {
3053 is_bin_exp = (*p == 'p' || *p == 'P');
3054 p++;
3055 exp_is_neg = 0;
3056 if (*p == '+') {
3057 p++;
3058 } else if (*p == '-') {
3059 exp_is_neg = 1;
3060 p++;
3061 }
3062 for(;;) {
3063 int c;
3064 c = to_digit(*p);
3065 if (c >= 10)
3066 break;
3067 if (unlikely(expn > ((EXP_MAX - 2 - 9) / 10))) {
3068 /* exponent overflow */
3069 if (exp_is_neg) {
3070 bf_set_zero(r, is_neg);
3071 ret = BF_ST_UNDERFLOW | BF_ST_INEXACT;
3072 } else {
3073 bf_set_inf(r, is_neg);
3074 ret = BF_ST_OVERFLOW | BF_ST_INEXACT;
3075 }
3076 goto done;
3077 }
3078 p++;
3079 expn = expn * 10 + c;
3080 }
3081 if (exp_is_neg)
3082 expn = -expn;
3083 }
3084 if (is_dec) {
3085 a->expn = expn + int_len;
3086 a->sign = is_neg;
3087 ret = bfdec_normalize_and_round((bfdec_t *)a, prec, flags);
3088 } else if (radix_bits) {
3089 /* XXX: may overflow */
3090 if (!is_bin_exp)
3091 expn *= radix_bits;
3092 a->expn = expn + (int_len * radix_bits);
3093 a->sign = is_neg;
3094 ret = bf_normalize_and_round(a, prec, flags);
3095 } else {
3096 limb_t l;
3097 pos++;
3098 l = a->len - pos; /* number of limbs */
3099 if (l == 0) {
3100 bf_set_zero(r, is_neg);
3101 ret = 0;
3102 } else {
3103 bf_t T_s, *T = &T_s;
3104
3105 expn -= l * digits_per_limb - int_len;
3106 bf_init(r->ctx, T);
3107 if (bf_integer_from_radix(T, a->tab + pos, l, radix)) {
3108 bf_set_nan(r);
3109 ret = BF_ST_MEM_ERROR;
3110 } else {
3111 T->sign = is_neg;
3112 if (flags & BF_ATOF_EXPONENT) {
3113 /* return the exponent */
3114 *pexponent = expn;
3115 ret = bf_set(r, T);
3116 } else {
3117 ret = bf_mul_pow_radix(r, T, radix, expn, prec, flags);
3118 }
3119 }
3120 bf_delete(T);
3121 }
3122 bf_delete(a);
3123 }
3124 done:
3125 if (pnext)
3126 *pnext = p;
3127 return ret;
3128 }
3129
3130 /*
3131 Return (status, n, exp). 'status' is the floating point status. 'n'
3132 is the parsed number.
3133
3134 If (flags & BF_ATOF_EXPONENT) and if the radix is not a power of
3135 two, the parsed number is equal to r *
3136 (*pexponent)^radix. Otherwise *pexponent = 0.
3137 */
bf_atof2(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3138 int bf_atof2(bf_t *r, slimb_t *pexponent,
3139 const char *str, const char **pnext, int radix,
3140 limb_t prec, bf_flags_t flags)
3141 {
3142 return bf_atof_internal(r, pexponent, str, pnext, radix, prec, flags,
3143 FALSE);
3144 }
3145
bf_atof(bf_t * r,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3146 int bf_atof(bf_t *r, const char *str, const char **pnext, int radix,
3147 limb_t prec, bf_flags_t flags)
3148 {
3149 slimb_t dummy_exp;
3150 return bf_atof_internal(r, &dummy_exp, str, pnext, radix, prec, flags, FALSE);
3151 }
3152
3153 /* base conversion to radix */
3154
3155 #if LIMB_BITS == 64
3156 #define RADIXL_10 UINT64_C(10000000000000000000)
3157 #else
3158 #define RADIXL_10 UINT64_C(1000000000)
3159 #endif
3160
3161 static const uint32_t inv_log2_radix[BF_RADIX_MAX - 1][LIMB_BITS / 32 + 1] = {
3162 #if LIMB_BITS == 32
3163 { 0x80000000, 0x00000000,},
3164 { 0x50c24e60, 0xd4d4f4a7,},
3165 { 0x40000000, 0x00000000,},
3166 { 0x372068d2, 0x0a1ee5ca,},
3167 { 0x3184648d, 0xb8153e7a,},
3168 { 0x2d983275, 0x9d5369c4,},
3169 { 0x2aaaaaaa, 0xaaaaaaab,},
3170 { 0x28612730, 0x6a6a7a54,},
3171 { 0x268826a1, 0x3ef3fde6,},
3172 { 0x25001383, 0xbac8a744,},
3173 { 0x23b46706, 0x82c0c709,},
3174 { 0x229729f1, 0xb2c83ded,},
3175 { 0x219e7ffd, 0xa5ad572b,},
3176 { 0x20c33b88, 0xda7c29ab,},
3177 { 0x20000000, 0x00000000,},
3178 { 0x1f50b57e, 0xac5884b3,},
3179 { 0x1eb22cc6, 0x8aa6e26f,},
3180 { 0x1e21e118, 0x0c5daab2,},
3181 { 0x1d9dcd21, 0x439834e4,},
3182 { 0x1d244c78, 0x367a0d65,},
3183 { 0x1cb40589, 0xac173e0c,},
3184 { 0x1c4bd95b, 0xa8d72b0d,},
3185 { 0x1bead768, 0x98f8ce4c,},
3186 { 0x1b903469, 0x050f72e5,},
3187 { 0x1b3b433f, 0x2eb06f15,},
3188 { 0x1aeb6f75, 0x9c46fc38,},
3189 { 0x1aa038eb, 0x0e3bfd17,},
3190 { 0x1a593062, 0xb38d8c56,},
3191 { 0x1a15f4c3, 0x2b95a2e6,},
3192 { 0x19d630dc, 0xcc7ddef9,},
3193 { 0x19999999, 0x9999999a,},
3194 { 0x195fec80, 0x8a609431,},
3195 { 0x1928ee7b, 0x0b4f22f9,},
3196 { 0x18f46acf, 0x8c06e318,},
3197 { 0x18c23246, 0xdc0a9f3d,},
3198 #else
3199 { 0x80000000, 0x00000000, 0x00000000,},
3200 { 0x50c24e60, 0xd4d4f4a7, 0x021f57bc,},
3201 { 0x40000000, 0x00000000, 0x00000000,},
3202 { 0x372068d2, 0x0a1ee5ca, 0x19ea911b,},
3203 { 0x3184648d, 0xb8153e7a, 0x7fc2d2e1,},
3204 { 0x2d983275, 0x9d5369c4, 0x4dec1661,},
3205 { 0x2aaaaaaa, 0xaaaaaaaa, 0xaaaaaaab,},
3206 { 0x28612730, 0x6a6a7a53, 0x810fabde,},
3207 { 0x268826a1, 0x3ef3fde6, 0x23e2566b,},
3208 { 0x25001383, 0xbac8a744, 0x385a3349,},
3209 { 0x23b46706, 0x82c0c709, 0x3f891718,},
3210 { 0x229729f1, 0xb2c83ded, 0x15fba800,},
3211 { 0x219e7ffd, 0xa5ad572a, 0xe169744b,},
3212 { 0x20c33b88, 0xda7c29aa, 0x9bddee52,},
3213 { 0x20000000, 0x00000000, 0x00000000,},
3214 { 0x1f50b57e, 0xac5884b3, 0x70e28eee,},
3215 { 0x1eb22cc6, 0x8aa6e26f, 0x06d1a2a2,},
3216 { 0x1e21e118, 0x0c5daab1, 0x81b4f4bf,},
3217 { 0x1d9dcd21, 0x439834e3, 0x81667575,},
3218 { 0x1d244c78, 0x367a0d64, 0xc8204d6d,},
3219 { 0x1cb40589, 0xac173e0c, 0x3b7b16ba,},
3220 { 0x1c4bd95b, 0xa8d72b0d, 0x5879f25a,},
3221 { 0x1bead768, 0x98f8ce4c, 0x66cc2858,},
3222 { 0x1b903469, 0x050f72e5, 0x0cf5488e,},
3223 { 0x1b3b433f, 0x2eb06f14, 0x8c89719c,},
3224 { 0x1aeb6f75, 0x9c46fc37, 0xab5fc7e9,},
3225 { 0x1aa038eb, 0x0e3bfd17, 0x1bd62080,},
3226 { 0x1a593062, 0xb38d8c56, 0x7998ab45,},
3227 { 0x1a15f4c3, 0x2b95a2e6, 0x46aed6a0,},
3228 { 0x19d630dc, 0xcc7ddef9, 0x5aadd61b,},
3229 { 0x19999999, 0x99999999, 0x9999999a,},
3230 { 0x195fec80, 0x8a609430, 0xe1106014,},
3231 { 0x1928ee7b, 0x0b4f22f9, 0x5f69791d,},
3232 { 0x18f46acf, 0x8c06e318, 0x4d2aeb2c,},
3233 { 0x18c23246, 0xdc0a9f3d, 0x3fe16970,},
3234 #endif
3235 };
3236
3237 static const limb_t log2_radix[BF_RADIX_MAX - 1] = {
3238 #if LIMB_BITS == 32
3239 0x20000000,
3240 0x32b80347,
3241 0x40000000,
3242 0x4a4d3c26,
3243 0x52b80347,
3244 0x59d5d9fd,
3245 0x60000000,
3246 0x6570068e,
3247 0x6a4d3c26,
3248 0x6eb3a9f0,
3249 0x72b80347,
3250 0x766a008e,
3251 0x79d5d9fd,
3252 0x7d053f6d,
3253 0x80000000,
3254 0x82cc7edf,
3255 0x8570068e,
3256 0x87ef05ae,
3257 0x8a4d3c26,
3258 0x8c8ddd45,
3259 0x8eb3a9f0,
3260 0x90c10501,
3261 0x92b80347,
3262 0x949a784c,
3263 0x966a008e,
3264 0x982809d6,
3265 0x99d5d9fd,
3266 0x9b74948f,
3267 0x9d053f6d,
3268 0x9e88c6b3,
3269 0xa0000000,
3270 0xa16bad37,
3271 0xa2cc7edf,
3272 0xa4231623,
3273 0xa570068e,
3274 #else
3275 0x2000000000000000,
3276 0x32b803473f7ad0f4,
3277 0x4000000000000000,
3278 0x4a4d3c25e68dc57f,
3279 0x52b803473f7ad0f4,
3280 0x59d5d9fd5010b366,
3281 0x6000000000000000,
3282 0x6570068e7ef5a1e8,
3283 0x6a4d3c25e68dc57f,
3284 0x6eb3a9f01975077f,
3285 0x72b803473f7ad0f4,
3286 0x766a008e4788cbcd,
3287 0x79d5d9fd5010b366,
3288 0x7d053f6d26089673,
3289 0x8000000000000000,
3290 0x82cc7edf592262d0,
3291 0x8570068e7ef5a1e8,
3292 0x87ef05ae409a0289,
3293 0x8a4d3c25e68dc57f,
3294 0x8c8ddd448f8b845a,
3295 0x8eb3a9f01975077f,
3296 0x90c10500d63aa659,
3297 0x92b803473f7ad0f4,
3298 0x949a784bcd1b8afe,
3299 0x966a008e4788cbcd,
3300 0x982809d5be7072dc,
3301 0x99d5d9fd5010b366,
3302 0x9b74948f5532da4b,
3303 0x9d053f6d26089673,
3304 0x9e88c6b3626a72aa,
3305 0xa000000000000000,
3306 0xa16bad3758efd873,
3307 0xa2cc7edf592262d0,
3308 0xa4231623369e78e6,
3309 0xa570068e7ef5a1e8,
3310 #endif
3311 };
3312
3313 /* compute floor(a*b) or ceil(a*b) with b = log2(radix) or
3314 b=1/log2(radix). For is_inv = 0, strict accuracy is not guaranteed
3315 when radix is not a power of two. */
bf_mul_log2_radix(slimb_t a1,unsigned int radix,int is_inv,int is_ceil1)3316 slimb_t bf_mul_log2_radix(slimb_t a1, unsigned int radix, int is_inv,
3317 int is_ceil1)
3318 {
3319 int is_neg;
3320 limb_t a;
3321 BOOL is_ceil;
3322
3323 is_ceil = is_ceil1;
3324 a = a1;
3325 if (a1 < 0) {
3326 a = -a;
3327 is_neg = 1;
3328 } else {
3329 is_neg = 0;
3330 }
3331 is_ceil ^= is_neg;
3332 if ((radix & (radix - 1)) == 0) {
3333 int radix_bits;
3334 /* radix is a power of two */
3335 radix_bits = ceil_log2(radix);
3336 if (is_inv) {
3337 if (is_ceil)
3338 a += radix_bits - 1;
3339 a = a / radix_bits;
3340 } else {
3341 a = a * radix_bits;
3342 }
3343 } else {
3344 const uint32_t *tab;
3345 limb_t b0, b1;
3346 dlimb_t t;
3347
3348 if (is_inv) {
3349 tab = inv_log2_radix[radix - 2];
3350 #if LIMB_BITS == 32
3351 b1 = tab[0];
3352 b0 = tab[1];
3353 #else
3354 b1 = ((limb_t)tab[0] << 32) | tab[1];
3355 b0 = (limb_t)tab[2] << 32;
3356 #endif
3357 t = (dlimb_t)b0 * (dlimb_t)a;
3358 t = (dlimb_t)b1 * (dlimb_t)a + (t >> LIMB_BITS);
3359 a = t >> (LIMB_BITS - 1);
3360 } else {
3361 b0 = log2_radix[radix - 2];
3362 t = (dlimb_t)b0 * (dlimb_t)a;
3363 a = t >> (LIMB_BITS - 3);
3364 }
3365 /* a = floor(result) and 'result' cannot be an integer */
3366 a += is_ceil;
3367 }
3368 if (is_neg)
3369 a = -a;
3370 return a;
3371 }
3372
3373 /* 'n' is the number of output limbs */
bf_integer_to_radix_rec(bf_t * pow_tab,limb_t * out,const bf_t * a,limb_t n,int level,limb_t n0,limb_t radixl,unsigned int radixl_bits)3374 static void bf_integer_to_radix_rec(bf_t *pow_tab,
3375 limb_t *out, const bf_t *a, limb_t n,
3376 int level, limb_t n0, limb_t radixl,
3377 unsigned int radixl_bits)
3378 {
3379 limb_t n1, n2, q_prec;
3380 assert(n >= 1);
3381 if (n == 1) {
3382 out[0] = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
3383 } else if (n == 2) {
3384 dlimb_t t;
3385 slimb_t pos;
3386 pos = a->len * LIMB_BITS - a->expn;
3387 t = ((dlimb_t)get_bits(a->tab, a->len, pos + LIMB_BITS) << LIMB_BITS) |
3388 get_bits(a->tab, a->len, pos);
3389 if (likely(radixl == RADIXL_10)) {
3390 /* use division by a constant when possible */
3391 out[0] = t % RADIXL_10;
3392 out[1] = t / RADIXL_10;
3393 } else {
3394 out[0] = t % radixl;
3395 out[1] = t / radixl;
3396 }
3397 } else {
3398 bf_t Q, R, *B, *B_inv;
3399 int q_add;
3400 bf_init(a->ctx, &Q);
3401 bf_init(a->ctx, &R);
3402 n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
3403 n1 = n - n2;
3404 B = &pow_tab[2 * level];
3405 B_inv = &pow_tab[2 * level + 1];
3406 if (B->len == 0) {
3407 /* compute BASE^n2 */
3408 bf_pow_ui_ui(B, radixl, n2, BF_PREC_INF, BF_RNDZ);
3409 /* we use enough bits for the maximum possible 'n1' value,
3410 i.e. n2 + 1 */
3411 bf_set_ui(&R, 1);
3412 bf_div(B_inv, &R, B, (n2 + 1) * radixl_bits + 2, BF_RNDN);
3413 }
3414 // printf("%d: n1=% " PRId64 " n2=%" PRId64 "\n", level, n1, n2);
3415 q_prec = n1 * radixl_bits;
3416 bf_mul(&Q, a, B_inv, q_prec, BF_RNDN);
3417 bf_rint(&Q, BF_RNDZ);
3418
3419 bf_mul(&R, &Q, B, BF_PREC_INF, BF_RNDZ);
3420 bf_sub(&R, a, &R, BF_PREC_INF, BF_RNDZ);
3421 /* adjust if necessary */
3422 q_add = 0;
3423 while (R.sign && R.len != 0) {
3424 bf_add(&R, &R, B, BF_PREC_INF, BF_RNDZ);
3425 q_add--;
3426 }
3427 while (bf_cmpu(&R, B) >= 0) {
3428 bf_sub(&R, &R, B, BF_PREC_INF, BF_RNDZ);
3429 q_add++;
3430 }
3431 if (q_add != 0) {
3432 bf_add_si(&Q, &Q, q_add, BF_PREC_INF, BF_RNDZ);
3433 }
3434 bf_integer_to_radix_rec(pow_tab, out + n2, &Q, n1, level + 1, n0,
3435 radixl, radixl_bits);
3436 bf_integer_to_radix_rec(pow_tab, out, &R, n2, level + 1, n0,
3437 radixl, radixl_bits);
3438 bf_delete(&Q);
3439 bf_delete(&R);
3440 }
3441 }
3442
bf_integer_to_radix(bf_t * r,const bf_t * a,limb_t radixl)3443 static void bf_integer_to_radix(bf_t *r, const bf_t *a, limb_t radixl)
3444 {
3445 bf_context_t *s = r->ctx;
3446 limb_t r_len;
3447 bf_t *pow_tab;
3448 int i, pow_tab_len;
3449
3450 r_len = r->len;
3451 pow_tab_len = (ceil_log2(r_len) + 2) * 2; /* XXX: check */
3452 pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
3453 for(i = 0; i < pow_tab_len; i++)
3454 bf_init(r->ctx, &pow_tab[i]);
3455
3456 bf_integer_to_radix_rec(pow_tab, r->tab, a, r_len, 0, r_len, radixl,
3457 ceil_log2(radixl));
3458
3459 for(i = 0; i < pow_tab_len; i++) {
3460 bf_delete(&pow_tab[i]);
3461 }
3462 bf_free(s, pow_tab);
3463 }
3464
3465 /* a must be >= 0. 'P' is the wanted number of digits in radix
3466 'radix'. 'r' is the mantissa represented as an integer. *pE
3467 contains the exponent. Return != 0 if memory error. */
bf_convert_to_radix(bf_t * r,slimb_t * pE,const bf_t * a,int radix,limb_t P,bf_rnd_t rnd_mode,BOOL is_fixed_exponent)3468 static int bf_convert_to_radix(bf_t *r, slimb_t *pE,
3469 const bf_t *a, int radix,
3470 limb_t P, bf_rnd_t rnd_mode,
3471 BOOL is_fixed_exponent)
3472 {
3473 slimb_t E, e, prec, extra_bits, ziv_extra_bits, prec0;
3474 bf_t B_s, *B = &B_s;
3475 int e_sign, ret, res;
3476
3477 if (a->len == 0) {
3478 /* zero case */
3479 *pE = 0;
3480 return bf_set(r, a);
3481 }
3482
3483 if (is_fixed_exponent) {
3484 E = *pE;
3485 } else {
3486 /* compute the new exponent */
3487 E = 1 + bf_mul_log2_radix(a->expn - 1, radix, TRUE, FALSE);
3488 }
3489 // bf_print_str("a", a);
3490 // printf("E=%ld P=%ld radix=%d\n", E, P, radix);
3491
3492 for(;;) {
3493 e = P - E;
3494 e_sign = 0;
3495 if (e < 0) {
3496 e = -e;
3497 e_sign = 1;
3498 }
3499 /* Note: precision for log2(radix) is not critical here */
3500 prec0 = bf_mul_log2_radix(P, radix, FALSE, TRUE);
3501 ziv_extra_bits = 16;
3502 for(;;) {
3503 prec = prec0 + ziv_extra_bits;
3504 /* XXX: rigorous error analysis needed */
3505 extra_bits = ceil_log2(e) * 2 + 1;
3506 ret = bf_pow_ui_ui(r, radix, e, prec + extra_bits, BF_RNDN);
3507 if (!e_sign)
3508 ret |= bf_mul(r, r, a, prec + extra_bits, BF_RNDN);
3509 else
3510 ret |= bf_div(r, a, r, prec + extra_bits, BF_RNDN);
3511 if (ret & BF_ST_MEM_ERROR)
3512 return BF_ST_MEM_ERROR;
3513 /* if the result is not exact, check that it can be safely
3514 rounded to an integer */
3515 if ((ret & BF_ST_INEXACT) &&
3516 !bf_can_round(r, r->expn, rnd_mode, prec)) {
3517 /* and more precision and retry */
3518 ziv_extra_bits = ziv_extra_bits + (ziv_extra_bits / 2);
3519 continue;
3520 } else {
3521 ret = bf_rint(r, rnd_mode);
3522 if (ret & BF_ST_MEM_ERROR)
3523 return BF_ST_MEM_ERROR;
3524 break;
3525 }
3526 }
3527 if (is_fixed_exponent)
3528 break;
3529 /* check that the result is < B^P */
3530 /* XXX: do a fast approximate test first ? */
3531 bf_init(r->ctx, B);
3532 ret = bf_pow_ui_ui(B, radix, P, BF_PREC_INF, BF_RNDZ);
3533 if (ret) {
3534 bf_delete(B);
3535 return ret;
3536 }
3537 res = bf_cmpu(r, B);
3538 bf_delete(B);
3539 if (res < 0)
3540 break;
3541 /* try a larger exponent */
3542 E++;
3543 }
3544 *pE = E;
3545 return 0;
3546 }
3547
limb_to_a(char * buf,limb_t n,unsigned int radix,int len)3548 static void limb_to_a(char *buf, limb_t n, unsigned int radix, int len)
3549 {
3550 int digit, i;
3551
3552 if (radix == 10) {
3553 /* specific case with constant divisor */
3554 for(i = len - 1; i >= 0; i--) {
3555 digit = (limb_t)n % 10;
3556 n = (limb_t)n / 10;
3557 buf[i] = digit + '0';
3558 }
3559 } else {
3560 for(i = len - 1; i >= 0; i--) {
3561 digit = (limb_t)n % radix;
3562 n = (limb_t)n / radix;
3563 if (digit < 10)
3564 digit += '0';
3565 else
3566 digit += 'a' - 10;
3567 buf[i] = digit;
3568 }
3569 }
3570 }
3571
3572 /* for power of 2 radixes */
limb_to_a2(char * buf,limb_t n,unsigned int radix_bits,int len)3573 static void limb_to_a2(char *buf, limb_t n, unsigned int radix_bits, int len)
3574 {
3575 int digit, i;
3576 unsigned int mask;
3577
3578 mask = (1 << radix_bits) - 1;
3579 for(i = len - 1; i >= 0; i--) {
3580 digit = n & mask;
3581 n >>= radix_bits;
3582 if (digit < 10)
3583 digit += '0';
3584 else
3585 digit += 'a' - 10;
3586 buf[i] = digit;
3587 }
3588 }
3589
3590 /* 'a' must be an integer if the is_dec = FALSE or if the radix is not
3591 a power of two. A dot is added before the 'dot_pos' digit. dot_pos
3592 = n_digits does not display the dot. 0 <= dot_pos <=
3593 n_digits. n_digits >= 1. */
output_digits(DynBuf * s,const bf_t * a1,int radix,limb_t n_digits,limb_t dot_pos,BOOL is_dec)3594 static void output_digits(DynBuf *s, const bf_t *a1, int radix, limb_t n_digits,
3595 limb_t dot_pos, BOOL is_dec)
3596 {
3597 limb_t i, v, l;
3598 slimb_t pos, pos_incr;
3599 int digits_per_limb, buf_pos, radix_bits, first_buf_pos;
3600 char buf[65];
3601 bf_t a_s, *a;
3602
3603 if (is_dec) {
3604 digits_per_limb = LIMB_DIGITS;
3605 a = (bf_t *)a1;
3606 radix_bits = 0;
3607 pos = a->len;
3608 pos_incr = 1;
3609 first_buf_pos = 0;
3610 } else if ((radix & (radix - 1)) != 0) {
3611 limb_t n, radixl;
3612
3613 digits_per_limb = digits_per_limb_table[radix - 2];
3614 radixl = get_limb_radix(radix);
3615 a = &a_s;
3616 bf_init(a1->ctx, a);
3617 n = (n_digits + digits_per_limb - 1) / digits_per_limb;
3618 bf_resize(a, n);
3619 bf_integer_to_radix(a, a1, radixl);
3620 radix_bits = 0;
3621 pos = n;
3622 pos_incr = 1;
3623 first_buf_pos = pos * digits_per_limb - n_digits;
3624 } else {
3625 a = (bf_t *)a1;
3626 radix_bits = ceil_log2(radix);
3627 digits_per_limb = LIMB_BITS / radix_bits;
3628 pos_incr = digits_per_limb * radix_bits;
3629 pos = a->len * LIMB_BITS - a->expn + n_digits * radix_bits;
3630 first_buf_pos = 0;
3631 }
3632 buf_pos = digits_per_limb;
3633 i = 0;
3634 while (i < n_digits) {
3635 if (buf_pos == digits_per_limb) {
3636 pos -= pos_incr;
3637 if (radix_bits == 0) {
3638 v = get_limbz(a, pos);
3639 limb_to_a(buf, v, radix, digits_per_limb);
3640 } else {
3641 v = get_bits(a->tab, a->len, pos);
3642 limb_to_a2(buf, v, radix_bits, digits_per_limb);
3643 }
3644 buf_pos = first_buf_pos;
3645 first_buf_pos = 0;
3646 }
3647 if (i < dot_pos) {
3648 l = dot_pos;
3649 } else {
3650 if (i == dot_pos)
3651 dbuf_putc(s, '.');
3652 l = n_digits;
3653 }
3654 l = bf_min(digits_per_limb - buf_pos, l - i);
3655 dbuf_put(s, (uint8_t *)(buf + buf_pos), l);
3656 buf_pos += l;
3657 i += l;
3658 }
3659 if (a != a1)
3660 bf_delete(a);
3661 }
3662
bf_dbuf_realloc(void * opaque,void * ptr,size_t size)3663 static void *bf_dbuf_realloc(void *opaque, void *ptr, size_t size)
3664 {
3665 bf_context_t *s = opaque;
3666 return bf_realloc(s, ptr, size);
3667 }
3668
3669 /* return the length in bytes. A trailing '\0' is added */
bf_ftoa_internal(size_t * plen,const bf_t * a2,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)3670 static char *bf_ftoa_internal(size_t *plen, const bf_t *a2, int radix,
3671 limb_t prec, bf_flags_t flags, BOOL is_dec)
3672 {
3673 DynBuf s_s, *s = &s_s;
3674 int radix_bits;
3675
3676 // bf_print_str("ftoa", a2);
3677 // printf("radix=%d\n", radix);
3678 dbuf_init2(s, a2->ctx, bf_dbuf_realloc);
3679 if (a2->expn == BF_EXP_NAN) {
3680 dbuf_putstr(s, "NaN");
3681 } else {
3682 if (a2->sign)
3683 dbuf_putc(s, '-');
3684 if (a2->expn == BF_EXP_INF) {
3685 if (flags & BF_FTOA_JS_QUIRKS)
3686 dbuf_putstr(s, "Infinity");
3687 else
3688 dbuf_putstr(s, "Inf");
3689 } else {
3690 int fmt, ret;
3691 slimb_t n_digits, n, i, n_max, n1;
3692 bf_t a1_s, *a1;
3693 bf_t a_s, *a = &a_s;
3694
3695 /* make a positive number */
3696 a->tab = a2->tab;
3697 a->len = a2->len;
3698 a->expn = a2->expn;
3699 a->sign = 0;
3700
3701 if ((radix & (radix - 1)) != 0)
3702 radix_bits = 0;
3703 else
3704 radix_bits = ceil_log2(radix);
3705
3706 fmt = flags & BF_FTOA_FORMAT_MASK;
3707 a1 = &a1_s;
3708 bf_init(a2->ctx, a1);
3709 if (fmt == BF_FTOA_FORMAT_FRAC) {
3710 size_t pos, start;
3711 assert(!is_dec);
3712 /* one more digit for the rounding */
3713 n = 1 + bf_mul_log2_radix(bf_max(a->expn, 0), radix, TRUE, TRUE);
3714 n_digits = n + prec;
3715 n1 = n;
3716 if (bf_convert_to_radix(a1, &n1, a, radix, n_digits,
3717 flags & BF_RND_MASK, TRUE))
3718 goto fail1;
3719 start = s->size;
3720 output_digits(s, a1, radix, n_digits, n, is_dec);
3721 /* remove leading zeros because we allocated one more digit */
3722 pos = start;
3723 while ((pos + 1) < s->size && s->buf[pos] == '0' &&
3724 s->buf[pos + 1] != '.')
3725 pos++;
3726 if (pos > start) {
3727 memmove(s->buf + start, s->buf + pos, s->size - pos);
3728 s->size -= (pos - start);
3729 }
3730 } else {
3731 #ifdef USE_BF_DEC
3732 if (is_dec) {
3733 if (fmt == BF_FTOA_FORMAT_FIXED) {
3734 n_digits = prec;
3735 n_max = n_digits;
3736 } else {
3737 /* prec is ignored */
3738 prec = n_digits = a->len * LIMB_DIGITS;
3739 /* remove trailing zero digits */
3740 while (n_digits > 1 &&
3741 get_digit(a->tab, a->len, prec - n_digits) == 0) {
3742 n_digits--;
3743 }
3744 n_max = n_digits + 4;
3745 }
3746 bf_init(a2->ctx, a1);
3747 bf_set(a1, a);
3748 n = a1->expn;
3749 } else
3750 #endif
3751 {
3752 if (fmt == BF_FTOA_FORMAT_FIXED) {
3753 n_digits = prec;
3754 n_max = n_digits;
3755 } else {
3756 slimb_t n_digits_max, n_digits_min;
3757
3758 if (prec == BF_PREC_INF) {
3759 assert(radix_bits != 0);
3760 /* XXX: could use the exact number of bits */
3761 prec = a->len * LIMB_BITS;
3762 }
3763 n_digits = 1 + bf_mul_log2_radix(prec, radix, TRUE, TRUE);
3764 /* max number of digits for non exponential
3765 notation. The rational is to have the same rule
3766 as JS i.e. n_max = 21 for 64 bit float in base 10. */
3767 n_max = n_digits + 4;
3768 if (fmt == BF_FTOA_FORMAT_FREE_MIN) {
3769 bf_t b_s, *b = &b_s;
3770
3771 /* find the minimum number of digits by
3772 dichotomy. */
3773 /* XXX: inefficient */
3774 n_digits_max = n_digits;
3775 n_digits_min = 1;
3776 bf_init(a2->ctx, b);
3777 while (n_digits_min < n_digits_max) {
3778 n_digits = (n_digits_min + n_digits_max) / 2;
3779 if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3780 flags & BF_RND_MASK, FALSE)) {
3781 bf_delete(b);
3782 goto fail1;
3783 }
3784 /* convert back to a number and compare */
3785 ret = bf_mul_pow_radix(b, a1, radix, n - n_digits,
3786 prec,
3787 (flags & ~BF_RND_MASK) |
3788 BF_RNDN);
3789 if (ret & BF_ST_MEM_ERROR) {
3790 bf_delete(b);
3791 goto fail1;
3792 }
3793 if (bf_cmpu(b, a) == 0) {
3794 n_digits_max = n_digits;
3795 } else {
3796 n_digits_min = n_digits + 1;
3797 }
3798 }
3799 bf_delete(b);
3800 n_digits = n_digits_max;
3801 }
3802 }
3803 if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3804 flags & BF_RND_MASK, FALSE)) {
3805 fail1:
3806 bf_delete(a1);
3807 goto fail;
3808 }
3809 }
3810 if (a1->expn == BF_EXP_ZERO &&
3811 fmt != BF_FTOA_FORMAT_FIXED &&
3812 !(flags & BF_FTOA_FORCE_EXP)) {
3813 /* just output zero */
3814 dbuf_putstr(s, "0");
3815 } else {
3816 if (flags & BF_FTOA_ADD_PREFIX) {
3817 if (radix == 16)
3818 dbuf_putstr(s, "0x");
3819 else if (radix == 8)
3820 dbuf_putstr(s, "0o");
3821 else if (radix == 2)
3822 dbuf_putstr(s, "0b");
3823 }
3824 if (a1->expn == BF_EXP_ZERO)
3825 n = 1;
3826 if ((flags & BF_FTOA_FORCE_EXP) ||
3827 n <= -6 || n > n_max) {
3828 const char *fmt;
3829 /* exponential notation */
3830 output_digits(s, a1, radix, n_digits, 1, is_dec);
3831 if (radix_bits != 0 && radix <= 16) {
3832 if (flags & BF_FTOA_JS_QUIRKS)
3833 fmt = "p%+" PRId_LIMB;
3834 else
3835 fmt = "p%" PRId_LIMB;
3836 dbuf_printf(s, fmt, (n - 1) * radix_bits);
3837 } else {
3838 if (flags & BF_FTOA_JS_QUIRKS)
3839 fmt = "%c%+" PRId_LIMB;
3840 else
3841 fmt = "%c%" PRId_LIMB;
3842 dbuf_printf(s, fmt,
3843 radix <= 10 ? 'e' : '@', n - 1);
3844 }
3845 } else if (n <= 0) {
3846 /* 0.x */
3847 dbuf_putstr(s, "0.");
3848 for(i = 0; i < -n; i++) {
3849 dbuf_putc(s, '0');
3850 }
3851 output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3852 } else {
3853 if (n_digits <= n) {
3854 /* no dot */
3855 output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3856 for(i = 0; i < (n - n_digits); i++)
3857 dbuf_putc(s, '0');
3858 } else {
3859 output_digits(s, a1, radix, n_digits, n, is_dec);
3860 }
3861 }
3862 }
3863 }
3864 bf_delete(a1);
3865 }
3866 }
3867 dbuf_putc(s, '\0');
3868 if (dbuf_error(s))
3869 goto fail;
3870 if (plen)
3871 *plen = s->size - 1;
3872 return (char *)s->buf;
3873 fail:
3874 bf_free(a2->ctx, s->buf);
3875 if (plen)
3876 *plen = 0;
3877 return NULL;
3878 }
3879
bf_ftoa(size_t * plen,const bf_t * a,int radix,limb_t prec,bf_flags_t flags)3880 char *bf_ftoa(size_t *plen, const bf_t *a, int radix, limb_t prec,
3881 bf_flags_t flags)
3882 {
3883 return bf_ftoa_internal(plen, a, radix, prec, flags, FALSE);
3884 }
3885
3886 /***************************************************************/
3887 /* transcendental functions */
3888
3889 /* Note: the algorithm is from MPFR */
bf_const_log2_rec(bf_t * T,bf_t * P,bf_t * Q,limb_t n1,limb_t n2,BOOL need_P)3890 static void bf_const_log2_rec(bf_t *T, bf_t *P, bf_t *Q, limb_t n1,
3891 limb_t n2, BOOL need_P)
3892 {
3893 bf_context_t *s = T->ctx;
3894 if ((n2 - n1) == 1) {
3895 if (n1 == 0) {
3896 bf_set_ui(P, 3);
3897 } else {
3898 bf_set_ui(P, n1);
3899 P->sign = 1;
3900 }
3901 bf_set_ui(Q, 2 * n1 + 1);
3902 Q->expn += 2;
3903 bf_set(T, P);
3904 } else {
3905 limb_t m;
3906 bf_t T1_s, *T1 = &T1_s;
3907 bf_t P1_s, *P1 = &P1_s;
3908 bf_t Q1_s, *Q1 = &Q1_s;
3909
3910 m = n1 + ((n2 - n1) >> 1);
3911 bf_const_log2_rec(T, P, Q, n1, m, TRUE);
3912 bf_init(s, T1);
3913 bf_init(s, P1);
3914 bf_init(s, Q1);
3915 bf_const_log2_rec(T1, P1, Q1, m, n2, need_P);
3916 bf_mul(T, T, Q1, BF_PREC_INF, BF_RNDZ);
3917 bf_mul(T1, T1, P, BF_PREC_INF, BF_RNDZ);
3918 bf_add(T, T, T1, BF_PREC_INF, BF_RNDZ);
3919 if (need_P)
3920 bf_mul(P, P, P1, BF_PREC_INF, BF_RNDZ);
3921 bf_mul(Q, Q, Q1, BF_PREC_INF, BF_RNDZ);
3922 bf_delete(T1);
3923 bf_delete(P1);
3924 bf_delete(Q1);
3925 }
3926 }
3927
3928 /* compute log(2) with faithful rounding at precision 'prec' */
bf_const_log2_internal(bf_t * T,limb_t prec)3929 static void bf_const_log2_internal(bf_t *T, limb_t prec)
3930 {
3931 limb_t w, N;
3932 bf_t P_s, *P = &P_s;
3933 bf_t Q_s, *Q = &Q_s;
3934
3935 w = prec + 15;
3936 N = w / 3 + 1;
3937 bf_init(T->ctx, P);
3938 bf_init(T->ctx, Q);
3939 bf_const_log2_rec(T, P, Q, 0, N, FALSE);
3940 bf_div(T, T, Q, prec, BF_RNDN);
3941 bf_delete(P);
3942 bf_delete(Q);
3943 }
3944
3945 /* PI constant */
3946
3947 #define CHUD_A 13591409
3948 #define CHUD_B 545140134
3949 #define CHUD_C 640320
3950 #define CHUD_BITS_PER_TERM 47
3951
chud_bs(bf_t * P,bf_t * Q,bf_t * G,int64_t a,int64_t b,int need_g,limb_t prec)3952 static void chud_bs(bf_t *P, bf_t *Q, bf_t *G, int64_t a, int64_t b, int need_g,
3953 limb_t prec)
3954 {
3955 bf_context_t *s = P->ctx;
3956 int64_t c;
3957
3958 if (a == (b - 1)) {
3959 bf_t T0, T1;
3960
3961 bf_init(s, &T0);
3962 bf_init(s, &T1);
3963 bf_set_ui(G, 2 * b - 1);
3964 bf_mul_ui(G, G, 6 * b - 1, prec, BF_RNDN);
3965 bf_mul_ui(G, G, 6 * b - 5, prec, BF_RNDN);
3966 bf_set_ui(&T0, CHUD_B);
3967 bf_mul_ui(&T0, &T0, b, prec, BF_RNDN);
3968 bf_set_ui(&T1, CHUD_A);
3969 bf_add(&T0, &T0, &T1, prec, BF_RNDN);
3970 bf_mul(P, G, &T0, prec, BF_RNDN);
3971 P->sign = b & 1;
3972
3973 bf_set_ui(Q, b);
3974 bf_mul_ui(Q, Q, b, prec, BF_RNDN);
3975 bf_mul_ui(Q, Q, b, prec, BF_RNDN);
3976 bf_mul_ui(Q, Q, (uint64_t)CHUD_C * CHUD_C * CHUD_C / 24, prec, BF_RNDN);
3977 bf_delete(&T0);
3978 bf_delete(&T1);
3979 } else {
3980 bf_t P2, Q2, G2;
3981
3982 bf_init(s, &P2);
3983 bf_init(s, &Q2);
3984 bf_init(s, &G2);
3985
3986 c = (a + b) / 2;
3987 chud_bs(P, Q, G, a, c, 1, prec);
3988 chud_bs(&P2, &Q2, &G2, c, b, need_g, prec);
3989
3990 /* Q = Q1 * Q2 */
3991 /* G = G1 * G2 */
3992 /* P = P1 * Q2 + P2 * G1 */
3993 bf_mul(&P2, &P2, G, prec, BF_RNDN);
3994 if (!need_g)
3995 bf_set_ui(G, 0);
3996 bf_mul(P, P, &Q2, prec, BF_RNDN);
3997 bf_add(P, P, &P2, prec, BF_RNDN);
3998 bf_delete(&P2);
3999
4000 bf_mul(Q, Q, &Q2, prec, BF_RNDN);
4001 bf_delete(&Q2);
4002 if (need_g)
4003 bf_mul(G, G, &G2, prec, BF_RNDN);
4004 bf_delete(&G2);
4005 }
4006 }
4007
4008 /* compute Pi with faithful rounding at precision 'prec' using the
4009 Chudnovsky formula */
bf_const_pi_internal(bf_t * Q,limb_t prec)4010 static void bf_const_pi_internal(bf_t *Q, limb_t prec)
4011 {
4012 bf_context_t *s = Q->ctx;
4013 int64_t n, prec1;
4014 bf_t P, G;
4015
4016 /* number of serie terms */
4017 n = prec / CHUD_BITS_PER_TERM + 1;
4018 /* XXX: precision analysis */
4019 prec1 = prec + 32;
4020
4021 bf_init(s, &P);
4022 bf_init(s, &G);
4023
4024 chud_bs(&P, Q, &G, 0, n, 0, BF_PREC_INF);
4025
4026 bf_mul_ui(&G, Q, CHUD_A, prec1, BF_RNDN);
4027 bf_add(&P, &G, &P, prec1, BF_RNDN);
4028 bf_div(Q, Q, &P, prec1, BF_RNDF);
4029
4030 bf_set_ui(&P, CHUD_C);
4031 bf_sqrt(&G, &P, prec1, BF_RNDF);
4032 bf_mul_ui(&G, &G, (uint64_t)CHUD_C / 12, prec1, BF_RNDF);
4033 bf_mul(Q, Q, &G, prec, BF_RNDN);
4034 bf_delete(&P);
4035 bf_delete(&G);
4036 }
4037
bf_const_get(bf_t * T,limb_t prec,bf_flags_t flags,BFConstCache * c,void (* func)(bf_t * res,limb_t prec))4038 static int bf_const_get(bf_t *T, limb_t prec, bf_flags_t flags,
4039 BFConstCache *c,
4040 void (*func)(bf_t *res, limb_t prec))
4041 {
4042 limb_t ziv_extra_bits, prec1;
4043
4044 ziv_extra_bits = 32;
4045 for(;;) {
4046 prec1 = prec + ziv_extra_bits;
4047 if (c->prec < prec1) {
4048 if (c->val.len == 0)
4049 bf_init(T->ctx, &c->val);
4050 func(&c->val, prec1);
4051 c->prec = prec1;
4052 } else {
4053 prec1 = c->prec;
4054 }
4055 bf_set(T, &c->val);
4056 if (!bf_can_round(T, prec, flags & BF_RND_MASK, prec1)) {
4057 /* and more precision and retry */
4058 ziv_extra_bits = ziv_extra_bits + (ziv_extra_bits / 2);
4059 } else {
4060 break;
4061 }
4062 }
4063 return bf_round(T, prec, flags);
4064 }
4065
bf_const_free(BFConstCache * c)4066 static void bf_const_free(BFConstCache *c)
4067 {
4068 bf_delete(&c->val);
4069 memset(c, 0, sizeof(*c));
4070 }
4071
bf_const_log2(bf_t * T,limb_t prec,bf_flags_t flags)4072 int bf_const_log2(bf_t *T, limb_t prec, bf_flags_t flags)
4073 {
4074 bf_context_t *s = T->ctx;
4075 return bf_const_get(T, prec, flags, &s->log2_cache, bf_const_log2_internal);
4076 }
4077
bf_const_pi(bf_t * T,limb_t prec,bf_flags_t flags)4078 int bf_const_pi(bf_t *T, limb_t prec, bf_flags_t flags)
4079 {
4080 bf_context_t *s = T->ctx;
4081 return bf_const_get(T, prec, flags, &s->pi_cache, bf_const_pi_internal);
4082 }
4083
bf_clear_cache(bf_context_t * s)4084 void bf_clear_cache(bf_context_t *s)
4085 {
4086 #ifdef USE_FFT_MUL
4087 fft_clear_cache(s);
4088 #endif
4089 bf_const_free(&s->log2_cache);
4090 bf_const_free(&s->pi_cache);
4091 }
4092
4093 /* ZivFunc should compute the result 'r' with faithful rounding at
4094 precision 'prec'. For efficiency purposes, the final bf_round()
4095 does not need to be done in the function. */
4096 typedef int ZivFunc(bf_t *r, const bf_t *a, limb_t prec, void *opaque);
4097
bf_ziv_rounding(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags,ZivFunc * f,void * opaque)4098 static int bf_ziv_rounding(bf_t *r, const bf_t *a,
4099 limb_t prec, bf_flags_t flags,
4100 ZivFunc *f, void *opaque)
4101 {
4102 int rnd_mode, ret;
4103 slimb_t prec1, ziv_extra_bits;
4104
4105 rnd_mode = flags & BF_RND_MASK;
4106 if (rnd_mode == BF_RNDF) {
4107 /* no need to iterate */
4108 f(r, a, prec, opaque);
4109 ret = 0;
4110 } else {
4111 ziv_extra_bits = 32;
4112 for(;;) {
4113 prec1 = prec + ziv_extra_bits;
4114 ret = f(r, a, prec1, opaque);
4115 if (ret & (BF_ST_OVERFLOW | BF_ST_UNDERFLOW | BF_ST_MEM_ERROR)) {
4116 /* overflow or underflow should never happen because
4117 it indicates the rounding cannot be done correctly,
4118 but we do not catch all the cases */
4119 return ret;
4120 }
4121 /* if the result is exact, we can stop */
4122 if (!(ret & BF_ST_INEXACT)) {
4123 ret = 0;
4124 break;
4125 }
4126 if (bf_can_round(r, prec, rnd_mode, prec1)) {
4127 ret = BF_ST_INEXACT;
4128 break;
4129 }
4130 ziv_extra_bits = ziv_extra_bits * 2;
4131 }
4132 }
4133 return bf_round(r, prec, flags) | ret;
4134 }
4135
4136 /* Compute the exponential using faithful rounding at precision 'prec'.
4137 Note: the algorithm is from MPFR */
bf_exp_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4138 static int bf_exp_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4139 {
4140 bf_context_t *s = r->ctx;
4141 bf_t T_s, *T = &T_s;
4142 slimb_t n, K, l, i, prec1;
4143
4144 assert(r != a);
4145
4146 /* argument reduction:
4147 T = a - n*log(2) with 0 <= T < log(2) and n integer.
4148 */
4149 bf_init(s, T);
4150 if (a->expn <= -1) {
4151 /* 0 <= abs(a) <= 0.5 */
4152 if (a->sign)
4153 n = -1;
4154 else
4155 n = 0;
4156 } else {
4157 bf_const_log2(T, LIMB_BITS, BF_RNDZ);
4158 bf_div(T, a, T, LIMB_BITS, BF_RNDD);
4159 bf_get_limb(&n, T, 0);
4160 }
4161
4162 K = bf_isqrt((prec + 1) / 2);
4163 l = (prec - 1) / K + 1;
4164 /* XXX: precision analysis ? */
4165 prec1 = prec + (K + 2 * l + 18) + K + 8;
4166 if (a->expn > 0)
4167 prec1 += a->expn;
4168 // printf("n=%ld K=%ld prec1=%ld\n", n, K, prec1);
4169
4170 bf_const_log2(T, prec1, BF_RNDF);
4171 bf_mul_si(T, T, n, prec1, BF_RNDN);
4172 bf_sub(T, a, T, prec1, BF_RNDN);
4173
4174 /* reduce the range of T */
4175 bf_mul_2exp(T, -K, BF_PREC_INF, BF_RNDZ);
4176
4177 /* Taylor expansion around zero :
4178 1 + x + x^2/2 + ... + x^n/n!
4179 = (1 + x * (1 + x/2 * (1 + ... (x/n))))
4180 */
4181 {
4182 bf_t U_s, *U = &U_s;
4183
4184 bf_init(s, U);
4185 bf_set_ui(r, 1);
4186 for(i = l ; i >= 1; i--) {
4187 bf_set_ui(U, i);
4188 bf_div(U, T, U, prec1, BF_RNDN);
4189 bf_mul(r, r, U, prec1, BF_RNDN);
4190 bf_add_si(r, r, 1, prec1, BF_RNDN);
4191 }
4192 bf_delete(U);
4193 }
4194 bf_delete(T);
4195
4196 /* undo the range reduction */
4197 for(i = 0; i < K; i++) {
4198 bf_mul(r, r, r, prec1, BF_RNDN);
4199 }
4200
4201 /* undo the argument reduction */
4202 bf_mul_2exp(r, n, BF_PREC_INF, BF_RNDZ);
4203
4204 return BF_ST_INEXACT;
4205 }
4206
bf_exp(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4207 int bf_exp(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4208 {
4209 bf_context_t *s = r->ctx;
4210 assert(r != a);
4211 if (a->len == 0) {
4212 if (a->expn == BF_EXP_NAN) {
4213 bf_set_nan(r);
4214 } else if (a->expn == BF_EXP_INF) {
4215 if (a->sign)
4216 bf_set_zero(r, 0);
4217 else
4218 bf_set_inf(r, 0);
4219 } else {
4220 bf_set_ui(r, 1);
4221 }
4222 return 0;
4223 }
4224
4225 /* crude overflow and underflow tests */
4226 if (a->expn > 0) {
4227 bf_t T_s, *T = &T_s;
4228 bf_t log2_s, *log2 = &log2_s;
4229 slimb_t e_min, e_max;
4230 e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
4231 e_min = -e_max + 3;
4232 if (flags & BF_FLAG_SUBNORMAL)
4233 e_min -= (prec - 1);
4234
4235 bf_init(s, T);
4236 bf_init(s, log2);
4237 bf_const_log2(log2, LIMB_BITS, BF_RNDU);
4238 bf_mul_ui(T, log2, e_max, LIMB_BITS, BF_RNDU);
4239 /* a > e_max * log(2) implies exp(a) > e_max */
4240 if (bf_cmp_lt(T, a) > 0) {
4241 /* overflow */
4242 bf_delete(T);
4243 bf_delete(log2);
4244 return bf_set_overflow(r, 0, prec, flags);
4245 }
4246 /* a < e_min * log(2) implies exp(a) < e_min */
4247 bf_mul_si(T, log2, e_min, LIMB_BITS, BF_RNDD);
4248 if (bf_cmp_lt(a, T)) {
4249 int rnd_mode = flags & BF_RND_MASK;
4250
4251 /* underflow */
4252 bf_delete(T);
4253 bf_delete(log2);
4254 if (rnd_mode == BF_RNDU) {
4255 /* set the smallest value */
4256 bf_set_ui(r, 1);
4257 r->expn = e_min;
4258 } else {
4259 bf_set_zero(r, 0);
4260 }
4261 return BF_ST_UNDERFLOW | BF_ST_INEXACT;
4262 }
4263 bf_delete(log2);
4264 bf_delete(T);
4265 }
4266
4267 return bf_ziv_rounding(r, a, prec, flags, bf_exp_internal, NULL);
4268 }
4269
bf_log_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4270 static int bf_log_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4271 {
4272 bf_context_t *s = r->ctx;
4273 bf_t T_s, *T = &T_s;
4274 bf_t U_s, *U = &U_s;
4275 bf_t V_s, *V = &V_s;
4276 slimb_t n, prec1, l, i, K;
4277
4278 assert(r != a);
4279
4280 bf_init(s, T);
4281 /* argument reduction 1 */
4282 /* T=a*2^n with 2/3 <= T <= 4/3 */
4283 {
4284 bf_t U_s, *U = &U_s;
4285 bf_set(T, a);
4286 n = T->expn;
4287 T->expn = 0;
4288 /* U= ~ 2/3 */
4289 bf_init(s, U);
4290 bf_set_ui(U, 0xaaaaaaaa);
4291 U->expn = 0;
4292 if (bf_cmp_lt(T, U)) {
4293 T->expn++;
4294 n--;
4295 }
4296 bf_delete(U);
4297 }
4298 // printf("n=%ld\n", n);
4299 // bf_print_str("T", T);
4300
4301 /* XXX: precision analysis */
4302 /* number of iterations for argument reduction 2 */
4303 K = bf_isqrt((prec + 1) / 2);
4304 /* order of Taylor expansion */
4305 l = prec / (2 * K) + 1;
4306 /* precision of the intermediate computations */
4307 prec1 = prec + K + 2 * l + 32;
4308
4309 bf_init(s, U);
4310 bf_init(s, V);
4311
4312 /* Note: cancellation occurs here, so we use more precision (XXX:
4313 reduce the precision by computing the exact cancellation) */
4314 bf_add_si(T, T, -1, BF_PREC_INF, BF_RNDN);
4315
4316 /* argument reduction 2 */
4317 for(i = 0; i < K; i++) {
4318 /* T = T / (1 + sqrt(1 + T)) */
4319 bf_add_si(U, T, 1, prec1, BF_RNDN);
4320 bf_sqrt(V, U, prec1, BF_RNDF);
4321 bf_add_si(U, V, 1, prec1, BF_RNDN);
4322 bf_div(T, T, U, prec1, BF_RNDN);
4323 }
4324
4325 {
4326 bf_t Y_s, *Y = &Y_s;
4327 bf_t Y2_s, *Y2 = &Y2_s;
4328 bf_init(s, Y);
4329 bf_init(s, Y2);
4330
4331 /* compute ln(1+x) = ln((1+y)/(1-y)) with y=x/(2+x)
4332 = y + y^3/3 + ... + y^(2*l + 1) / (2*l+1)
4333 with Y=Y^2
4334 = y*(1+Y/3+Y^2/5+...) = y*(1+Y*(1/3+Y*(1/5 + ...)))
4335 */
4336 bf_add_si(Y, T, 2, prec1, BF_RNDN);
4337 bf_div(Y, T, Y, prec1, BF_RNDN);
4338
4339 bf_mul(Y2, Y, Y, prec1, BF_RNDN);
4340 bf_set_ui(r, 0);
4341 for(i = l; i >= 1; i--) {
4342 bf_set_ui(U, 1);
4343 bf_set_ui(V, 2 * i + 1);
4344 bf_div(U, U, V, prec1, BF_RNDN);
4345 bf_add(r, r, U, prec1, BF_RNDN);
4346 bf_mul(r, r, Y2, prec1, BF_RNDN);
4347 }
4348 bf_add_si(r, r, 1, prec1, BF_RNDN);
4349 bf_mul(r, r, Y, prec1, BF_RNDN);
4350 bf_delete(Y);
4351 bf_delete(Y2);
4352 }
4353 bf_delete(V);
4354 bf_delete(U);
4355
4356 /* multiplication by 2 for the Taylor expansion and undo the
4357 argument reduction 2*/
4358 bf_mul_2exp(r, K + 1, BF_PREC_INF, BF_RNDZ);
4359
4360 /* undo the argument reduction 1 */
4361 bf_const_log2(T, prec1, BF_RNDF);
4362 bf_mul_si(T, T, n, prec1, BF_RNDN);
4363 bf_add(r, r, T, prec1, BF_RNDN);
4364
4365 bf_delete(T);
4366 return BF_ST_INEXACT;
4367 }
4368
bf_log(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4369 int bf_log(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4370 {
4371 bf_context_t *s = r->ctx;
4372 bf_t T_s, *T = &T_s;
4373
4374 assert(r != a);
4375 if (a->len == 0) {
4376 if (a->expn == BF_EXP_NAN) {
4377 bf_set_nan(r);
4378 return 0;
4379 } else if (a->expn == BF_EXP_INF) {
4380 if (a->sign) {
4381 bf_set_nan(r);
4382 return BF_ST_INVALID_OP;
4383 } else {
4384 bf_set_inf(r, 0);
4385 return 0;
4386 }
4387 } else {
4388 bf_set_inf(r, 1);
4389 return 0;
4390 }
4391 }
4392 if (a->sign) {
4393 bf_set_nan(r);
4394 return BF_ST_INVALID_OP;
4395 }
4396 bf_init(s, T);
4397 bf_set_ui(T, 1);
4398 if (bf_cmp_eq(a, T)) {
4399 bf_set_zero(r, 0);
4400 bf_delete(T);
4401 return 0;
4402 }
4403 bf_delete(T);
4404
4405 return bf_ziv_rounding(r, a, prec, flags, bf_log_internal, NULL);
4406 }
4407
4408 /* x and y finite and x > 0 */
4409 /* XXX: overflow/underflow handling */
bf_pow_generic(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4410 static int bf_pow_generic(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4411 {
4412 bf_context_t *s = r->ctx;
4413 const bf_t *y = opaque;
4414 bf_t T_s, *T = &T_s;
4415 limb_t prec1;
4416
4417 bf_init(s, T);
4418 /* XXX: proof for the added precision */
4419 prec1 = prec + 32;
4420 bf_log(T, x, prec1, BF_RNDF);
4421 bf_mul(T, T, y, prec1, BF_RNDF);
4422 bf_exp(r, T, prec1, BF_RNDF);
4423 bf_delete(T);
4424 return BF_ST_INEXACT;
4425 }
4426
4427 /* x and y finite, x > 0, y integer and y fits on one limb */
4428 /* XXX: overflow/underflow handling */
bf_pow_int(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4429 static int bf_pow_int(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4430 {
4431 bf_context_t *s = r->ctx;
4432 const bf_t *y = opaque;
4433 bf_t T_s, *T = &T_s;
4434 limb_t prec1;
4435 int ret;
4436 slimb_t y1;
4437
4438 bf_get_limb(&y1, y, 0);
4439 if (y1 < 0)
4440 y1 = -y1;
4441 /* XXX: proof for the added precision */
4442 prec1 = prec + ceil_log2(y1) * 2 + 8;
4443 ret = bf_pow_ui(r, x, y1 < 0 ? -y1 : y1, prec1, BF_RNDN);
4444 if (y->sign) {
4445 bf_init(s, T);
4446 bf_set_ui(T, 1);
4447 ret |= bf_div(r, T, r, prec1, BF_RNDN);
4448 bf_delete(T);
4449 }
4450 return ret;
4451 }
4452
4453 /* x must be a finite non zero float. Return TRUE if there is a
4454 floating point number r such as x=r^(2^n) and return this floating
4455 point number 'r'. Otherwise return FALSE and r is undefined. */
check_exact_power2n(bf_t * r,const bf_t * x,slimb_t n)4456 static BOOL check_exact_power2n(bf_t *r, const bf_t *x, slimb_t n)
4457 {
4458 bf_context_t *s = r->ctx;
4459 bf_t T_s, *T = &T_s;
4460 slimb_t e, i, er;
4461 limb_t v;
4462
4463 /* x = m*2^e with m odd integer */
4464 e = bf_get_exp_min(x);
4465 /* fast check on the exponent */
4466 if (n > (LIMB_BITS - 1)) {
4467 if (e != 0)
4468 return FALSE;
4469 er = 0;
4470 } else {
4471 if ((e & (((limb_t)1 << n) - 1)) != 0)
4472 return FALSE;
4473 er = e >> n;
4474 }
4475 /* every perfect odd square = 1 modulo 8 */
4476 v = get_bits(x->tab, x->len, x->len * LIMB_BITS - x->expn + e);
4477 if ((v & 7) != 1)
4478 return FALSE;
4479
4480 bf_init(s, T);
4481 bf_set(T, x);
4482 T->expn -= e;
4483 for(i = 0; i < n; i++) {
4484 if (i != 0)
4485 bf_set(T, r);
4486 if (bf_sqrtrem(r, NULL, T) != 0)
4487 return FALSE;
4488 }
4489 r->expn += er;
4490 return TRUE;
4491 }
4492
4493 /* prec = BF_PREC_INF is accepted for x and y integers and y >= 0 */
bf_pow(bf_t * r,const bf_t * x,const bf_t * y,limb_t prec,bf_flags_t flags)4494 int bf_pow(bf_t *r, const bf_t *x, const bf_t *y, limb_t prec, bf_flags_t flags)
4495 {
4496 bf_context_t *s = r->ctx;
4497 bf_t T_s, *T = &T_s;
4498 bf_t ytmp_s;
4499 BOOL y_is_int, y_is_odd;
4500 int r_sign, ret, rnd_mode;
4501 slimb_t y_emin;
4502
4503 if (x->len == 0 || y->len == 0) {
4504 if (y->expn == BF_EXP_ZERO) {
4505 /* pow(x, 0) = 1 */
4506 bf_set_ui(r, 1);
4507 } else if (x->expn == BF_EXP_NAN) {
4508 bf_set_nan(r);
4509 } else {
4510 int cmp_x_abs_1;
4511 bf_set_ui(r, 1);
4512 cmp_x_abs_1 = bf_cmpu(x, r);
4513 if (cmp_x_abs_1 == 0 && (flags & BF_POW_JS_QUICKS) &&
4514 (y->expn >= BF_EXP_INF)) {
4515 bf_set_nan(r);
4516 } else if (cmp_x_abs_1 == 0 &&
4517 (!x->sign || y->expn != BF_EXP_NAN)) {
4518 /* pow(1, y) = 1 even if y = NaN */
4519 /* pow(-1, +/-inf) = 1 */
4520 } else if (y->expn == BF_EXP_NAN) {
4521 bf_set_nan(r);
4522 } else if (y->expn == BF_EXP_INF) {
4523 if (y->sign == (cmp_x_abs_1 > 0)) {
4524 bf_set_zero(r, 0);
4525 } else {
4526 bf_set_inf(r, 0);
4527 }
4528 } else {
4529 y_emin = bf_get_exp_min(y);
4530 y_is_odd = (y_emin == 0);
4531 if (y->sign == (x->expn == BF_EXP_ZERO)) {
4532 bf_set_inf(r, y_is_odd & x->sign);
4533 if (y->sign) {
4534 /* pow(0, y) with y < 0 */
4535 return BF_ST_DIVIDE_ZERO;
4536 }
4537 } else {
4538 bf_set_zero(r, y_is_odd & x->sign);
4539 }
4540 }
4541 }
4542 return 0;
4543 }
4544 bf_init(s, T);
4545 bf_set(T, x);
4546 y_emin = bf_get_exp_min(y);
4547 y_is_int = (y_emin >= 0);
4548 rnd_mode = flags & BF_RND_MASK;
4549 if (x->sign) {
4550 if (!y_is_int) {
4551 bf_set_nan(r);
4552 bf_delete(T);
4553 return BF_ST_INVALID_OP;
4554 }
4555 y_is_odd = (y_emin == 0);
4556 r_sign = y_is_odd;
4557 /* change the directed rounding mode if the sign of the result
4558 is changed */
4559 if (r_sign && (rnd_mode == BF_RNDD || rnd_mode == BF_RNDU))
4560 flags ^= 1;
4561 bf_neg(T);
4562 } else {
4563 r_sign = 0;
4564 }
4565
4566 bf_set_ui(r, 1);
4567 if (bf_cmp_eq(T, r)) {
4568 /* abs(x) = 1: nothing more to do */
4569 ret = 0;
4570 } else if (y_is_int) {
4571 slimb_t T_bits, e;
4572 int_pow:
4573 T_bits = T->expn - bf_get_exp_min(T);
4574 if (T_bits == 1) {
4575 /* pow(2^b, y) = 2^(b*y) */
4576 bf_mul_si(T, y, T->expn - 1, LIMB_BITS, BF_RNDZ);
4577 bf_get_limb(&e, T, 0);
4578 bf_set_ui(r, 1);
4579 ret = bf_mul_2exp(r, e, prec, flags);
4580 } else if (prec == BF_PREC_INF) {
4581 slimb_t y1;
4582 /* specific case for infinite precision (integer case) */
4583 bf_get_limb(&y1, y, 0);
4584 assert(!y->sign);
4585 /* x must be an integer, so abs(x) >= 2 */
4586 if (y1 >= ((slimb_t)1 << BF_EXP_BITS_MAX)) {
4587 bf_delete(T);
4588 return bf_set_overflow(r, 0, BF_PREC_INF, flags);
4589 }
4590 ret = bf_pow_ui(r, T, y1, BF_PREC_INF, BF_RNDZ);
4591 } else {
4592 if (y->expn <= 31) {
4593 /* small enough power: use exponentiation in all cases */
4594 } else if (y->sign) {
4595 /* cannot be exact */
4596 goto general_case;
4597 } else {
4598 if (rnd_mode == BF_RNDF)
4599 goto general_case; /* no need to track exact results */
4600 /* see if the result has a chance to be exact:
4601 if x=a*2^b (a odd), x^y=a^y*2^(b*y)
4602 x^y needs a precision of at least floor_log2(a)*y bits
4603 */
4604 bf_mul_si(r, y, T_bits - 1, LIMB_BITS, BF_RNDZ);
4605 bf_get_limb(&e, r, 0);
4606 if (prec < e)
4607 goto general_case;
4608 }
4609 ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_int, (void *)y);
4610 }
4611 } else {
4612 if (rnd_mode != BF_RNDF) {
4613 bf_t *y1;
4614 if (y_emin < 0 && check_exact_power2n(r, T, -y_emin)) {
4615 /* the problem is reduced to a power to an integer */
4616 #if 0
4617 printf("\nn=%ld\n", -y_emin);
4618 bf_print_str("T", T);
4619 bf_print_str("r", r);
4620 #endif
4621 bf_set(T, r);
4622 y1 = &ytmp_s;
4623 y1->tab = y->tab;
4624 y1->len = y->len;
4625 y1->sign = y->sign;
4626 y1->expn = y->expn - y_emin;
4627 y = y1;
4628 goto int_pow;
4629 }
4630 }
4631 general_case:
4632 ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_generic, (void *)y);
4633 }
4634 bf_delete(T);
4635 r->sign = r_sign;
4636 return ret;
4637 }
4638
4639 /* compute sqrt(-2*x-x^2) to get |sin(x)| from cos(x) - 1. */
bf_sqrt_sin(bf_t * r,const bf_t * x,limb_t prec1)4640 static void bf_sqrt_sin(bf_t *r, const bf_t *x, limb_t prec1)
4641 {
4642 bf_context_t *s = r->ctx;
4643 bf_t T_s, *T = &T_s;
4644 bf_init(s, T);
4645 bf_set(T, x);
4646 bf_mul(r, T, T, prec1, BF_RNDN);
4647 bf_mul_2exp(T, 1, BF_PREC_INF, BF_RNDZ);
4648 bf_add(T, T, r, prec1, BF_RNDN);
4649 bf_neg(T);
4650 bf_sqrt(r, T, prec1, BF_RNDF);
4651 bf_delete(T);
4652 }
4653
bf_sincos(bf_t * s,bf_t * c,const bf_t * a,limb_t prec)4654 int bf_sincos(bf_t *s, bf_t *c, const bf_t *a, limb_t prec)
4655 {
4656 bf_context_t *s1 = a->ctx;
4657 bf_t T_s, *T = &T_s;
4658 bf_t U_s, *U = &U_s;
4659 bf_t r_s, *r = &r_s;
4660 slimb_t K, prec1, i, l, mod, prec2;
4661 int is_neg;
4662
4663 assert(c != a && s != a);
4664 if (a->len == 0) {
4665 if (a->expn == BF_EXP_NAN) {
4666 if (c)
4667 bf_set_nan(c);
4668 if (s)
4669 bf_set_nan(s);
4670 return 0;
4671 } else if (a->expn == BF_EXP_INF) {
4672 if (c)
4673 bf_set_nan(c);
4674 if (s)
4675 bf_set_nan(s);
4676 return BF_ST_INVALID_OP;
4677 } else {
4678 if (c)
4679 bf_set_ui(c, 1);
4680 if (s)
4681 bf_set_zero(s, a->sign);
4682 return 0;
4683 }
4684 }
4685
4686 bf_init(s1, T);
4687 bf_init(s1, U);
4688 bf_init(s1, r);
4689
4690 /* XXX: precision analysis */
4691 K = bf_isqrt(prec / 2);
4692 l = prec / (2 * K) + 1;
4693 prec1 = prec + 2 * K + l + 8;
4694
4695 /* after the modulo reduction, -pi/4 <= T <= pi/4 */
4696 if (a->expn <= -1) {
4697 /* abs(a) <= 0.25: no modulo reduction needed */
4698 bf_set(T, a);
4699 mod = 0;
4700 } else {
4701 slimb_t cancel;
4702 cancel = 0;
4703 for(;;) {
4704 prec2 = prec1 + cancel;
4705 bf_const_pi(U, prec2, BF_RNDF);
4706 bf_mul_2exp(U, -1, BF_PREC_INF, BF_RNDZ);
4707 bf_remquo(&mod, T, a, U, prec2, BF_RNDN);
4708 // printf("T.expn=%ld prec2=%ld\n", T->expn, prec2);
4709 if (mod == 0 || (T->expn != BF_EXP_ZERO &&
4710 (T->expn + prec2) >= (prec1 - 1)))
4711 break;
4712 /* increase the number of bits until the precision is good enough */
4713 cancel = bf_max(-T->expn, (cancel + 1) * 3 / 2);
4714 }
4715 mod &= 3;
4716 }
4717
4718 is_neg = T->sign;
4719
4720 /* compute cosm1(x) = cos(x) - 1 */
4721 bf_mul(T, T, T, prec1, BF_RNDN);
4722 bf_mul_2exp(T, -2 * K, BF_PREC_INF, BF_RNDZ);
4723
4724 /* Taylor expansion:
4725 -x^2/2 + x^4/4! - x^6/6! + ...
4726 */
4727 bf_set_ui(r, 1);
4728 for(i = l ; i >= 1; i--) {
4729 bf_set_ui(U, 2 * i - 1);
4730 bf_mul_ui(U, U, 2 * i, BF_PREC_INF, BF_RNDZ);
4731 bf_div(U, T, U, prec1, BF_RNDN);
4732 bf_mul(r, r, U, prec1, BF_RNDN);
4733 bf_neg(r);
4734 if (i != 1)
4735 bf_add_si(r, r, 1, prec1, BF_RNDN);
4736 }
4737 bf_delete(U);
4738
4739 /* undo argument reduction:
4740 cosm1(2*x)= 2*(2*cosm1(x)+cosm1(x)^2)
4741 */
4742 for(i = 0; i < K; i++) {
4743 bf_mul(T, r, r, prec1, BF_RNDN);
4744 bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4745 bf_add(r, r, T, prec1, BF_RNDN);
4746 bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4747 }
4748 bf_delete(T);
4749
4750 if (c) {
4751 if ((mod & 1) == 0) {
4752 bf_add_si(c, r, 1, prec1, BF_RNDN);
4753 } else {
4754 bf_sqrt_sin(c, r, prec1);
4755 c->sign = is_neg ^ 1;
4756 }
4757 c->sign ^= mod >> 1;
4758 }
4759 if (s) {
4760 if ((mod & 1) == 0) {
4761 bf_sqrt_sin(s, r, prec1);
4762 s->sign = is_neg;
4763 } else {
4764 bf_add_si(s, r, 1, prec1, BF_RNDN);
4765 }
4766 s->sign ^= mod >> 1;
4767 }
4768 bf_delete(r);
4769 return BF_ST_INEXACT;
4770 }
4771
bf_cos_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4772 static int bf_cos_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4773 {
4774 return bf_sincos(NULL, r, a, prec);
4775 }
4776
bf_cos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4777 int bf_cos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4778 {
4779 return bf_ziv_rounding(r, a, prec, flags, bf_cos_internal, NULL);
4780 }
4781
bf_sin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4782 static int bf_sin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4783 {
4784 return bf_sincos(r, NULL, a, prec);
4785 }
4786
bf_sin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4787 int bf_sin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4788 {
4789 return bf_ziv_rounding(r, a, prec, flags, bf_sin_internal, NULL);
4790 }
4791
bf_tan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4792 static int bf_tan_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4793 {
4794 bf_context_t *s = r->ctx;
4795 bf_t T_s, *T = &T_s;
4796 limb_t prec1;
4797
4798 if (a->len == 0) {
4799 if (a->expn == BF_EXP_NAN) {
4800 bf_set_nan(r);
4801 return 0;
4802 } else if (a->expn == BF_EXP_INF) {
4803 bf_set_nan(r);
4804 return BF_ST_INVALID_OP;
4805 } else {
4806 bf_set_zero(r, a->sign);
4807 return 0;
4808 }
4809 }
4810
4811 /* XXX: precision analysis */
4812 prec1 = prec + 8;
4813 bf_init(s, T);
4814 bf_sincos(r, T, a, prec1);
4815 bf_div(r, r, T, prec1, BF_RNDF);
4816 bf_delete(T);
4817 return BF_ST_INEXACT;
4818 }
4819
bf_tan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4820 int bf_tan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4821 {
4822 return bf_ziv_rounding(r, a, prec, flags, bf_tan_internal, NULL);
4823 }
4824
4825 /* if add_pi2 is true, add pi/2 to the result (used for acos(x) to
4826 avoid cancellation) */
bf_atan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4827 static int bf_atan_internal(bf_t *r, const bf_t *a, limb_t prec,
4828 void *opaque)
4829 {
4830 bf_context_t *s = r->ctx;
4831 BOOL add_pi2 = (BOOL)(intptr_t)opaque;
4832 bf_t T_s, *T = &T_s;
4833 bf_t U_s, *U = &U_s;
4834 bf_t V_s, *V = &V_s;
4835 bf_t X2_s, *X2 = &X2_s;
4836 int cmp_1;
4837 slimb_t prec1, i, K, l;
4838
4839 if (a->len == 0) {
4840 if (a->expn == BF_EXP_NAN) {
4841 bf_set_nan(r);
4842 return 0;
4843 } else {
4844 if (a->expn == BF_EXP_INF)
4845 i = 1 - 2 * a->sign;
4846 else
4847 i = 0;
4848 i += add_pi2;
4849 /* return i*(pi/2) with -1 <= i <= 2 */
4850 if (i == 0) {
4851 bf_set_zero(r, add_pi2 ? 0 : a->sign);
4852 return 0;
4853 } else {
4854 /* PI or PI/2 */
4855 bf_const_pi(r, prec, BF_RNDF);
4856 if (i != 2)
4857 bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
4858 r->sign = (i < 0);
4859 return BF_ST_INEXACT;
4860 }
4861 }
4862 }
4863
4864 bf_init(s, T);
4865 bf_set_ui(T, 1);
4866 cmp_1 = bf_cmpu(a, T);
4867 if (cmp_1 == 0 && !add_pi2) {
4868 /* short cut: abs(a) == 1 -> +/-pi/4 */
4869 bf_const_pi(r, prec, BF_RNDF);
4870 bf_mul_2exp(r, -2, BF_PREC_INF, BF_RNDZ);
4871 r->sign = a->sign;
4872 bf_delete(T);
4873 return BF_ST_INEXACT;
4874 }
4875
4876 /* XXX: precision analysis */
4877 K = bf_isqrt((prec + 1) / 2);
4878 l = prec / (2 * K) + 1;
4879 prec1 = prec + K + 2 * l + 32;
4880 // printf("prec=%ld K=%ld l=%ld prec1=%ld\n", prec, K, l, prec1);
4881
4882 if (cmp_1 > 0) {
4883 bf_set_ui(T, 1);
4884 bf_div(T, T, a, prec1, BF_RNDN);
4885 } else {
4886 bf_set(T, a);
4887 }
4888
4889 /* abs(T) <= 1 */
4890
4891 /* argument reduction */
4892
4893 bf_init(s, U);
4894 bf_init(s, V);
4895 bf_init(s, X2);
4896 for(i = 0; i < K; i++) {
4897 /* T = T / (1 + sqrt(1 + T^2)) */
4898 bf_mul(U, T, T, prec1, BF_RNDN);
4899 bf_add_si(U, U, 1, prec1, BF_RNDN);
4900 bf_sqrt(V, U, prec1, BF_RNDN);
4901 bf_add_si(V, V, 1, prec1, BF_RNDN);
4902 bf_div(T, T, V, prec1, BF_RNDN);
4903 }
4904
4905 /* Taylor series:
4906 x - x^3/3 + ... + (-1)^ l * y^(2*l + 1) / (2*l+1)
4907 */
4908 bf_mul(X2, T, T, prec1, BF_RNDN);
4909 bf_set_ui(r, 0);
4910 for(i = l; i >= 1; i--) {
4911 bf_set_si(U, 1);
4912 bf_set_ui(V, 2 * i + 1);
4913 bf_div(U, U, V, prec1, BF_RNDN);
4914 bf_neg(r);
4915 bf_add(r, r, U, prec1, BF_RNDN);
4916 bf_mul(r, r, X2, prec1, BF_RNDN);
4917 }
4918 bf_neg(r);
4919 bf_add_si(r, r, 1, prec1, BF_RNDN);
4920 bf_mul(r, r, T, prec1, BF_RNDN);
4921
4922 /* undo the argument reduction */
4923 bf_mul_2exp(r, K, BF_PREC_INF, BF_RNDZ);
4924
4925 bf_delete(U);
4926 bf_delete(V);
4927 bf_delete(X2);
4928
4929 i = add_pi2;
4930 if (cmp_1 > 0) {
4931 /* undo the inversion : r = sign(a)*PI/2 - r */
4932 bf_neg(r);
4933 i += 1 - 2 * a->sign;
4934 }
4935 /* add i*(pi/2) with -1 <= i <= 2 */
4936 if (i != 0) {
4937 bf_const_pi(T, prec1, BF_RNDF);
4938 if (i != 2)
4939 bf_mul_2exp(T, -1, BF_PREC_INF, BF_RNDZ);
4940 T->sign = (i < 0);
4941 bf_add(r, T, r, prec1, BF_RNDN);
4942 }
4943
4944 bf_delete(T);
4945 return BF_ST_INEXACT;
4946 }
4947
bf_atan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4948 int bf_atan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4949 {
4950 return bf_ziv_rounding(r, a, prec, flags, bf_atan_internal, (void *)FALSE);
4951 }
4952
bf_atan2_internal(bf_t * r,const bf_t * y,limb_t prec,void * opaque)4953 static int bf_atan2_internal(bf_t *r, const bf_t *y, limb_t prec, void *opaque)
4954 {
4955 bf_context_t *s = r->ctx;
4956 const bf_t *x = opaque;
4957 bf_t T_s, *T = &T_s;
4958 limb_t prec1;
4959 int ret;
4960
4961 if (y->expn == BF_EXP_NAN || x->expn == BF_EXP_NAN) {
4962 bf_set_nan(r);
4963 return 0;
4964 }
4965
4966 /* compute atan(y/x) assumming inf/inf = 1 and 0/0 = 0 */
4967 bf_init(s, T);
4968 prec1 = prec + 32;
4969 if (y->expn == BF_EXP_INF && x->expn == BF_EXP_INF) {
4970 bf_set_ui(T, 1);
4971 T->sign = y->sign ^ x->sign;
4972 } else if (y->expn == BF_EXP_ZERO && x->expn == BF_EXP_ZERO) {
4973 bf_set_zero(T, y->sign ^ x->sign);
4974 } else {
4975 bf_div(T, y, x, prec1, BF_RNDF);
4976 }
4977 ret = bf_atan(r, T, prec1, BF_RNDF);
4978
4979 if (x->sign) {
4980 /* if x < 0 (it includes -0), return sign(y)*pi + atan(y/x) */
4981 bf_const_pi(T, prec1, BF_RNDF);
4982 T->sign = y->sign;
4983 bf_add(r, r, T, prec1, BF_RNDN);
4984 ret |= BF_ST_INEXACT;
4985 }
4986
4987 bf_delete(T);
4988 return ret;
4989 }
4990
bf_atan2(bf_t * r,const bf_t * y,const bf_t * x,limb_t prec,bf_flags_t flags)4991 int bf_atan2(bf_t *r, const bf_t *y, const bf_t *x,
4992 limb_t prec, bf_flags_t flags)
4993 {
4994 return bf_ziv_rounding(r, y, prec, flags, bf_atan2_internal, (void *)x);
4995 }
4996
bf_asin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4997 static int bf_asin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4998 {
4999 bf_context_t *s = r->ctx;
5000 BOOL is_acos = (BOOL)(intptr_t)opaque;
5001 bf_t T_s, *T = &T_s;
5002 limb_t prec1, prec2;
5003 int res;
5004
5005 if (a->len == 0) {
5006 if (a->expn == BF_EXP_NAN) {
5007 bf_set_nan(r);
5008 return 0;
5009 } else if (a->expn == BF_EXP_INF) {
5010 bf_set_nan(r);
5011 return BF_ST_INVALID_OP;
5012 } else {
5013 if (is_acos) {
5014 bf_const_pi(r, prec, BF_RNDF);
5015 bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
5016 return BF_ST_INEXACT;
5017 } else {
5018 bf_set_zero(r, a->sign);
5019 return 0;
5020 }
5021 }
5022 }
5023 bf_init(s, T);
5024 bf_set_ui(T, 1);
5025 res = bf_cmpu(a, T);
5026 if (res > 0) {
5027 bf_delete(T);
5028 bf_set_nan(r);
5029 return BF_ST_INVALID_OP;
5030 } else if (res == 0 && a->sign == 0 && is_acos) {
5031 bf_set_zero(r, 0);
5032 bf_delete(T);
5033 return 0;
5034 }
5035
5036 /* asin(x) = atan(x/sqrt(1-x^2))
5037 acos(x) = pi/2 - asin(x) */
5038 prec1 = prec + 8;
5039 /* increase the precision in x^2 to compensate the cancellation in
5040 (1-x^2) if x is close to 1 */
5041 /* XXX: use less precision when possible */
5042 if (a->expn >= 0)
5043 prec2 = BF_PREC_INF;
5044 else
5045 prec2 = prec1;
5046 bf_mul(T, a, a, prec2, BF_RNDN);
5047 bf_neg(T);
5048 bf_add_si(T, T, 1, prec2, BF_RNDN);
5049
5050 bf_sqrt(r, T, prec1, BF_RNDN);
5051 bf_div(T, a, r, prec1, BF_RNDN);
5052 if (is_acos)
5053 bf_neg(T);
5054 bf_atan_internal(r, T, prec1, (void *)(intptr_t)is_acos);
5055 bf_delete(T);
5056 return BF_ST_INEXACT;
5057 }
5058
bf_asin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5059 int bf_asin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5060 {
5061 return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)FALSE);
5062 }
5063
bf_acos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5064 int bf_acos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5065 {
5066 return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)TRUE);
5067 }
5068
5069 /***************************************************************/
5070 /* decimal floating point numbers */
5071
5072 #ifdef USE_BF_DEC
5073
5074 #define adddq(r1, r0, a1, a0) \
5075 do { \
5076 limb_t __t = r0; \
5077 r0 += (a0); \
5078 r1 += (a1) + (r0 < __t); \
5079 } while (0)
5080
5081 #define subdq(r1, r0, a1, a0) \
5082 do { \
5083 limb_t __t = r0; \
5084 r0 -= (a0); \
5085 r1 -= (a1) + (r0 > __t); \
5086 } while (0)
5087
5088 #if LIMB_BITS == 64
5089
5090 /* Note: we assume __int128 is available */
5091 #define muldq(r1, r0, a, b) \
5092 do { \
5093 unsigned __int128 __t; \
5094 __t = (unsigned __int128)(a) * (unsigned __int128)(b); \
5095 r0 = __t; \
5096 r1 = __t >> 64; \
5097 } while (0)
5098
5099 #define divdq(q, r, a1, a0, b) \
5100 do { \
5101 unsigned __int128 __t; \
5102 limb_t __b = (b); \
5103 __t = ((unsigned __int128)(a1) << 64) | (a0); \
5104 q = __t / __b; \
5105 r = __t % __b; \
5106 } while (0)
5107
5108 #else
5109
5110 #define muldq(r1, r0, a, b) \
5111 do { \
5112 uint64_t __t; \
5113 __t = (uint64_t)(a) * (uint64_t)(b); \
5114 r0 = __t; \
5115 r1 = __t >> 32; \
5116 } while (0)
5117
5118 #define divdq(q, r, a1, a0, b) \
5119 do { \
5120 uint64_t __t; \
5121 limb_t __b = (b); \
5122 __t = ((uint64_t)(a1) << 32) | (a0); \
5123 q = __t / __b; \
5124 r = __t % __b; \
5125 } while (0)
5126
5127 #endif /* LIMB_BITS != 64 */
5128
5129 #if 0 //unused
5130 static inline limb_t shrd(limb_t low, limb_t high, long shift)
5131 {
5132 if (shift != 0)
5133 low = (low >> shift) | (high << (LIMB_BITS - shift));
5134 return low;
5135 }
5136 #endif
5137
shld(limb_t a1,limb_t a0,long shift)5138 static inline limb_t shld(limb_t a1, limb_t a0, long shift)
5139 {
5140 if (shift != 0)
5141 return (a1 << shift) | (a0 >> (LIMB_BITS - shift));
5142 else
5143 return a1;
5144 }
5145
5146 #if LIMB_DIGITS == 19
5147
5148 /* WARNING: hardcoded for b = 1e19. It is assumed that:
5149 0 <= a1 < 2^63 */
5150 #define divdq_base(q, r, a1, a0)\
5151 do {\
5152 uint64_t __a0, __a1, __t0, __t1, __b = BF_DEC_BASE; \
5153 __a0 = a0;\
5154 __a1 = a1;\
5155 __t0 = __a1;\
5156 __t0 = shld(__t0, __a0, 1);\
5157 muldq(q, __t1, __t0, UINT64_C(17014118346046923173)); \
5158 muldq(__t1, __t0, q, __b);\
5159 subdq(__a1, __a0, __t1, __t0);\
5160 subdq(__a1, __a0, 1, __b * 2); \
5161 __t0 = (slimb_t)__a1 >> 1; \
5162 q += 2 + __t0;\
5163 adddq(__a1, __a0, 0, __b & __t0);\
5164 q += __a1; \
5165 __a0 += __b & __a1; \
5166 r = __a0;\
5167 } while(0)
5168
5169 #elif LIMB_DIGITS == 9
5170
5171 /* WARNING: hardcoded for b = 1e9. It is assumed that:
5172 0 <= a1 < 2^29 */
5173 #define divdq_base(q, r, a1, a0)\
5174 do {\
5175 uint32_t __t0, __t1, __b = BF_DEC_BASE; \
5176 __t0 = a1;\
5177 __t1 = a0;\
5178 __t0 = (__t0 << 3) | (__t1 >> (32 - 3)); \
5179 muldq(q, __t1, __t0, 2305843009U);\
5180 r = a0 - q * __b;\
5181 __t1 = (r >= __b);\
5182 q += __t1;\
5183 if (__t1)\
5184 r -= __b;\
5185 } while(0)
5186
5187 #endif
5188
5189 /* fast integer division by a fixed constant */
5190
5191 typedef struct FastDivData {
5192 limb_t d; /* divisor (only user visible field) */
5193 limb_t m1; /* multiplier */
5194 int shift1;
5195 int shift2;
5196 } FastDivData;
5197
5198 #if 1
5199 /* From "Division by Invariant Integers using Multiplication" by
5200 Torborn Granlund and Peter L. Montgomery */
5201 /* d must be != 0 */
fast_udiv_init(FastDivData * s,limb_t d)5202 static inline void fast_udiv_init(FastDivData *s, limb_t d)
5203 {
5204 int l;
5205 limb_t q, r, m1;
5206 s->d = d;
5207 if (d == 1)
5208 l = 0;
5209 else
5210 l = 64 - clz64(d - 1);
5211 divdq(q, r, ((limb_t)1 << l) - d, 0, d);
5212 (void)r;
5213 m1 = q + 1;
5214 // printf("d=%lu l=%d m1=0x%016lx\n", d, l, m1);
5215 s->m1 = m1;
5216 s->shift1 = l;
5217 if (s->shift1 > 1)
5218 s->shift1 = 1;
5219 s->shift2 = l - 1;
5220 if (s->shift2 < 0)
5221 s->shift2 = 0;
5222 }
5223
fast_udiv(limb_t a,const FastDivData * s)5224 static inline limb_t fast_udiv(limb_t a, const FastDivData *s)
5225 {
5226 limb_t t0, t1;
5227 muldq(t1, t0, s->m1, a);
5228 t0 = (a - t1) >> s->shift1;
5229 return (t1 + t0) >> s->shift2;
5230 }
5231
5232 #if 0 //unused
5233 static inline limb_t fast_urem(limb_t a, const FastDivData *s)
5234 {
5235 limb_t q;
5236 q = fast_udiv(a, s);
5237 return a - q * s->d;
5238 }
5239 #endif
5240
5241 #define fast_udivrem(q, r, a, s) q = fast_udiv(a, s), r = a - q * (s)->d
5242
5243 #else
5244
fast_udiv_init(FastDivData * s,limb_t d)5245 static inline void fast_udiv_init(FastDivData *s, limb_t d)
5246 {
5247 s->d = d;
5248 }
fast_udiv(limb_t a,const FastDivData * s)5249 static inline limb_t fast_udiv(limb_t a, const FastDivData *s)
5250 {
5251 return a / s->d;
5252 }
fast_urem(limb_t a,const FastDivData * s)5253 static inline limb_t fast_urem(limb_t a, const FastDivData *s)
5254 {
5255 return a % s->d;
5256 }
5257
5258 #define fast_udivrem(q, r, a, s) q = a / (s)->d, r = a % (s)->d
5259
5260 #endif
5261
5262 /* contains 10^i */
5263 /* XXX: make it const */
5264 limb_t mp_pow_dec[LIMB_DIGITS + 1];
5265 static FastDivData mp_pow_div[LIMB_DIGITS + 1];
5266
mp_pow_init(void)5267 static void mp_pow_init(void)
5268 {
5269 limb_t a;
5270 int i;
5271
5272 a = 1;
5273 for(i = 0; i <= LIMB_DIGITS; i++) {
5274 mp_pow_dec[i] = a;
5275 fast_udiv_init(&mp_pow_div[i], a);
5276 a = a * 10;
5277 }
5278 }
5279
mp_add_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5280 limb_t mp_add_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5281 mp_size_t n, limb_t carry)
5282 {
5283 limb_t base = BF_DEC_BASE;
5284 mp_size_t i;
5285 limb_t k, a, v;
5286
5287 k=carry;
5288 for(i=0;i<n;i++) {
5289 /* XXX: reuse the trick in add_mod */
5290 v = op1[i];
5291 a = v + op2[i] + k - base;
5292 k = a <= v;
5293 if (!k)
5294 a += base;
5295 res[i]=a;
5296 }
5297 return k;
5298 }
5299
mp_add_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5300 limb_t mp_add_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5301 {
5302 limb_t base = BF_DEC_BASE;
5303 mp_size_t i;
5304 limb_t k, a, v;
5305
5306 k=b;
5307 for(i=0;i<n;i++) {
5308 v = tab[i];
5309 a = v + k - base;
5310 k = a <= v;
5311 if (!k)
5312 a += base;
5313 tab[i] = a;
5314 if (k == 0)
5315 break;
5316 }
5317 return k;
5318 }
5319
mp_sub_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5320 limb_t mp_sub_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5321 mp_size_t n, limb_t carry)
5322 {
5323 limb_t base = BF_DEC_BASE;
5324 mp_size_t i;
5325 limb_t k, v, a;
5326
5327 k=carry;
5328 for(i=0;i<n;i++) {
5329 v = op1[i];
5330 a = v - op2[i] - k;
5331 k = a > v;
5332 if (k)
5333 a += base;
5334 res[i] = a;
5335 }
5336 return k;
5337 }
5338
mp_sub_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5339 limb_t mp_sub_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5340 {
5341 limb_t base = BF_DEC_BASE;
5342 mp_size_t i;
5343 limb_t k, v, a;
5344
5345 k=b;
5346 for(i=0;i<n;i++) {
5347 v = tab[i];
5348 a = v - k;
5349 k = a > v;
5350 if (k)
5351 a += base;
5352 tab[i]=a;
5353 if (k == 0)
5354 break;
5355 }
5356 return k;
5357 }
5358
5359 /* taba[] = taba[] * b + l. 0 <= b, l <= base - 1. Return the high carry */
mp_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b,limb_t l)5360 limb_t mp_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5361 limb_t b, limb_t l)
5362 {
5363 mp_size_t i;
5364 limb_t t0, t1, r;
5365
5366 for(i = 0; i < n; i++) {
5367 muldq(t1, t0, taba[i], b);
5368 adddq(t1, t0, 0, l);
5369 divdq_base(l, r, t1, t0);
5370 tabr[i] = r;
5371 }
5372 return l;
5373 }
5374
5375 /* tabr[] += taba[] * b. 0 <= b <= base - 1. Return the value to add
5376 to the high word */
mp_add_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5377 limb_t mp_add_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5378 limb_t b)
5379 {
5380 mp_size_t i;
5381 limb_t l, t0, t1, r;
5382
5383 l = 0;
5384 for(i = 0; i < n; i++) {
5385 muldq(t1, t0, taba[i], b);
5386 adddq(t1, t0, 0, l);
5387 adddq(t1, t0, 0, tabr[i]);
5388 divdq_base(l, r, t1, t0);
5389 tabr[i] = r;
5390 }
5391 return l;
5392 }
5393
5394 /* tabr[] -= taba[] * b. 0 <= b <= base - 1. Return the value to
5395 substract to the high word. */
mp_sub_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5396 limb_t mp_sub_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5397 limb_t b)
5398 {
5399 limb_t base = BF_DEC_BASE;
5400 mp_size_t i;
5401 limb_t l, t0, t1, r, a, v, c;
5402
5403 /* XXX: optimize */
5404 l = 0;
5405 for(i = 0; i < n; i++) {
5406 muldq(t1, t0, taba[i], b);
5407 adddq(t1, t0, 0, l);
5408 divdq_base(l, r, t1, t0);
5409 v = tabr[i];
5410 a = v - r;
5411 c = a > v;
5412 if (c)
5413 a += base;
5414 /* never bigger than base because r = 0 when l = base - 1 */
5415 l += c;
5416 tabr[i] = a;
5417 }
5418 return l;
5419 }
5420
5421 /* size of the result : op1_size + op2_size. */
mp_mul_basecase_dec(limb_t * result,const limb_t * op1,mp_size_t op1_size,const limb_t * op2,mp_size_t op2_size)5422 void mp_mul_basecase_dec(limb_t *result,
5423 const limb_t *op1, mp_size_t op1_size,
5424 const limb_t *op2, mp_size_t op2_size)
5425 {
5426 mp_size_t i;
5427 limb_t r;
5428
5429 result[op1_size] = mp_mul1_dec(result, op1, op1_size, op2[0], 0);
5430
5431 for(i=1;i<op2_size;i++) {
5432 r = mp_add_mul1_dec(result + i, op1, op1_size, op2[i]);
5433 result[i + op1_size] = r;
5434 }
5435 }
5436
5437 /* taba[] = (taba[] + r*base^na) / b. 0 <= b < base. 0 <= r <
5438 b. Return the remainder. */
mp_div1_dec(limb_t * tabr,const limb_t * taba,mp_size_t na,limb_t b,limb_t r)5439 limb_t mp_div1_dec(limb_t *tabr, const limb_t *taba, mp_size_t na,
5440 limb_t b, limb_t r)
5441 {
5442 limb_t base = BF_DEC_BASE;
5443 mp_size_t i;
5444 limb_t t0, t1, q;
5445 int shift;
5446
5447 #if (BF_DEC_BASE % 2) == 0
5448 if (b == 2) {
5449 limb_t base_div2;
5450 /* Note: only works if base is even */
5451 base_div2 = base >> 1;
5452 if (r)
5453 r = base_div2;
5454 for(i = na - 1; i >= 0; i--) {
5455 t0 = taba[i];
5456 tabr[i] = (t0 >> 1) + r;
5457 r = 0;
5458 if (t0 & 1)
5459 r = base_div2;
5460 }
5461 if (r)
5462 r = 1;
5463 } else
5464 #endif
5465 if (na >= UDIV1NORM_THRESHOLD) {
5466 shift = clz(b);
5467 if (shift == 0) {
5468 /* normalized case: b >= 2^(LIMB_BITS-1) */
5469 limb_t b_inv;
5470 b_inv = udiv1norm_init(b);
5471 for(i = na - 1; i >= 0; i--) {
5472 muldq(t1, t0, r, base);
5473 adddq(t1, t0, 0, taba[i]);
5474 q = udiv1norm(&r, t1, t0, b, b_inv);
5475 tabr[i] = q;
5476 }
5477 } else {
5478 limb_t b_inv;
5479 b <<= shift;
5480 b_inv = udiv1norm_init(b);
5481 for(i = na - 1; i >= 0; i--) {
5482 muldq(t1, t0, r, base);
5483 adddq(t1, t0, 0, taba[i]);
5484 t1 = (t1 << shift) | (t0 >> (LIMB_BITS - shift));
5485 t0 <<= shift;
5486 q = udiv1norm(&r, t1, t0, b, b_inv);
5487 r >>= shift;
5488 tabr[i] = q;
5489 }
5490 }
5491 } else {
5492 for(i = na - 1; i >= 0; i--) {
5493 muldq(t1, t0, r, base);
5494 adddq(t1, t0, 0, taba[i]);
5495 divdq(q, r, t1, t0, b);
5496 tabr[i] = q;
5497 }
5498 }
5499 return r;
5500 }
5501
mp_print_str_dec(const char * str,const limb_t * tab,slimb_t n)5502 static __maybe_unused void mp_print_str_dec(const char *str,
5503 const limb_t *tab, slimb_t n)
5504 {
5505 slimb_t i;
5506 printf("%s=", str);
5507 for(i = n - 1; i >= 0; i--) {
5508 if (i != n - 1)
5509 printf("_");
5510 printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5511 }
5512 printf("\n");
5513 }
5514
mp_print_str_h_dec(const char * str,const limb_t * tab,slimb_t n,limb_t high)5515 static __maybe_unused void mp_print_str_h_dec(const char *str,
5516 const limb_t *tab, slimb_t n,
5517 limb_t high)
5518 {
5519 slimb_t i;
5520 printf("%s=", str);
5521 printf("%0*" PRIu_LIMB, LIMB_DIGITS, high);
5522 for(i = n - 1; i >= 0; i--) {
5523 printf("_");
5524 printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5525 }
5526 printf("\n");
5527 }
5528
5529 //#define DEBUG_DIV_SLOW
5530
5531 #define DIV_STATIC_ALLOC_LEN 16
5532
5533 /* return q = a / b and r = a % b.
5534
5535 taba[na] must be allocated if tabb1[nb - 1] < B / 2. tabb1[nb - 1]
5536 must be != zero. na must be >= nb. 's' can be NULL if tabb1[nb - 1]
5537 >= B / 2.
5538
5539 The remainder is is returned in taba and contains nb libms. tabq
5540 contains na - nb + 1 limbs. No overlap is permitted.
5541
5542 Running time of the standard method: (na - nb + 1) * nb
5543 Return 0 if OK, -1 if memory alloc error
5544 */
5545 /* XXX: optimize */
mp_div_dec(bf_context_t * s,limb_t * tabq,limb_t * taba,mp_size_t na,const limb_t * tabb1,mp_size_t nb)5546 static int mp_div_dec(bf_context_t *s, limb_t *tabq,
5547 limb_t *taba, mp_size_t na,
5548 const limb_t *tabb1, mp_size_t nb)
5549 {
5550 limb_t base = BF_DEC_BASE;
5551 limb_t r, mult, t0, t1, a, c, q, v, *tabb;
5552 mp_size_t i, j;
5553 limb_t static_tabb[DIV_STATIC_ALLOC_LEN];
5554
5555 #ifdef DEBUG_DIV_SLOW
5556 mp_print_str_dec("a", taba, na);
5557 mp_print_str_dec("b", tabb1, nb);
5558 #endif
5559
5560 /* normalize tabb */
5561 r = tabb1[nb - 1];
5562 assert(r != 0);
5563 i = na - nb;
5564 if (r >= BF_DEC_BASE / 2) {
5565 mult = 1;
5566 tabb = (limb_t *)tabb1;
5567 q = 1;
5568 for(j = nb - 1; j >= 0; j--) {
5569 if (taba[i + j] != tabb[j]) {
5570 if (taba[i + j] < tabb[j])
5571 q = 0;
5572 break;
5573 }
5574 }
5575 tabq[i] = q;
5576 if (q) {
5577 mp_sub_dec(taba + i, taba + i, tabb, nb, 0);
5578 }
5579 i--;
5580 } else {
5581 mult = base / (r + 1);
5582 if (likely(nb <= DIV_STATIC_ALLOC_LEN)) {
5583 tabb = static_tabb;
5584 } else {
5585 tabb = bf_malloc(s, sizeof(limb_t) * nb);
5586 if (!tabb)
5587 return -1;
5588 }
5589 mp_mul1_dec(tabb, tabb1, nb, mult, 0);
5590 taba[na] = mp_mul1_dec(taba, taba, na, mult, 0);
5591 }
5592
5593 #ifdef DEBUG_DIV_SLOW
5594 printf("mult=" FMT_LIMB "\n", mult);
5595 mp_print_str_dec("a_norm", taba, na + 1);
5596 mp_print_str_dec("b_norm", tabb, nb);
5597 #endif
5598
5599 for(; i >= 0; i--) {
5600 if (unlikely(taba[i + nb] >= tabb[nb - 1])) {
5601 /* XXX: check if it is really possible */
5602 q = base - 1;
5603 } else {
5604 muldq(t1, t0, taba[i + nb], base);
5605 adddq(t1, t0, 0, taba[i + nb - 1]);
5606 divdq(q, r, t1, t0, tabb[nb - 1]);
5607 }
5608 // printf("i=%d q1=%ld\n", i, q);
5609
5610 r = mp_sub_mul1_dec(taba + i, tabb, nb, q);
5611 // mp_dump("r1", taba + i, nb, bd);
5612 // printf("r2=%ld\n", r);
5613
5614 v = taba[i + nb];
5615 a = v - r;
5616 c = a > v;
5617 if (c)
5618 a += base;
5619 taba[i + nb] = a;
5620
5621 if (c != 0) {
5622 /* negative result */
5623 for(;;) {
5624 q--;
5625 c = mp_add_dec(taba + i, taba + i, tabb, nb, 0);
5626 /* propagate carry and test if positive result */
5627 if (c != 0) {
5628 if (++taba[i + nb] == base) {
5629 break;
5630 }
5631 }
5632 }
5633 }
5634 tabq[i] = q;
5635 }
5636
5637 #ifdef DEBUG_DIV_SLOW
5638 mp_print_str_dec("q", tabq, na - nb + 1);
5639 mp_print_str_dec("r", taba, nb);
5640 #endif
5641
5642 /* remove the normalization */
5643 if (mult != 1) {
5644 mp_div1_dec(taba, taba, nb, mult, 0);
5645 if (unlikely(tabb != static_tabb))
5646 bf_free(s, tabb);
5647 }
5648 return 0;
5649 }
5650
5651 /* divide by 10^shift */
mp_shr_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t high)5652 static limb_t mp_shr_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5653 limb_t shift, limb_t high)
5654 {
5655 mp_size_t i;
5656 limb_t l, a, q, r;
5657
5658 assert(shift >= 1 && shift < LIMB_DIGITS);
5659 l = high;
5660 for(i = n - 1; i >= 0; i--) {
5661 a = tab[i];
5662 fast_udivrem(q, r, a, &mp_pow_div[shift]);
5663 tab_r[i] = q + l * mp_pow_dec[LIMB_DIGITS - shift];
5664 l = r;
5665 }
5666 return l;
5667 }
5668
5669 /* multiply by 10^shift */
mp_shl_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t low)5670 static limb_t mp_shl_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5671 limb_t shift, limb_t low)
5672 {
5673 mp_size_t i;
5674 limb_t l, a, q, r;
5675
5676 assert(shift >= 1 && shift < LIMB_DIGITS);
5677 l = low;
5678 for(i = 0; i < n; i++) {
5679 a = tab[i];
5680 fast_udivrem(q, r, a, &mp_pow_div[LIMB_DIGITS - shift]);
5681 tab_r[i] = r * mp_pow_dec[shift] + l;
5682 l = q;
5683 }
5684 return l;
5685 }
5686
mp_sqrtrem2_dec(limb_t * tabs,limb_t * taba)5687 static limb_t mp_sqrtrem2_dec(limb_t *tabs, limb_t *taba)
5688 {
5689 int k;
5690 dlimb_t a, b, r;
5691 limb_t taba1[2], s, r0, r1;
5692
5693 /* convert to binary and normalize */
5694 a = (dlimb_t)taba[1] * BF_DEC_BASE + taba[0];
5695 k = clz(a >> LIMB_BITS) & ~1;
5696 b = a << k;
5697 taba1[0] = b;
5698 taba1[1] = b >> LIMB_BITS;
5699 mp_sqrtrem2(&s, taba1);
5700 s >>= (k >> 1);
5701 /* convert the remainder back to decimal */
5702 r = a - (dlimb_t)s * (dlimb_t)s;
5703 divdq_base(r1, r0, r >> LIMB_BITS, r);
5704 taba[0] = r0;
5705 tabs[0] = s;
5706 return r1;
5707 }
5708
5709 //#define DEBUG_SQRTREM_DEC
5710
5711 /* tmp_buf must contain (n / 2 + 1 limbs) */
mp_sqrtrem_rec_dec(limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf)5712 static limb_t mp_sqrtrem_rec_dec(limb_t *tabs, limb_t *taba, limb_t n,
5713 limb_t *tmp_buf)
5714 {
5715 limb_t l, h, rh, ql, qh, c, i;
5716
5717 if (n == 1)
5718 return mp_sqrtrem2_dec(tabs, taba);
5719 #ifdef DEBUG_SQRTREM_DEC
5720 mp_print_str_dec("a", taba, 2 * n);
5721 #endif
5722 l = n / 2;
5723 h = n - l;
5724 qh = mp_sqrtrem_rec_dec(tabs + l, taba + 2 * l, h, tmp_buf);
5725 #ifdef DEBUG_SQRTREM_DEC
5726 mp_print_str_dec("s1", tabs + l, h);
5727 mp_print_str_h_dec("r1", taba + 2 * l, h, qh);
5728 mp_print_str_h_dec("r2", taba + l, n, qh);
5729 #endif
5730
5731 /* the remainder is in taba + 2 * l. Its high bit is in qh */
5732 if (qh) {
5733 mp_sub_dec(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
5734 }
5735 /* instead of dividing by 2*s, divide by s (which is normalized)
5736 and update q and r */
5737 mp_div_dec(NULL, tmp_buf, taba + l, n, tabs + l, h);
5738 qh += tmp_buf[l];
5739 for(i = 0; i < l; i++)
5740 tabs[i] = tmp_buf[i];
5741 ql = mp_div1_dec(tabs, tabs, l, 2, qh & 1);
5742 qh = qh >> 1; /* 0 or 1 */
5743 if (ql)
5744 rh = mp_add_dec(taba + l, taba + l, tabs + l, h, 0);
5745 else
5746 rh = 0;
5747 #ifdef DEBUG_SQRTREM_DEC
5748 mp_print_str_h_dec("q", tabs, l, qh);
5749 mp_print_str_h_dec("u", taba + l, h, rh);
5750 #endif
5751
5752 mp_add_ui_dec(tabs + l, qh, h);
5753 #ifdef DEBUG_SQRTREM_DEC
5754 mp_print_str_dec("s2", tabs, n);
5755 #endif
5756
5757 /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
5758 /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
5759 if (qh) {
5760 c = qh;
5761 } else {
5762 mp_mul_basecase_dec(taba + n, tabs, l, tabs, l);
5763 c = mp_sub_dec(taba, taba, taba + n, 2 * l, 0);
5764 }
5765 rh -= mp_sub_ui_dec(taba + 2 * l, c, n - 2 * l);
5766 if ((slimb_t)rh < 0) {
5767 mp_sub_ui_dec(tabs, 1, n);
5768 rh += mp_add_mul1_dec(taba, tabs, n, 2);
5769 rh += mp_add_ui_dec(taba, 1, n);
5770 }
5771 return rh;
5772 }
5773
5774 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= B/4. Return (s,
5775 r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2 * s. tabs has n
5776 limbs. r is returned in the lower n limbs of taba. Its r[n] is the
5777 returned value of the function. */
mp_sqrtrem_dec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)5778 int mp_sqrtrem_dec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
5779 {
5780 limb_t tmp_buf1[8];
5781 limb_t *tmp_buf;
5782 mp_size_t n2;
5783 n2 = n / 2 + 1;
5784 if (n2 <= countof(tmp_buf1)) {
5785 tmp_buf = tmp_buf1;
5786 } else {
5787 tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
5788 if (!tmp_buf)
5789 return -1;
5790 }
5791 taba[n] = mp_sqrtrem_rec_dec(tabs, taba, n, tmp_buf);
5792 if (tmp_buf != tmp_buf1)
5793 bf_free(s, tmp_buf);
5794 return 0;
5795 }
5796
5797 /* return the number of leading zero digits, from 0 to LIMB_DIGITS */
clz_dec(limb_t a)5798 static int clz_dec(limb_t a)
5799 {
5800 if (a == 0)
5801 return LIMB_DIGITS;
5802 switch(LIMB_BITS - 1 - clz(a)) {
5803 case 0: /* 1-1 */
5804 return LIMB_DIGITS - 1;
5805 case 1: /* 2-3 */
5806 return LIMB_DIGITS - 1;
5807 case 2: /* 4-7 */
5808 return LIMB_DIGITS - 1;
5809 case 3: /* 8-15 */
5810 if (a < 10)
5811 return LIMB_DIGITS - 1;
5812 else
5813 return LIMB_DIGITS - 2;
5814 case 4: /* 16-31 */
5815 return LIMB_DIGITS - 2;
5816 case 5: /* 32-63 */
5817 return LIMB_DIGITS - 2;
5818 case 6: /* 64-127 */
5819 if (a < 100)
5820 return LIMB_DIGITS - 2;
5821 else
5822 return LIMB_DIGITS - 3;
5823 case 7: /* 128-255 */
5824 return LIMB_DIGITS - 3;
5825 case 8: /* 256-511 */
5826 return LIMB_DIGITS - 3;
5827 case 9: /* 512-1023 */
5828 if (a < 1000)
5829 return LIMB_DIGITS - 3;
5830 else
5831 return LIMB_DIGITS - 4;
5832 case 10: /* 1024-2047 */
5833 return LIMB_DIGITS - 4;
5834 case 11: /* 2048-4095 */
5835 return LIMB_DIGITS - 4;
5836 case 12: /* 4096-8191 */
5837 return LIMB_DIGITS - 4;
5838 case 13: /* 8192-16383 */
5839 if (a < 10000)
5840 return LIMB_DIGITS - 4;
5841 else
5842 return LIMB_DIGITS - 5;
5843 case 14: /* 16384-32767 */
5844 return LIMB_DIGITS - 5;
5845 case 15: /* 32768-65535 */
5846 return LIMB_DIGITS - 5;
5847 case 16: /* 65536-131071 */
5848 if (a < 100000)
5849 return LIMB_DIGITS - 5;
5850 else
5851 return LIMB_DIGITS - 6;
5852 case 17: /* 131072-262143 */
5853 return LIMB_DIGITS - 6;
5854 case 18: /* 262144-524287 */
5855 return LIMB_DIGITS - 6;
5856 case 19: /* 524288-1048575 */
5857 if (a < 1000000)
5858 return LIMB_DIGITS - 6;
5859 else
5860 return LIMB_DIGITS - 7;
5861 case 20: /* 1048576-2097151 */
5862 return LIMB_DIGITS - 7;
5863 case 21: /* 2097152-4194303 */
5864 return LIMB_DIGITS - 7;
5865 case 22: /* 4194304-8388607 */
5866 return LIMB_DIGITS - 7;
5867 case 23: /* 8388608-16777215 */
5868 if (a < 10000000)
5869 return LIMB_DIGITS - 7;
5870 else
5871 return LIMB_DIGITS - 8;
5872 case 24: /* 16777216-33554431 */
5873 return LIMB_DIGITS - 8;
5874 case 25: /* 33554432-67108863 */
5875 return LIMB_DIGITS - 8;
5876 case 26: /* 67108864-134217727 */
5877 if (a < 100000000)
5878 return LIMB_DIGITS - 8;
5879 else
5880 return LIMB_DIGITS - 9;
5881 #if LIMB_BITS == 64
5882 case 27: /* 134217728-268435455 */
5883 return LIMB_DIGITS - 9;
5884 case 28: /* 268435456-536870911 */
5885 return LIMB_DIGITS - 9;
5886 case 29: /* 536870912-1073741823 */
5887 if (a < 1000000000)
5888 return LIMB_DIGITS - 9;
5889 else
5890 return LIMB_DIGITS - 10;
5891 case 30: /* 1073741824-2147483647 */
5892 return LIMB_DIGITS - 10;
5893 case 31: /* 2147483648-4294967295 */
5894 return LIMB_DIGITS - 10;
5895 case 32: /* 4294967296-8589934591 */
5896 return LIMB_DIGITS - 10;
5897 case 33: /* 8589934592-17179869183 */
5898 if (a < 10000000000)
5899 return LIMB_DIGITS - 10;
5900 else
5901 return LIMB_DIGITS - 11;
5902 case 34: /* 17179869184-34359738367 */
5903 return LIMB_DIGITS - 11;
5904 case 35: /* 34359738368-68719476735 */
5905 return LIMB_DIGITS - 11;
5906 case 36: /* 68719476736-137438953471 */
5907 if (a < 100000000000)
5908 return LIMB_DIGITS - 11;
5909 else
5910 return LIMB_DIGITS - 12;
5911 case 37: /* 137438953472-274877906943 */
5912 return LIMB_DIGITS - 12;
5913 case 38: /* 274877906944-549755813887 */
5914 return LIMB_DIGITS - 12;
5915 case 39: /* 549755813888-1099511627775 */
5916 if (a < 1000000000000)
5917 return LIMB_DIGITS - 12;
5918 else
5919 return LIMB_DIGITS - 13;
5920 case 40: /* 1099511627776-2199023255551 */
5921 return LIMB_DIGITS - 13;
5922 case 41: /* 2199023255552-4398046511103 */
5923 return LIMB_DIGITS - 13;
5924 case 42: /* 4398046511104-8796093022207 */
5925 return LIMB_DIGITS - 13;
5926 case 43: /* 8796093022208-17592186044415 */
5927 if (a < 10000000000000)
5928 return LIMB_DIGITS - 13;
5929 else
5930 return LIMB_DIGITS - 14;
5931 case 44: /* 17592186044416-35184372088831 */
5932 return LIMB_DIGITS - 14;
5933 case 45: /* 35184372088832-70368744177663 */
5934 return LIMB_DIGITS - 14;
5935 case 46: /* 70368744177664-140737488355327 */
5936 if (a < 100000000000000)
5937 return LIMB_DIGITS - 14;
5938 else
5939 return LIMB_DIGITS - 15;
5940 case 47: /* 140737488355328-281474976710655 */
5941 return LIMB_DIGITS - 15;
5942 case 48: /* 281474976710656-562949953421311 */
5943 return LIMB_DIGITS - 15;
5944 case 49: /* 562949953421312-1125899906842623 */
5945 if (a < 1000000000000000)
5946 return LIMB_DIGITS - 15;
5947 else
5948 return LIMB_DIGITS - 16;
5949 case 50: /* 1125899906842624-2251799813685247 */
5950 return LIMB_DIGITS - 16;
5951 case 51: /* 2251799813685248-4503599627370495 */
5952 return LIMB_DIGITS - 16;
5953 case 52: /* 4503599627370496-9007199254740991 */
5954 return LIMB_DIGITS - 16;
5955 case 53: /* 9007199254740992-18014398509481983 */
5956 if (a < 10000000000000000)
5957 return LIMB_DIGITS - 16;
5958 else
5959 return LIMB_DIGITS - 17;
5960 case 54: /* 18014398509481984-36028797018963967 */
5961 return LIMB_DIGITS - 17;
5962 case 55: /* 36028797018963968-72057594037927935 */
5963 return LIMB_DIGITS - 17;
5964 case 56: /* 72057594037927936-144115188075855871 */
5965 if (a < 100000000000000000)
5966 return LIMB_DIGITS - 17;
5967 else
5968 return LIMB_DIGITS - 18;
5969 case 57: /* 144115188075855872-288230376151711743 */
5970 return LIMB_DIGITS - 18;
5971 case 58: /* 288230376151711744-576460752303423487 */
5972 return LIMB_DIGITS - 18;
5973 case 59: /* 576460752303423488-1152921504606846975 */
5974 if (a < 1000000000000000000)
5975 return LIMB_DIGITS - 18;
5976 else
5977 return LIMB_DIGITS - 19;
5978 #endif
5979 default:
5980 return 0;
5981 }
5982 }
5983
5984 /* for debugging */
bfdec_print_str(const char * str,const bfdec_t * a)5985 void bfdec_print_str(const char *str, const bfdec_t *a)
5986 {
5987 slimb_t i;
5988 printf("%s=", str);
5989
5990 if (a->expn == BF_EXP_NAN) {
5991 printf("NaN");
5992 } else {
5993 if (a->sign)
5994 putchar('-');
5995 if (a->expn == BF_EXP_ZERO) {
5996 putchar('0');
5997 } else if (a->expn == BF_EXP_INF) {
5998 printf("Inf");
5999 } else {
6000 printf("0.");
6001 for(i = a->len - 1; i >= 0; i--)
6002 printf("%0*" PRIu_LIMB, LIMB_DIGITS, a->tab[i]);
6003 printf("e%" PRId_LIMB, a->expn);
6004 }
6005 }
6006 printf("\n");
6007 }
6008
6009 /* return != 0 if one digit between 0 and bit_pos inclusive is not zero. */
scan_digit_nz(const bfdec_t * r,slimb_t bit_pos)6010 static inline limb_t scan_digit_nz(const bfdec_t *r, slimb_t bit_pos)
6011 {
6012 slimb_t pos;
6013 limb_t v, q;
6014 int shift;
6015
6016 if (bit_pos < 0)
6017 return 0;
6018 pos = (limb_t)bit_pos / LIMB_DIGITS;
6019 shift = (limb_t)bit_pos % LIMB_DIGITS;
6020 fast_udivrem(q, v, r->tab[pos], &mp_pow_div[shift + 1]);
6021 (void)q;
6022 if (v != 0)
6023 return 1;
6024 pos--;
6025 while (pos >= 0) {
6026 if (r->tab[pos] != 0)
6027 return 1;
6028 pos--;
6029 }
6030 return 0;
6031 }
6032
get_digit(const limb_t * tab,limb_t len,slimb_t pos)6033 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos)
6034 {
6035 slimb_t i;
6036 int shift;
6037 i = floor_div(pos, LIMB_DIGITS);
6038 if (i < 0 || i >= len)
6039 return 0;
6040 shift = pos - i * LIMB_DIGITS;
6041 return fast_udiv(tab[i], &mp_pow_div[shift]) % 10;
6042 }
6043
6044 #if 0
6045 static limb_t get_digits(const limb_t *tab, limb_t len, slimb_t pos)
6046 {
6047 limb_t a0, a1;
6048 int shift;
6049 slimb_t i;
6050
6051 i = floor_div(pos, LIMB_DIGITS);
6052 shift = pos - i * LIMB_DIGITS;
6053 if (i >= 0 && i < len)
6054 a0 = tab[i];
6055 else
6056 a0 = 0;
6057 if (shift == 0) {
6058 return a0;
6059 } else {
6060 i++;
6061 if (i >= 0 && i < len)
6062 a1 = tab[i];
6063 else
6064 a1 = 0;
6065 return fast_udiv(a0, &mp_pow_div[shift]) +
6066 fast_urem(a1, &mp_pow_div[LIMB_DIGITS - shift]) *
6067 mp_pow_dec[shift];
6068 }
6069 }
6070 #endif
6071
6072 /* return the addend for rounding. Note that prec can be <= 0 for bf_rint() */
bfdec_get_rnd_add(int * pret,const bfdec_t * r,limb_t l,slimb_t prec,int rnd_mode)6073 static int bfdec_get_rnd_add(int *pret, const bfdec_t *r, limb_t l,
6074 slimb_t prec, int rnd_mode)
6075 {
6076 int add_one, inexact;
6077 limb_t digit1, digit0;
6078
6079 // bfdec_print_str("get_rnd_add", r);
6080 if (rnd_mode == BF_RNDF) {
6081 digit0 = 1; /* faithful rounding does not honor the INEXACT flag */
6082 } else {
6083 /* starting limb for bit 'prec + 1' */
6084 digit0 = scan_digit_nz(r, l * LIMB_DIGITS - 1 - bf_max(0, prec + 1));
6085 }
6086
6087 /* get the digit at 'prec' */
6088 digit1 = get_digit(r->tab, l, l * LIMB_DIGITS - 1 - prec);
6089 inexact = (digit1 | digit0) != 0;
6090
6091 add_one = 0;
6092 switch(rnd_mode) {
6093 case BF_RNDZ:
6094 break;
6095 case BF_RNDN:
6096 if (digit1 == 5) {
6097 if (digit0) {
6098 add_one = 1;
6099 } else {
6100 /* round to even */
6101 add_one =
6102 get_digit(r->tab, l, l * LIMB_DIGITS - 1 - (prec - 1)) & 1;
6103 }
6104 } else if (digit1 > 5) {
6105 add_one = 1;
6106 }
6107 break;
6108 case BF_RNDD:
6109 case BF_RNDU:
6110 if (r->sign == (rnd_mode == BF_RNDD))
6111 add_one = inexact;
6112 break;
6113 case BF_RNDNA:
6114 case BF_RNDF:
6115 add_one = (digit1 >= 5);
6116 break;
6117 case BF_RNDNU:
6118 if (digit1 >= 5) {
6119 if (r->sign)
6120 add_one = (digit0 != 0);
6121 else
6122 add_one = 1;
6123 }
6124 break;
6125 default:
6126 abort();
6127 }
6128
6129 if (inexact)
6130 *pret |= BF_ST_INEXACT;
6131 return add_one;
6132 }
6133
6134 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
6135 assumed to have length 'l' (1 <= l <= r->len). prec1 can be
6136 BF_PREC_INF. BF_FLAG_SUBNORMAL is not supported. Cannot fail with
6137 BF_ST_MEM_ERROR.
6138 */
__bfdec_round(bfdec_t * r,limb_t prec1,bf_flags_t flags,limb_t l)6139 static int __bfdec_round(bfdec_t *r, limb_t prec1, bf_flags_t flags, limb_t l)
6140 {
6141 int shift, add_one, rnd_mode, ret;
6142 slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
6143
6144 /* e_min and e_max are computed to match the IEEE 754 conventions */
6145 /* XXX: does not matter for decimal numbers */
6146 e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
6147 e_min = -e_range + 3;
6148 e_max = e_range;
6149
6150 if (flags & BF_FLAG_RADPNT_PREC) {
6151 /* 'prec' is the precision after the decimal point */
6152 if (prec1 != BF_PREC_INF)
6153 prec = r->expn + prec1;
6154 else
6155 prec = prec1;
6156 } else {
6157 prec = prec1;
6158 }
6159
6160 /* round to prec bits */
6161 rnd_mode = flags & BF_RND_MASK;
6162 ret = 0;
6163 add_one = bfdec_get_rnd_add(&ret, r, l, prec, rnd_mode);
6164
6165 if (prec <= 0) {
6166 if (add_one) {
6167 bfdec_resize(r, 1); /* cannot fail because r is non zero */
6168 r->tab[0] = BF_DEC_BASE / 10;
6169 r->expn += 1 - prec;
6170 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6171 return ret;
6172 } else {
6173 goto underflow;
6174 }
6175 } else if (add_one) {
6176 limb_t carry;
6177
6178 /* add one starting at digit 'prec - 1' */
6179 bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6180 pos = bit_pos / LIMB_DIGITS;
6181 carry = mp_pow_dec[bit_pos % LIMB_DIGITS];
6182 carry = mp_add_ui_dec(r->tab + pos, carry, l - pos);
6183 if (carry) {
6184 /* shift right by one digit */
6185 mp_shr_dec(r->tab + pos, r->tab + pos, l - pos, 1, 1);
6186 r->expn++;
6187 }
6188 }
6189
6190 /* check underflow */
6191 if (unlikely(r->expn < e_min)) {
6192 underflow:
6193 bfdec_set_zero(r, r->sign);
6194 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6195 return ret;
6196 }
6197
6198 /* check overflow */
6199 if (unlikely(r->expn > e_max)) {
6200 bfdec_set_inf(r, r->sign);
6201 ret |= BF_ST_OVERFLOW | BF_ST_INEXACT;
6202 return ret;
6203 }
6204
6205 /* keep the bits starting at 'prec - 1' */
6206 bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6207 i = floor_div(bit_pos, LIMB_DIGITS);
6208 if (i >= 0) {
6209 shift = smod(bit_pos, LIMB_DIGITS);
6210 if (shift != 0) {
6211 r->tab[i] = fast_udiv(r->tab[i], &mp_pow_div[shift]) *
6212 mp_pow_dec[shift];
6213 }
6214 } else {
6215 i = 0;
6216 }
6217 /* remove trailing zeros */
6218 while (r->tab[i] == 0)
6219 i++;
6220 if (i > 0) {
6221 l -= i;
6222 memmove(r->tab, r->tab + i, l * sizeof(limb_t));
6223 }
6224 bfdec_resize(r, l); /* cannot fail */
6225 return ret;
6226 }
6227
6228 /* Cannot fail with BF_ST_MEM_ERROR. */
bfdec_round(bfdec_t * r,limb_t prec,bf_flags_t flags)6229 int bfdec_round(bfdec_t *r, limb_t prec, bf_flags_t flags)
6230 {
6231 if (r->len == 0)
6232 return 0;
6233 return __bfdec_round(r, prec, flags, r->len);
6234 }
6235
6236 /* 'r' must be a finite number. Cannot fail with BF_ST_MEM_ERROR. */
bfdec_normalize_and_round(bfdec_t * r,limb_t prec1,bf_flags_t flags)6237 int bfdec_normalize_and_round(bfdec_t *r, limb_t prec1, bf_flags_t flags)
6238 {
6239 limb_t l, v;
6240 int shift, ret;
6241
6242 // bfdec_print_str("bf_renorm", r);
6243 l = r->len;
6244 while (l > 0 && r->tab[l - 1] == 0)
6245 l--;
6246 if (l == 0) {
6247 /* zero */
6248 r->expn = BF_EXP_ZERO;
6249 bfdec_resize(r, 0); /* cannot fail */
6250 ret = 0;
6251 } else {
6252 r->expn -= (r->len - l) * LIMB_DIGITS;
6253 /* shift to have the MSB set to '1' */
6254 v = r->tab[l - 1];
6255 shift = clz_dec(v);
6256 if (shift != 0) {
6257 mp_shl_dec(r->tab, r->tab, l, shift, 0);
6258 r->expn -= shift;
6259 }
6260 ret = __bfdec_round(r, prec1, flags, l);
6261 }
6262 // bf_print_str("r_final", r);
6263 return ret;
6264 }
6265
bfdec_set_ui(bfdec_t * r,uint64_t v)6266 int bfdec_set_ui(bfdec_t *r, uint64_t v)
6267 {
6268 #if LIMB_BITS == 32
6269 if (v >= BF_DEC_BASE * BF_DEC_BASE) {
6270 if (bfdec_resize(r, 3))
6271 goto fail;
6272 r->tab[0] = v % BF_DEC_BASE;
6273 v /= BF_DEC_BASE;
6274 r->tab[1] = v % BF_DEC_BASE;
6275 r->tab[2] = v / BF_DEC_BASE;
6276 r->expn = 3 * LIMB_DIGITS;
6277 } else
6278 #endif
6279 if (v >= BF_DEC_BASE) {
6280 if (bfdec_resize(r, 2))
6281 goto fail;
6282 r->tab[0] = v % BF_DEC_BASE;
6283 r->tab[1] = v / BF_DEC_BASE;
6284 r->expn = 2 * LIMB_DIGITS;
6285 } else {
6286 if (bfdec_resize(r, 1))
6287 goto fail;
6288 r->tab[0] = v;
6289 r->expn = LIMB_DIGITS;
6290 }
6291 r->sign = 0;
6292 return bfdec_normalize_and_round(r, BF_PREC_INF, 0);
6293 fail:
6294 bfdec_set_nan(r);
6295 return BF_ST_MEM_ERROR;
6296 }
6297
bfdec_set_si(bfdec_t * r,int64_t v)6298 int bfdec_set_si(bfdec_t *r, int64_t v)
6299 {
6300 int ret;
6301 if (v < 0) {
6302 ret = bfdec_set_ui(r, -v);
6303 r->sign = 1;
6304 } else {
6305 ret = bfdec_set_ui(r, v);
6306 }
6307 return ret;
6308 }
6309
bfdec_add_internal(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int b_neg)6310 static int bfdec_add_internal(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec, bf_flags_t flags, int b_neg)
6311 {
6312 bf_context_t *s = r->ctx;
6313 int is_sub, cmp_res, a_sign, b_sign, ret;
6314
6315 a_sign = a->sign;
6316 b_sign = b->sign ^ b_neg;
6317 is_sub = a_sign ^ b_sign;
6318 cmp_res = bfdec_cmpu(a, b);
6319 if (cmp_res < 0) {
6320 const bfdec_t *tmp;
6321 tmp = a;
6322 a = b;
6323 b = tmp;
6324 a_sign = b_sign; /* b_sign is never used later */
6325 }
6326 /* abs(a) >= abs(b) */
6327 if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
6328 /* zero result */
6329 bfdec_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
6330 ret = 0;
6331 } else if (a->len == 0 || b->len == 0) {
6332 ret = 0;
6333 if (a->expn >= BF_EXP_INF) {
6334 if (a->expn == BF_EXP_NAN) {
6335 /* at least one operand is NaN */
6336 bfdec_set_nan(r);
6337 ret = 0;
6338 } else if (b->expn == BF_EXP_INF && is_sub) {
6339 /* infinities with different signs */
6340 bfdec_set_nan(r);
6341 ret = BF_ST_INVALID_OP;
6342 } else {
6343 bfdec_set_inf(r, a_sign);
6344 }
6345 } else {
6346 /* at least one zero and not subtract */
6347 if (bfdec_set(r, a))
6348 return BF_ST_MEM_ERROR;
6349 r->sign = a_sign;
6350 goto renorm;
6351 }
6352 } else {
6353 slimb_t d, a_offset, b_offset, i, r_len;
6354 limb_t carry;
6355 limb_t *b1_tab;
6356 int b_shift;
6357 mp_size_t b1_len;
6358
6359 d = a->expn - b->expn;
6360
6361 /* XXX: not efficient in time and memory if the precision is
6362 not infinite */
6363 r_len = bf_max(a->len, b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6364 if (bfdec_resize(r, r_len))
6365 goto fail;
6366 r->sign = a_sign;
6367 r->expn = a->expn;
6368
6369 a_offset = r_len - a->len;
6370 for(i = 0; i < a_offset; i++)
6371 r->tab[i] = 0;
6372 for(i = 0; i < a->len; i++)
6373 r->tab[a_offset + i] = a->tab[i];
6374
6375 b_shift = d % LIMB_DIGITS;
6376 if (b_shift == 0) {
6377 b1_len = b->len;
6378 b1_tab = (limb_t *)b->tab;
6379 } else {
6380 b1_len = b->len + 1;
6381 b1_tab = bf_malloc(s, sizeof(limb_t) * b1_len);
6382 if (!b1_tab)
6383 goto fail;
6384 b1_tab[0] = mp_shr_dec(b1_tab + 1, b->tab, b->len, b_shift, 0) *
6385 mp_pow_dec[LIMB_DIGITS - b_shift];
6386 }
6387 b_offset = r_len - (b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6388
6389 if (is_sub) {
6390 carry = mp_sub_dec(r->tab + b_offset, r->tab + b_offset,
6391 b1_tab, b1_len, 0);
6392 if (carry != 0) {
6393 carry = mp_sub_ui_dec(r->tab + b_offset + b1_len, carry,
6394 r_len - (b_offset + b1_len));
6395 assert(carry == 0);
6396 }
6397 } else {
6398 carry = mp_add_dec(r->tab + b_offset, r->tab + b_offset,
6399 b1_tab, b1_len, 0);
6400 if (carry != 0) {
6401 carry = mp_add_ui_dec(r->tab + b_offset + b1_len, carry,
6402 r_len - (b_offset + b1_len));
6403 }
6404 if (carry != 0) {
6405 if (bfdec_resize(r, r_len + 1)) {
6406 if (b_shift != 0)
6407 bf_free(s, b1_tab);
6408 goto fail;
6409 }
6410 r->tab[r_len] = 1;
6411 r->expn += LIMB_DIGITS;
6412 }
6413 }
6414 if (b_shift != 0)
6415 bf_free(s, b1_tab);
6416 renorm:
6417 ret = bfdec_normalize_and_round(r, prec, flags);
6418 }
6419 return ret;
6420 fail:
6421 bfdec_set_nan(r);
6422 return BF_ST_MEM_ERROR;
6423 }
6424
__bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6425 static int __bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6426 bf_flags_t flags)
6427 {
6428 return bfdec_add_internal(r, a, b, prec, flags, 0);
6429 }
6430
__bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6431 static int __bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6432 bf_flags_t flags)
6433 {
6434 return bfdec_add_internal(r, a, b, prec, flags, 1);
6435 }
6436
bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6437 int bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6438 bf_flags_t flags)
6439 {
6440 return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6441 (bf_op2_func_t *)__bfdec_add);
6442 }
6443
bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6444 int bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6445 bf_flags_t flags)
6446 {
6447 return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6448 (bf_op2_func_t *)__bfdec_sub);
6449 }
6450
bfdec_mul(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6451 int bfdec_mul(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6452 bf_flags_t flags)
6453 {
6454 int ret, r_sign;
6455
6456 if (a->len < b->len) {
6457 const bfdec_t *tmp = a;
6458 a = b;
6459 b = tmp;
6460 }
6461 r_sign = a->sign ^ b->sign;
6462 /* here b->len <= a->len */
6463 if (b->len == 0) {
6464 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6465 bfdec_set_nan(r);
6466 ret = 0;
6467 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
6468 if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
6469 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
6470 bfdec_set_nan(r);
6471 ret = BF_ST_INVALID_OP;
6472 } else {
6473 bfdec_set_inf(r, r_sign);
6474 ret = 0;
6475 }
6476 } else {
6477 bfdec_set_zero(r, r_sign);
6478 ret = 0;
6479 }
6480 } else {
6481 bfdec_t tmp, *r1 = NULL;
6482 limb_t a_len, b_len;
6483 limb_t *a_tab, *b_tab;
6484
6485 a_len = a->len;
6486 b_len = b->len;
6487 a_tab = a->tab;
6488 b_tab = b->tab;
6489
6490 if (r == a || r == b) {
6491 bfdec_init(r->ctx, &tmp);
6492 r1 = r;
6493 r = &tmp;
6494 }
6495 if (bfdec_resize(r, a_len + b_len)) {
6496 bfdec_set_nan(r);
6497 ret = BF_ST_MEM_ERROR;
6498 goto done;
6499 }
6500 mp_mul_basecase_dec(r->tab, a_tab, a_len, b_tab, b_len);
6501 r->sign = r_sign;
6502 r->expn = a->expn + b->expn;
6503 ret = bfdec_normalize_and_round(r, prec, flags);
6504 done:
6505 if (r == &tmp)
6506 bfdec_move(r1, &tmp);
6507 }
6508 return ret;
6509 }
6510
bfdec_mul_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6511 int bfdec_mul_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6512 bf_flags_t flags)
6513 {
6514 bfdec_t b;
6515 int ret;
6516 bfdec_init(r->ctx, &b);
6517 ret = bfdec_set_si(&b, b1);
6518 ret |= bfdec_mul(r, a, &b, prec, flags);
6519 bfdec_delete(&b);
6520 return ret;
6521 }
6522
bfdec_add_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6523 int bfdec_add_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6524 bf_flags_t flags)
6525 {
6526 bfdec_t b;
6527 int ret;
6528
6529 bfdec_init(r->ctx, &b);
6530 ret = bfdec_set_si(&b, b1);
6531 ret |= bfdec_add(r, a, &b, prec, flags);
6532 bfdec_delete(&b);
6533 return ret;
6534 }
6535
__bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6536 static int __bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6537 limb_t prec, bf_flags_t flags)
6538 {
6539 int ret, r_sign;
6540 limb_t n, nb, precl;
6541
6542 r_sign = a->sign ^ b->sign;
6543 if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
6544 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6545 bfdec_set_nan(r);
6546 return 0;
6547 } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
6548 bfdec_set_nan(r);
6549 return BF_ST_INVALID_OP;
6550 } else if (a->expn == BF_EXP_INF) {
6551 bfdec_set_inf(r, r_sign);
6552 return 0;
6553 } else {
6554 bfdec_set_zero(r, r_sign);
6555 return 0;
6556 }
6557 } else if (a->expn == BF_EXP_ZERO) {
6558 if (b->expn == BF_EXP_ZERO) {
6559 bfdec_set_nan(r);
6560 return BF_ST_INVALID_OP;
6561 } else {
6562 bfdec_set_zero(r, r_sign);
6563 return 0;
6564 }
6565 } else if (b->expn == BF_EXP_ZERO) {
6566 bfdec_set_inf(r, r_sign);
6567 return BF_ST_DIVIDE_ZERO;
6568 }
6569
6570 nb = b->len;
6571 if (prec == BF_PREC_INF) {
6572 /* infinite precision: return BF_ST_INVALID_OP if not an exact
6573 result */
6574 /* XXX: check */
6575 precl = nb + 1;
6576 } else if (flags & BF_FLAG_RADPNT_PREC) {
6577 /* number of digits after the decimal point */
6578 /* XXX: check (2 extra digits for rounding + 2 digits) */
6579 precl = (bf_max(a->expn - b->expn, 0) + 2 +
6580 2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6581 } else {
6582 /* number of limbs of the quotient (2 extra digits for rounding) */
6583 precl = (prec + 2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6584 }
6585 n = bf_max(a->len, precl);
6586
6587 {
6588 limb_t *taba, na, i;
6589 slimb_t d;
6590
6591 na = n + nb;
6592 taba = bf_malloc(r->ctx, (na + 1) * sizeof(limb_t));
6593 if (!taba)
6594 goto fail;
6595 d = na - a->len;
6596 memset(taba, 0, d * sizeof(limb_t));
6597 memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
6598 if (bfdec_resize(r, n + 1))
6599 goto fail1;
6600 if (mp_div_dec(r->ctx, r->tab, taba, na, b->tab, nb)) {
6601 fail1:
6602 bf_free(r->ctx, taba);
6603 goto fail;
6604 }
6605 /* see if non zero remainder */
6606 for(i = 0; i < nb; i++) {
6607 if (taba[i] != 0)
6608 break;
6609 }
6610 bf_free(r->ctx, taba);
6611 if (i != nb) {
6612 if (prec == BF_PREC_INF) {
6613 bfdec_set_nan(r);
6614 return BF_ST_INVALID_OP;
6615 } else {
6616 r->tab[0] |= 1;
6617 }
6618 }
6619 r->expn = a->expn - b->expn + LIMB_DIGITS;
6620 r->sign = r_sign;
6621 ret = bfdec_normalize_and_round(r, prec, flags);
6622 }
6623 return ret;
6624 fail:
6625 bfdec_set_nan(r);
6626 return BF_ST_MEM_ERROR;
6627 }
6628
bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6629 int bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6630 bf_flags_t flags)
6631 {
6632 return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6633 (bf_op2_func_t *)__bfdec_div);
6634 }
6635
6636 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
6637 integer defined as floor(a/b) and r = a - q * b. */
bfdec_tdivremu(bf_context_t * s,bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b)6638 static void bfdec_tdivremu(bf_context_t *s, bfdec_t *q, bfdec_t *r,
6639 const bfdec_t *a, const bfdec_t *b)
6640 {
6641 if (bfdec_cmpu(a, b) < 0) {
6642 bfdec_set_ui(q, 0);
6643 bfdec_set(r, a);
6644 } else {
6645 bfdec_div(q, a, b, 0, BF_RNDZ | BF_FLAG_RADPNT_PREC);
6646 bfdec_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
6647 bfdec_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
6648 }
6649 }
6650
6651 /* division and remainder.
6652
6653 rnd_mode is the rounding mode for the quotient. The additional
6654 rounding mode BF_RND_EUCLIDIAN is supported.
6655
6656 'q' is an integer. 'r' is rounded with prec and flags (prec can be
6657 BF_PREC_INF).
6658 */
bfdec_divrem(bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)6659 int bfdec_divrem(bfdec_t *q, bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6660 limb_t prec, bf_flags_t flags, int rnd_mode)
6661 {
6662 bf_context_t *s = q->ctx;
6663 bfdec_t a1_s, *a1 = &a1_s;
6664 bfdec_t b1_s, *b1 = &b1_s;
6665 bfdec_t r1_s, *r1 = &r1_s;
6666 int q_sign, res;
6667 BOOL is_ceil, is_rndn;
6668
6669 assert(q != a && q != b);
6670 assert(r != a && r != b);
6671 assert(q != r);
6672
6673 if (a->len == 0 || b->len == 0) {
6674 bfdec_set_zero(q, 0);
6675 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6676 bfdec_set_nan(r);
6677 return 0;
6678 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
6679 bfdec_set_nan(r);
6680 return BF_ST_INVALID_OP;
6681 } else {
6682 bfdec_set(r, a);
6683 return bfdec_round(r, prec, flags);
6684 }
6685 }
6686
6687 q_sign = a->sign ^ b->sign;
6688 is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA ||
6689 rnd_mode == BF_RNDNU);
6690 switch(rnd_mode) {
6691 default:
6692 case BF_RNDZ:
6693 case BF_RNDN:
6694 case BF_RNDNA:
6695 is_ceil = FALSE;
6696 break;
6697 case BF_RNDD:
6698 is_ceil = q_sign;
6699 break;
6700 case BF_RNDU:
6701 is_ceil = q_sign ^ 1;
6702 break;
6703 case BF_DIVREM_EUCLIDIAN:
6704 is_ceil = a->sign;
6705 break;
6706 case BF_RNDNU:
6707 /* XXX: unsupported yet */
6708 abort();
6709 }
6710
6711 a1->expn = a->expn;
6712 a1->tab = a->tab;
6713 a1->len = a->len;
6714 a1->sign = 0;
6715
6716 b1->expn = b->expn;
6717 b1->tab = b->tab;
6718 b1->len = b->len;
6719 b1->sign = 0;
6720
6721 // bfdec_print_str("a1", a1);
6722 // bfdec_print_str("b1", b1);
6723 /* XXX: could improve to avoid having a large 'q' */
6724 bfdec_tdivremu(s, q, r, a1, b1);
6725 if (bfdec_is_nan(q) || bfdec_is_nan(r))
6726 goto fail;
6727 // bfdec_print_str("q", q);
6728 // bfdec_print_str("r", r);
6729
6730 if (r->len != 0) {
6731 if (is_rndn) {
6732 bfdec_init(s, r1);
6733 if (bfdec_set(r1, r))
6734 goto fail;
6735 if (bfdec_mul_si(r1, r1, 2, BF_PREC_INF, BF_RNDZ)) {
6736 bfdec_delete(r1);
6737 goto fail;
6738 }
6739 res = bfdec_cmpu(r1, b);
6740 bfdec_delete(r1);
6741 if (res > 0 ||
6742 (res == 0 &&
6743 (rnd_mode == BF_RNDNA ||
6744 (get_digit(q->tab, q->len, q->len * LIMB_DIGITS - q->expn) & 1) != 0))) {
6745 goto do_sub_r;
6746 }
6747 } else if (is_ceil) {
6748 do_sub_r:
6749 res = bfdec_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
6750 res |= bfdec_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
6751 if (res & BF_ST_MEM_ERROR)
6752 goto fail;
6753 }
6754 }
6755
6756 r->sign ^= a->sign;
6757 q->sign = q_sign;
6758 return bfdec_round(r, prec, flags);
6759 fail:
6760 bfdec_set_nan(q);
6761 bfdec_set_nan(r);
6762 return BF_ST_MEM_ERROR;
6763 }
6764
bfdec_fmod(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6765 int bfdec_fmod(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6766 bf_flags_t flags)
6767 {
6768 bfdec_t q_s, *q = &q_s;
6769 int ret;
6770
6771 bfdec_init(r->ctx, q);
6772 ret = bfdec_divrem(q, r, a, b, prec, flags, BF_RNDZ);
6773 bfdec_delete(q);
6774 return ret;
6775 }
6776
6777 /* convert to integer (infinite precision) */
bfdec_rint(bfdec_t * r,int rnd_mode)6778 int bfdec_rint(bfdec_t *r, int rnd_mode)
6779 {
6780 return bfdec_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
6781 }
6782
bfdec_sqrt(bfdec_t * r,const bfdec_t * a,limb_t prec,bf_flags_t flags)6783 int bfdec_sqrt(bfdec_t *r, const bfdec_t *a, limb_t prec, bf_flags_t flags)
6784 {
6785 bf_context_t *s = a->ctx;
6786 int ret, k;
6787 limb_t *a1, v;
6788 slimb_t n, n1;
6789 limb_t res;
6790
6791 assert(r != a);
6792
6793 if (a->len == 0) {
6794 if (a->expn == BF_EXP_NAN) {
6795 bfdec_set_nan(r);
6796 } else if (a->expn == BF_EXP_INF && a->sign) {
6797 goto invalid_op;
6798 } else {
6799 bfdec_set(r, a);
6800 }
6801 ret = 0;
6802 } else if (a->sign || prec == BF_PREC_INF) {
6803 invalid_op:
6804 bfdec_set_nan(r);
6805 ret = BF_ST_INVALID_OP;
6806 } else {
6807 /* convert the mantissa to an integer with at least 2 *
6808 prec + 4 digits */
6809 n = (2 * (prec + 2) + 2 * LIMB_DIGITS - 1) / (2 * LIMB_DIGITS);
6810 if (bfdec_resize(r, n))
6811 goto fail;
6812 a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
6813 if (!a1)
6814 goto fail;
6815 n1 = bf_min(2 * n, a->len);
6816 memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
6817 memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
6818 if (a->expn & 1) {
6819 res = mp_shr_dec(a1, a1, 2 * n, 1, 0);
6820 } else {
6821 res = 0;
6822 }
6823 /* normalize so that a1 >= B^(2*n)/4. Not need for n = 1
6824 because mp_sqrtrem2_dec already does it */
6825 k = 0;
6826 if (n > 1) {
6827 v = a1[2 * n - 1];
6828 while (v < BF_DEC_BASE / 4) {
6829 k++;
6830 v *= 4;
6831 }
6832 if (k != 0)
6833 mp_mul1_dec(a1, a1, 2 * n, 1 << (2 * k), 0);
6834 }
6835 if (mp_sqrtrem_dec(s, r->tab, a1, n)) {
6836 bf_free(s, a1);
6837 goto fail;
6838 }
6839 if (k != 0)
6840 mp_div1_dec(r->tab, r->tab, n, 1 << k, 0);
6841 if (!res) {
6842 res = mp_scan_nz(a1, n + 1);
6843 }
6844 bf_free(s, a1);
6845 if (!res) {
6846 res = mp_scan_nz(a->tab, a->len - n1);
6847 }
6848 if (res != 0)
6849 r->tab[0] |= 1;
6850 r->sign = 0;
6851 r->expn = (a->expn + 1) >> 1;
6852 ret = bfdec_round(r, prec, flags);
6853 }
6854 return ret;
6855 fail:
6856 bfdec_set_nan(r);
6857 return BF_ST_MEM_ERROR;
6858 }
6859
6860 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
6861 is an overflow and 0 otherwise. No memory error is possible. */
bfdec_get_int32(int * pres,const bfdec_t * a)6862 int bfdec_get_int32(int *pres, const bfdec_t *a)
6863 {
6864 uint32_t v;
6865 int ret;
6866 if (a->expn >= BF_EXP_INF) {
6867 ret = 0;
6868 if (a->expn == BF_EXP_INF) {
6869 v = (uint32_t)INT32_MAX + a->sign;
6870 /* XXX: return overflow ? */
6871 } else {
6872 v = INT32_MAX;
6873 }
6874 } else if (a->expn <= 0) {
6875 v = 0;
6876 ret = 0;
6877 } else if (a->expn <= 9) {
6878 v = fast_udiv(a->tab[a->len - 1], &mp_pow_div[LIMB_DIGITS - a->expn]);
6879 if (a->sign)
6880 v = -v;
6881 ret = 0;
6882 } else if (a->expn == 10) {
6883 uint64_t v1;
6884 uint32_t v_max;
6885 #if LIMB_BITS == 64
6886 v1 = fast_udiv(a->tab[a->len - 1], &mp_pow_div[LIMB_DIGITS - a->expn]);
6887 #else
6888 v1 = (uint64_t)a->tab[a->len - 1] * 10 +
6889 get_digit(a->tab, a->len, (a->len - 1) * LIMB_DIGITS - 1);
6890 #endif
6891 v_max = (uint32_t)INT32_MAX + a->sign;
6892 if (v1 > v_max) {
6893 v = v_max;
6894 ret = BF_ST_OVERFLOW;
6895 } else {
6896 v = v1;
6897 if (a->sign)
6898 v = -v;
6899 ret = 0;
6900 }
6901 } else {
6902 v = (uint32_t)INT32_MAX + a->sign;
6903 ret = BF_ST_OVERFLOW;
6904 }
6905 *pres = v;
6906 return ret;
6907 }
6908
6909 /* power to an integer with infinite precision */
bfdec_pow_ui(bfdec_t * r,const bfdec_t * a,limb_t b)6910 int bfdec_pow_ui(bfdec_t *r, const bfdec_t *a, limb_t b)
6911 {
6912 int ret, n_bits, i;
6913
6914 assert(r != a);
6915 if (b == 0)
6916 return bfdec_set_ui(r, 1);
6917 ret = bfdec_set(r, a);
6918 n_bits = LIMB_BITS - clz(b);
6919 for(i = n_bits - 2; i >= 0; i--) {
6920 ret |= bfdec_mul(r, r, r, BF_PREC_INF, BF_RNDZ);
6921 if ((b >> i) & 1)
6922 ret |= bfdec_mul(r, r, a, BF_PREC_INF, BF_RNDZ);
6923 }
6924 return ret;
6925 }
6926
bfdec_ftoa(size_t * plen,const bfdec_t * a,limb_t prec,bf_flags_t flags)6927 char *bfdec_ftoa(size_t *plen, const bfdec_t *a, limb_t prec, bf_flags_t flags)
6928 {
6929 return bf_ftoa_internal(plen, (const bf_t *)a, 10, prec, flags, TRUE);
6930 }
6931
bfdec_atof(bfdec_t * r,const char * str,const char ** pnext,limb_t prec,bf_flags_t flags)6932 int bfdec_atof(bfdec_t *r, const char *str, const char **pnext,
6933 limb_t prec, bf_flags_t flags)
6934 {
6935 slimb_t dummy_exp;
6936 return bf_atof_internal((bf_t *)r, &dummy_exp, str, pnext, 10, prec,
6937 flags, TRUE);
6938 }
6939
6940 #endif /* USE_BF_DEC */
6941
6942 #ifdef USE_FFT_MUL
6943 /***************************************************************/
6944 /* Integer multiplication with FFT */
6945
6946 /* or LIMB_BITS at bit position 'pos' in tab */
put_bits(limb_t * tab,limb_t len,slimb_t pos,limb_t val)6947 static inline void put_bits(limb_t *tab, limb_t len, slimb_t pos, limb_t val)
6948 {
6949 limb_t i;
6950 int p;
6951
6952 i = pos >> LIMB_LOG2_BITS;
6953 p = pos & (LIMB_BITS - 1);
6954 if (i < len)
6955 tab[i] |= val << p;
6956 if (p != 0) {
6957 i++;
6958 if (i < len) {
6959 tab[i] |= val >> (LIMB_BITS - p);
6960 }
6961 }
6962 }
6963
6964 #if defined(__AVX2__)
6965
6966 typedef double NTTLimb;
6967
6968 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
6969 #define NTT_MOD_LOG2_MIN 50
6970 #define NTT_MOD_LOG2_MAX 51
6971 #define NB_MODS 5
6972 #define NTT_PROOT_2EXP 39
6973 static const int ntt_int_bits[NB_MODS] = { 254, 203, 152, 101, 50, };
6974
6975 static const limb_t ntt_mods[NB_MODS] = { 0x00073a8000000001, 0x0007858000000001, 0x0007a38000000001, 0x0007a68000000001, 0x0007fd8000000001,
6976 };
6977
6978 static const limb_t ntt_proot[2][NB_MODS] = {
6979 { 0x00056198d44332c8, 0x0002eb5d640aad39, 0x00047e31eaa35fd0, 0x0005271ac118a150, 0x00075e0ce8442bd5, },
6980 { 0x000461169761bcc5, 0x0002dac3cb2da688, 0x0004abc97751e3bf, 0x000656778fc8c485, 0x0000dc6469c269fa, },
6981 };
6982
6983 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
6984 0x00020e4da740da8e, 0x0004c3dc09c09c1d, 0x000063bd097b4271, 0x000799d8f18f18fd,
6985 0x0005384222222264, 0x000572b07c1f07fe, 0x00035cd08888889a,
6986 0x00066015555557e3, 0x000725960b60b623,
6987 0x0002fc1fa1d6ce12,
6988 };
6989
6990 #else
6991
6992 typedef limb_t NTTLimb;
6993
6994 #if LIMB_BITS == 64
6995
6996 #define NTT_MOD_LOG2_MIN 61
6997 #define NTT_MOD_LOG2_MAX 62
6998 #define NB_MODS 5
6999 #define NTT_PROOT_2EXP 51
7000 static const int ntt_int_bits[NB_MODS] = { 307, 246, 185, 123, 61, };
7001
7002 static const limb_t ntt_mods[NB_MODS] = { 0x28d8000000000001, 0x2a88000000000001, 0x2ed8000000000001, 0x3508000000000001, 0x3aa8000000000001,
7003 };
7004
7005 static const limb_t ntt_proot[2][NB_MODS] = {
7006 { 0x1b8ea61034a2bea7, 0x21a9762de58206fb, 0x02ca782f0756a8ea, 0x278384537a3e50a1, 0x106e13fee74ce0ab, },
7007 { 0x233513af133e13b8, 0x1d13140d1c6f75f1, 0x12cde57f97e3eeda, 0x0d6149e23cbe654f, 0x36cd204f522a1379, },
7008 };
7009
7010 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7011 0x08a9ed097b425eea, 0x18a44aaaaaaaaab3, 0x2493f57f57f57f5d, 0x126b8d0649a7f8d4,
7012 0x09d80ed7303b5ccc, 0x25b8bcf3cf3cf3d5, 0x2ce6ce63398ce638,
7013 0x0e31fad40a57eb59, 0x02a3529fd4a7f52f,
7014 0x3a5493e93e93e94a,
7015 };
7016
7017 #elif LIMB_BITS == 32
7018
7019 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
7020 #define NTT_MOD_LOG2_MIN 29
7021 #define NTT_MOD_LOG2_MAX 30
7022 #define NB_MODS 5
7023 #define NTT_PROOT_2EXP 20
7024 static const int ntt_int_bits[NB_MODS] = { 148, 119, 89, 59, 29, };
7025
7026 static const limb_t ntt_mods[NB_MODS] = { 0x0000000032b00001, 0x0000000033700001, 0x0000000036d00001, 0x0000000037300001, 0x000000003e500001,
7027 };
7028
7029 static const limb_t ntt_proot[2][NB_MODS] = {
7030 { 0x0000000032525f31, 0x0000000005eb3b37, 0x00000000246eda9f, 0x0000000035f25901, 0x00000000022f5768, },
7031 { 0x00000000051eba1a, 0x00000000107be10e, 0x000000001cd574e0, 0x00000000053806e6, 0x000000002cd6bf98, },
7032 };
7033
7034 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7035 0x000000000449559a, 0x000000001eba6ca9, 0x000000002ec18e46, 0x000000000860160b,
7036 0x000000000d321307, 0x000000000bf51120, 0x000000000f662938,
7037 0x000000000932ab3e, 0x000000002f40eef8,
7038 0x000000002e760905,
7039 };
7040
7041 #endif /* LIMB_BITS */
7042
7043 #endif /* !AVX2 */
7044
7045 #if defined(__AVX2__)
7046 #define NTT_TRIG_K_MAX 18
7047 #else
7048 #define NTT_TRIG_K_MAX 19
7049 #endif
7050
7051 typedef struct BFNTTState {
7052 bf_context_t *ctx;
7053
7054 /* used for mul_mod_fast() */
7055 limb_t ntt_mods_div[NB_MODS];
7056
7057 limb_t ntt_proot_pow[NB_MODS][2][NTT_PROOT_2EXP + 1];
7058 limb_t ntt_proot_pow_inv[NB_MODS][2][NTT_PROOT_2EXP + 1];
7059 NTTLimb *ntt_trig[NB_MODS][2][NTT_TRIG_K_MAX + 1];
7060 /* 1/2^n mod m */
7061 limb_t ntt_len_inv[NB_MODS][NTT_PROOT_2EXP + 1][2];
7062 #if defined(__AVX2__)
7063 __m256d ntt_mods_cr_vec[NB_MODS * (NB_MODS - 1) / 2];
7064 __m256d ntt_mods_vec[NB_MODS];
7065 __m256d ntt_mods_inv_vec[NB_MODS];
7066 #else
7067 limb_t ntt_mods_cr_inv[NB_MODS * (NB_MODS - 1) / 2];
7068 #endif
7069 } BFNTTState;
7070
7071 static NTTLimb *get_trig(BFNTTState *s, int k, int inverse, int m_idx);
7072
7073 /* add modulo with up to (LIMB_BITS-1) bit modulo */
add_mod(limb_t a,limb_t b,limb_t m)7074 static inline limb_t add_mod(limb_t a, limb_t b, limb_t m)
7075 {
7076 limb_t r;
7077 r = a + b;
7078 if (r >= m)
7079 r -= m;
7080 return r;
7081 }
7082
7083 /* sub modulo with up to LIMB_BITS bit modulo */
sub_mod(limb_t a,limb_t b,limb_t m)7084 static inline limb_t sub_mod(limb_t a, limb_t b, limb_t m)
7085 {
7086 limb_t r;
7087 r = a - b;
7088 if (r > a)
7089 r += m;
7090 return r;
7091 }
7092
7093 /* return (r0+r1*B) mod m
7094 precondition: 0 <= r0+r1*B < 2^(64+NTT_MOD_LOG2_MIN)
7095 */
mod_fast(dlimb_t r,limb_t m,limb_t m_inv)7096 static inline limb_t mod_fast(dlimb_t r,
7097 limb_t m, limb_t m_inv)
7098 {
7099 limb_t a1, q, t0, r1, r0;
7100
7101 a1 = r >> NTT_MOD_LOG2_MIN;
7102
7103 q = ((dlimb_t)a1 * m_inv) >> LIMB_BITS;
7104 r = r - (dlimb_t)q * m - m * 2;
7105 r1 = r >> LIMB_BITS;
7106 t0 = (slimb_t)r1 >> 1;
7107 r += m & t0;
7108 r0 = r;
7109 r1 = r >> LIMB_BITS;
7110 r0 += m & r1;
7111 return r0;
7112 }
7113
7114 /* faster version using precomputed modulo inverse.
7115 precondition: 0 <= a * b < 2^(64+NTT_MOD_LOG2_MIN) */
mul_mod_fast(limb_t a,limb_t b,limb_t m,limb_t m_inv)7116 static inline limb_t mul_mod_fast(limb_t a, limb_t b,
7117 limb_t m, limb_t m_inv)
7118 {
7119 dlimb_t r;
7120 r = (dlimb_t)a * (dlimb_t)b;
7121 return mod_fast(r, m, m_inv);
7122 }
7123
init_mul_mod_fast(limb_t m)7124 static inline limb_t init_mul_mod_fast(limb_t m)
7125 {
7126 dlimb_t t;
7127 assert(m < (limb_t)1 << NTT_MOD_LOG2_MAX);
7128 assert(m >= (limb_t)1 << NTT_MOD_LOG2_MIN);
7129 t = (dlimb_t)1 << (LIMB_BITS + NTT_MOD_LOG2_MIN);
7130 return t / m;
7131 }
7132
7133 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7134 0 <= b < m. */
mul_mod_fast2(limb_t a,limb_t b,limb_t m,limb_t b_inv)7135 static inline limb_t mul_mod_fast2(limb_t a, limb_t b,
7136 limb_t m, limb_t b_inv)
7137 {
7138 limb_t r, q;
7139
7140 q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7141 r = a * b - q * m;
7142 if (r >= m)
7143 r -= m;
7144 return r;
7145 }
7146
7147 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7148 0 <= b < m. Let r = a * b mod m. The return value is 'r' or 'r +
7149 m'. */
mul_mod_fast3(limb_t a,limb_t b,limb_t m,limb_t b_inv)7150 static inline limb_t mul_mod_fast3(limb_t a, limb_t b,
7151 limb_t m, limb_t b_inv)
7152 {
7153 limb_t r, q;
7154
7155 q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7156 r = a * b - q * m;
7157 return r;
7158 }
7159
init_mul_mod_fast2(limb_t b,limb_t m)7160 static inline limb_t init_mul_mod_fast2(limb_t b, limb_t m)
7161 {
7162 return ((dlimb_t)b << LIMB_BITS) / m;
7163 }
7164
7165 #ifdef __AVX2__
7166
ntt_limb_to_int(NTTLimb a,limb_t m)7167 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7168 {
7169 slimb_t v;
7170 v = a;
7171 if (v < 0)
7172 v += m;
7173 if (v >= m)
7174 v -= m;
7175 return v;
7176 }
7177
int_to_ntt_limb(limb_t a,limb_t m)7178 static inline NTTLimb int_to_ntt_limb(limb_t a, limb_t m)
7179 {
7180 return (slimb_t)a;
7181 }
7182
int_to_ntt_limb2(limb_t a,limb_t m)7183 static inline NTTLimb int_to_ntt_limb2(limb_t a, limb_t m)
7184 {
7185 if (a >= (m / 2))
7186 a -= m;
7187 return (slimb_t)a;
7188 }
7189
7190 /* return r + m if r < 0 otherwise r. */
ntt_mod1(__m256d r,__m256d m)7191 static inline __m256d ntt_mod1(__m256d r, __m256d m)
7192 {
7193 return _mm256_blendv_pd(r, r + m, r);
7194 }
7195
7196 /* input: abs(r) < 2 * m. Output: abs(r) < m */
ntt_mod(__m256d r,__m256d mf,__m256d m2f)7197 static inline __m256d ntt_mod(__m256d r, __m256d mf, __m256d m2f)
7198 {
7199 return _mm256_blendv_pd(r, r + m2f, r) - mf;
7200 }
7201
7202 /* input: abs(a*b) < 2 * m^2, output: abs(r) < m */
ntt_mul_mod(__m256d a,__m256d b,__m256d mf,__m256d m_inv)7203 static inline __m256d ntt_mul_mod(__m256d a, __m256d b, __m256d mf,
7204 __m256d m_inv)
7205 {
7206 __m256d r, q, ab1, ab0, qm0, qm1;
7207 ab1 = a * b;
7208 q = _mm256_round_pd(ab1 * m_inv, 0); /* round to nearest */
7209 qm1 = q * mf;
7210 qm0 = _mm256_fmsub_pd(q, mf, qm1); /* low part */
7211 ab0 = _mm256_fmsub_pd(a, b, ab1); /* low part */
7212 r = (ab1 - qm1) + (ab0 - qm0);
7213 return r;
7214 }
7215
bf_aligned_malloc(bf_context_t * s,size_t size,size_t align)7216 static void *bf_aligned_malloc(bf_context_t *s, size_t size, size_t align)
7217 {
7218 void *ptr;
7219 void **ptr1;
7220 ptr = bf_malloc(s, size + sizeof(void *) + align - 1);
7221 if (!ptr)
7222 return NULL;
7223 ptr1 = (void **)(((uintptr_t)ptr + sizeof(void *) + align - 1) &
7224 ~(align - 1));
7225 ptr1[-1] = ptr;
7226 return ptr1;
7227 }
7228
bf_aligned_free(bf_context_t * s,void * ptr)7229 static void bf_aligned_free(bf_context_t *s, void *ptr)
7230 {
7231 if (!ptr)
7232 return;
7233 bf_free(s, ((void **)ptr)[-1]);
7234 }
7235
ntt_malloc(BFNTTState * s,size_t size)7236 static void *ntt_malloc(BFNTTState *s, size_t size)
7237 {
7238 return bf_aligned_malloc(s->ctx, size, 64);
7239 }
7240
ntt_free(BFNTTState * s,void * ptr)7241 static void ntt_free(BFNTTState *s, void *ptr)
7242 {
7243 bf_aligned_free(s->ctx, ptr);
7244 }
7245
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7246 static no_inline int ntt_fft(BFNTTState *s,
7247 NTTLimb *out_buf, NTTLimb *in_buf,
7248 NTTLimb *tmp_buf, int fft_len_log2,
7249 int inverse, int m_idx)
7250 {
7251 limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j;
7252 NTTLimb *tab_in, *tab_out, *tmp, *trig;
7253 __m256d m_inv, mf, m2f, c, a0, a1, b0, b1;
7254 limb_t m;
7255 int l;
7256
7257 m = ntt_mods[m_idx];
7258
7259 m_inv = _mm256_set1_pd(1.0 / (double)m);
7260 mf = _mm256_set1_pd(m);
7261 m2f = _mm256_set1_pd(m * 2);
7262
7263 n = (limb_t)1 << fft_len_log2;
7264 assert(n >= 8);
7265 stride_in = n / 2;
7266
7267 tab_in = in_buf;
7268 tab_out = tmp_buf;
7269 trig = get_trig(s, fft_len_log2, inverse, m_idx);
7270 if (!trig)
7271 return -1;
7272 p = 0;
7273 for(k = 0; k < stride_in; k += 4) {
7274 a0 = _mm256_load_pd(&tab_in[k]);
7275 a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7276 c = _mm256_load_pd(trig);
7277 trig += 4;
7278 b0 = ntt_mod(a0 + a1, mf, m2f);
7279 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7280 a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7281 a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7282 a0 = _mm256_permute4x64_pd(a0, 0xd8);
7283 a1 = _mm256_permute4x64_pd(a1, 0xd8);
7284 _mm256_store_pd(&tab_out[p], a0);
7285 _mm256_store_pd(&tab_out[p + 4], a1);
7286 p += 2 * 4;
7287 }
7288 tmp = tab_in;
7289 tab_in = tab_out;
7290 tab_out = tmp;
7291
7292 trig = get_trig(s, fft_len_log2 - 1, inverse, m_idx);
7293 if (!trig)
7294 return -1;
7295 p = 0;
7296 for(k = 0; k < stride_in; k += 4) {
7297 a0 = _mm256_load_pd(&tab_in[k]);
7298 a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7299 c = _mm256_setr_pd(trig[0], trig[0], trig[1], trig[1]);
7300 trig += 2;
7301 b0 = ntt_mod(a0 + a1, mf, m2f);
7302 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7303 a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7304 a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7305 _mm256_store_pd(&tab_out[p], a0);
7306 _mm256_store_pd(&tab_out[p + 4], a1);
7307 p += 2 * 4;
7308 }
7309 tmp = tab_in;
7310 tab_in = tab_out;
7311 tab_out = tmp;
7312
7313 nb_blocks = n / 4;
7314 fft_per_block = 4;
7315
7316 l = fft_len_log2 - 2;
7317 while (nb_blocks != 2) {
7318 nb_blocks >>= 1;
7319 p = 0;
7320 k = 0;
7321 trig = get_trig(s, l, inverse, m_idx);
7322 if (!trig)
7323 return -1;
7324 for(i = 0; i < nb_blocks; i++) {
7325 c = _mm256_set1_pd(trig[0]);
7326 trig++;
7327 for(j = 0; j < fft_per_block; j += 4) {
7328 a0 = _mm256_load_pd(&tab_in[k + j]);
7329 a1 = _mm256_load_pd(&tab_in[k + j + stride_in]);
7330 b0 = ntt_mod(a0 + a1, mf, m2f);
7331 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7332 _mm256_store_pd(&tab_out[p + j], b0);
7333 _mm256_store_pd(&tab_out[p + j + fft_per_block], b1);
7334 }
7335 k += fft_per_block;
7336 p += 2 * fft_per_block;
7337 }
7338 fft_per_block <<= 1;
7339 l--;
7340 tmp = tab_in;
7341 tab_in = tab_out;
7342 tab_out = tmp;
7343 }
7344
7345 tab_out = out_buf;
7346 for(k = 0; k < stride_in; k += 4) {
7347 a0 = _mm256_load_pd(&tab_in[k]);
7348 a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7349 b0 = ntt_mod(a0 + a1, mf, m2f);
7350 b1 = ntt_mod(a0 - a1, mf, m2f);
7351 _mm256_store_pd(&tab_out[k], b0);
7352 _mm256_store_pd(&tab_out[k + stride_in], b1);
7353 }
7354 return 0;
7355 }
7356
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,limb_t fft_len_log2,int k_tot,int m_idx)7357 static void ntt_vec_mul(BFNTTState *s,
7358 NTTLimb *tab1, NTTLimb *tab2, limb_t fft_len_log2,
7359 int k_tot, int m_idx)
7360 {
7361 limb_t i, c_inv, n, m;
7362 __m256d m_inv, mf, a, b, c;
7363
7364 m = ntt_mods[m_idx];
7365 c_inv = s->ntt_len_inv[m_idx][k_tot][0];
7366 m_inv = _mm256_set1_pd(1.0 / (double)m);
7367 mf = _mm256_set1_pd(m);
7368 c = _mm256_set1_pd(int_to_ntt_limb(c_inv, m));
7369 n = (limb_t)1 << fft_len_log2;
7370 for(i = 0; i < n; i += 4) {
7371 a = _mm256_load_pd(&tab1[i]);
7372 b = _mm256_load_pd(&tab2[i]);
7373 a = ntt_mul_mod(a, b, mf, m_inv);
7374 a = ntt_mul_mod(a, c, mf, m_inv);
7375 _mm256_store_pd(&tab1[i], a);
7376 }
7377 }
7378
mul_trig(NTTLimb * buf,limb_t n,limb_t c1,limb_t m,limb_t m_inv1)7379 static no_inline void mul_trig(NTTLimb *buf,
7380 limb_t n, limb_t c1, limb_t m, limb_t m_inv1)
7381 {
7382 limb_t i, c2, c3, c4;
7383 __m256d c, c_mul, a0, mf, m_inv;
7384 assert(n >= 2);
7385
7386 mf = _mm256_set1_pd(m);
7387 m_inv = _mm256_set1_pd(1.0 / (double)m);
7388
7389 c2 = mul_mod_fast(c1, c1, m, m_inv1);
7390 c3 = mul_mod_fast(c2, c1, m, m_inv1);
7391 c4 = mul_mod_fast(c2, c2, m, m_inv1);
7392 c = _mm256_setr_pd(1, int_to_ntt_limb(c1, m),
7393 int_to_ntt_limb(c2, m), int_to_ntt_limb(c3, m));
7394 c_mul = _mm256_set1_pd(int_to_ntt_limb(c4, m));
7395 for(i = 0; i < n; i += 4) {
7396 a0 = _mm256_load_pd(&buf[i]);
7397 a0 = ntt_mul_mod(a0, c, mf, m_inv);
7398 _mm256_store_pd(&buf[i], a0);
7399 c = ntt_mul_mod(c, c_mul, mf, m_inv);
7400 }
7401 }
7402
7403 #else
7404
ntt_malloc(BFNTTState * s,size_t size)7405 static void *ntt_malloc(BFNTTState *s, size_t size)
7406 {
7407 return bf_malloc(s->ctx, size);
7408 }
7409
ntt_free(BFNTTState * s,void * ptr)7410 static void ntt_free(BFNTTState *s, void *ptr)
7411 {
7412 bf_free(s->ctx, ptr);
7413 }
7414
ntt_limb_to_int(NTTLimb a,limb_t m)7415 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7416 {
7417 if (a >= m)
7418 a -= m;
7419 return a;
7420 }
7421
int_to_ntt_limb(slimb_t a,limb_t m)7422 static inline NTTLimb int_to_ntt_limb(slimb_t a, limb_t m)
7423 {
7424 return a;
7425 }
7426
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7427 static no_inline int ntt_fft(BFNTTState *s, NTTLimb *out_buf, NTTLimb *in_buf,
7428 NTTLimb *tmp_buf, int fft_len_log2,
7429 int inverse, int m_idx)
7430 {
7431 limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j, m, m2;
7432 NTTLimb *tab_in, *tab_out, *tmp, a0, a1, b0, b1, c, *trig, c_inv;
7433 int l;
7434
7435 m = ntt_mods[m_idx];
7436 m2 = 2 * m;
7437 n = (limb_t)1 << fft_len_log2;
7438 nb_blocks = n;
7439 fft_per_block = 1;
7440 stride_in = n / 2;
7441 tab_in = in_buf;
7442 tab_out = tmp_buf;
7443 l = fft_len_log2;
7444 while (nb_blocks != 2) {
7445 nb_blocks >>= 1;
7446 p = 0;
7447 k = 0;
7448 trig = get_trig(s, l, inverse, m_idx);
7449 if (!trig)
7450 return -1;
7451 for(i = 0; i < nb_blocks; i++) {
7452 c = trig[0];
7453 c_inv = trig[1];
7454 trig += 2;
7455 for(j = 0; j < fft_per_block; j++) {
7456 a0 = tab_in[k + j];
7457 a1 = tab_in[k + j + stride_in];
7458 b0 = add_mod(a0, a1, m2);
7459 b1 = a0 - a1 + m2;
7460 b1 = mul_mod_fast3(b1, c, m, c_inv);
7461 tab_out[p + j] = b0;
7462 tab_out[p + j + fft_per_block] = b1;
7463 }
7464 k += fft_per_block;
7465 p += 2 * fft_per_block;
7466 }
7467 fft_per_block <<= 1;
7468 l--;
7469 tmp = tab_in;
7470 tab_in = tab_out;
7471 tab_out = tmp;
7472 }
7473 /* no twiddle in last step */
7474 tab_out = out_buf;
7475 for(k = 0; k < stride_in; k++) {
7476 a0 = tab_in[k];
7477 a1 = tab_in[k + stride_in];
7478 b0 = add_mod(a0, a1, m2);
7479 b1 = sub_mod(a0, a1, m2);
7480 tab_out[k] = b0;
7481 tab_out[k + stride_in] = b1;
7482 }
7483 return 0;
7484 }
7485
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,int fft_len_log2,int k_tot,int m_idx)7486 static void ntt_vec_mul(BFNTTState *s,
7487 NTTLimb *tab1, NTTLimb *tab2, int fft_len_log2,
7488 int k_tot, int m_idx)
7489 {
7490 limb_t i, norm, norm_inv, a, n, m, m_inv;
7491
7492 m = ntt_mods[m_idx];
7493 m_inv = s->ntt_mods_div[m_idx];
7494 norm = s->ntt_len_inv[m_idx][k_tot][0];
7495 norm_inv = s->ntt_len_inv[m_idx][k_tot][1];
7496 n = (limb_t)1 << fft_len_log2;
7497 for(i = 0; i < n; i++) {
7498 a = tab1[i];
7499 /* need to reduce the range so that the product is <
7500 2^(LIMB_BITS+NTT_MOD_LOG2_MIN) */
7501 if (a >= m)
7502 a -= m;
7503 a = mul_mod_fast(a, tab2[i], m, m_inv);
7504 a = mul_mod_fast3(a, norm, m, norm_inv);
7505 tab1[i] = a;
7506 }
7507 }
7508
mul_trig(NTTLimb * buf,limb_t n,limb_t c_mul,limb_t m,limb_t m_inv)7509 static no_inline void mul_trig(NTTLimb *buf,
7510 limb_t n, limb_t c_mul, limb_t m, limb_t m_inv)
7511 {
7512 limb_t i, c0, c_mul_inv;
7513
7514 c0 = 1;
7515 c_mul_inv = init_mul_mod_fast2(c_mul, m);
7516 for(i = 0; i < n; i++) {
7517 buf[i] = mul_mod_fast(buf[i], c0, m, m_inv);
7518 c0 = mul_mod_fast2(c0, c_mul, m, c_mul_inv);
7519 }
7520 }
7521
7522 #endif /* !AVX2 */
7523
get_trig(BFNTTState * s,int k,int inverse,int m_idx)7524 static no_inline NTTLimb *get_trig(BFNTTState *s,
7525 int k, int inverse, int m_idx)
7526 {
7527 NTTLimb *tab;
7528 limb_t i, n2, c, c_mul, m, c_mul_inv;
7529
7530 if (k > NTT_TRIG_K_MAX)
7531 return NULL;
7532
7533 tab = s->ntt_trig[m_idx][inverse][k];
7534 if (tab)
7535 return tab;
7536 n2 = (limb_t)1 << (k - 1);
7537 m = ntt_mods[m_idx];
7538 #ifdef __AVX2__
7539 tab = ntt_malloc(s, sizeof(NTTLimb) * n2);
7540 #else
7541 tab = ntt_malloc(s, sizeof(NTTLimb) * n2 * 2);
7542 #endif
7543 if (!tab)
7544 return NULL;
7545 c = 1;
7546 c_mul = s->ntt_proot_pow[m_idx][inverse][k];
7547 c_mul_inv = s->ntt_proot_pow_inv[m_idx][inverse][k];
7548 for(i = 0; i < n2; i++) {
7549 #ifdef __AVX2__
7550 tab[i] = int_to_ntt_limb2(c, m);
7551 #else
7552 tab[2 * i] = int_to_ntt_limb(c, m);
7553 tab[2 * i + 1] = init_mul_mod_fast2(c, m);
7554 #endif
7555 c = mul_mod_fast2(c, c_mul, m, c_mul_inv);
7556 }
7557 s->ntt_trig[m_idx][inverse][k] = tab;
7558 return tab;
7559 }
7560
fft_clear_cache(bf_context_t * s1)7561 void fft_clear_cache(bf_context_t *s1)
7562 {
7563 int m_idx, inverse, k;
7564 BFNTTState *s = s1->ntt_state;
7565 if (s) {
7566 for(m_idx = 0; m_idx < NB_MODS; m_idx++) {
7567 for(inverse = 0; inverse < 2; inverse++) {
7568 for(k = 0; k < NTT_TRIG_K_MAX + 1; k++) {
7569 if (s->ntt_trig[m_idx][inverse][k]) {
7570 ntt_free(s, s->ntt_trig[m_idx][inverse][k]);
7571 s->ntt_trig[m_idx][inverse][k] = NULL;
7572 }
7573 }
7574 }
7575 }
7576 #if defined(__AVX2__)
7577 bf_aligned_free(s1, s);
7578 #else
7579 bf_free(s1, s);
7580 #endif
7581 s1->ntt_state = NULL;
7582 }
7583 }
7584
7585 #define STRIP_LEN 16
7586
7587 /* dst = buf1, src = buf2 */
ntt_fft_partial(BFNTTState * s,NTTLimb * buf1,int k1,int k2,limb_t n1,limb_t n2,int inverse,limb_t m_idx)7588 static int ntt_fft_partial(BFNTTState *s, NTTLimb *buf1,
7589 int k1, int k2, limb_t n1, limb_t n2, int inverse,
7590 limb_t m_idx)
7591 {
7592 limb_t i, j, c_mul, c0, m, m_inv, strip_len, l;
7593 NTTLimb *buf2, *buf3;
7594
7595 buf2 = NULL;
7596 buf3 = ntt_malloc(s, sizeof(NTTLimb) * n1);
7597 if (!buf3)
7598 goto fail;
7599 if (k2 == 0) {
7600 if (ntt_fft(s, buf1, buf1, buf3, k1, inverse, m_idx))
7601 goto fail;
7602 } else {
7603 strip_len = STRIP_LEN;
7604 buf2 = ntt_malloc(s, sizeof(NTTLimb) * n1 * strip_len);
7605 if (!buf2)
7606 goto fail;
7607 m = ntt_mods[m_idx];
7608 m_inv = s->ntt_mods_div[m_idx];
7609 c0 = s->ntt_proot_pow[m_idx][inverse][k1 + k2];
7610 c_mul = 1;
7611 assert((n2 % strip_len) == 0);
7612 for(j = 0; j < n2; j += strip_len) {
7613 for(i = 0; i < n1; i++) {
7614 for(l = 0; l < strip_len; l++) {
7615 buf2[i + l * n1] = buf1[i * n2 + (j + l)];
7616 }
7617 }
7618 for(l = 0; l < strip_len; l++) {
7619 if (inverse)
7620 mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7621 if (ntt_fft(s, buf2 + l * n1, buf2 + l * n1, buf3, k1, inverse, m_idx))
7622 goto fail;
7623 if (!inverse)
7624 mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7625 c_mul = mul_mod_fast(c_mul, c0, m, m_inv);
7626 }
7627
7628 for(i = 0; i < n1; i++) {
7629 for(l = 0; l < strip_len; l++) {
7630 buf1[i * n2 + (j + l)] = buf2[i + l *n1];
7631 }
7632 }
7633 }
7634 ntt_free(s, buf2);
7635 }
7636 ntt_free(s, buf3);
7637 return 0;
7638 fail:
7639 ntt_free(s, buf2);
7640 ntt_free(s, buf3);
7641 return -1;
7642 }
7643
7644
7645 /* dst = buf1, src = buf2, tmp = buf3 */
ntt_conv(BFNTTState * s,NTTLimb * buf1,NTTLimb * buf2,int k,int k_tot,limb_t m_idx)7646 static int ntt_conv(BFNTTState *s, NTTLimb *buf1, NTTLimb *buf2,
7647 int k, int k_tot, limb_t m_idx)
7648 {
7649 limb_t n1, n2, i;
7650 int k1, k2;
7651
7652 if (k <= NTT_TRIG_K_MAX) {
7653 k1 = k;
7654 } else {
7655 /* recursive split of the FFT */
7656 k1 = bf_min(k / 2, NTT_TRIG_K_MAX);
7657 }
7658 k2 = k - k1;
7659 n1 = (limb_t)1 << k1;
7660 n2 = (limb_t)1 << k2;
7661
7662 if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 0, m_idx))
7663 return -1;
7664 if (ntt_fft_partial(s, buf2, k1, k2, n1, n2, 0, m_idx))
7665 return -1;
7666 if (k2 == 0) {
7667 ntt_vec_mul(s, buf1, buf2, k, k_tot, m_idx);
7668 } else {
7669 for(i = 0; i < n1; i++) {
7670 ntt_conv(s, buf1 + i * n2, buf2 + i * n2, k2, k_tot, m_idx);
7671 }
7672 }
7673 if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 1, m_idx))
7674 return -1;
7675 return 0;
7676 }
7677
7678
limb_to_ntt(BFNTTState * s,NTTLimb * tabr,limb_t fft_len,const limb_t * taba,limb_t a_len,int dpl,int first_m_idx,int nb_mods)7679 static no_inline void limb_to_ntt(BFNTTState *s,
7680 NTTLimb *tabr, limb_t fft_len,
7681 const limb_t *taba, limb_t a_len, int dpl,
7682 int first_m_idx, int nb_mods)
7683 {
7684 slimb_t i, n;
7685 dlimb_t a, b;
7686 int j, shift;
7687 limb_t base_mask1, a0, a1, a2, r, m, m_inv;
7688
7689 #if 0
7690 for(i = 0; i < a_len; i++) {
7691 printf("%" PRId64 ": " FMT_LIMB "\n",
7692 (int64_t)i, taba[i]);
7693 }
7694 #endif
7695 memset(tabr, 0, sizeof(NTTLimb) * fft_len * nb_mods);
7696 shift = dpl & (LIMB_BITS - 1);
7697 if (shift == 0)
7698 base_mask1 = -1;
7699 else
7700 base_mask1 = ((limb_t)1 << shift) - 1;
7701 n = bf_min(fft_len, (a_len * LIMB_BITS + dpl - 1) / dpl);
7702 for(i = 0; i < n; i++) {
7703 a0 = get_bits(taba, a_len, i * dpl);
7704 if (dpl <= LIMB_BITS) {
7705 a0 &= base_mask1;
7706 a = a0;
7707 } else {
7708 a1 = get_bits(taba, a_len, i * dpl + LIMB_BITS);
7709 if (dpl <= (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
7710 a = a0 | ((dlimb_t)(a1 & base_mask1) << LIMB_BITS);
7711 } else {
7712 if (dpl > 2 * LIMB_BITS) {
7713 a2 = get_bits(taba, a_len, i * dpl + LIMB_BITS * 2) &
7714 base_mask1;
7715 } else {
7716 a1 &= base_mask1;
7717 a2 = 0;
7718 }
7719 // printf("a=0x%016lx%016lx%016lx\n", a2, a1, a0);
7720 a = (a0 >> (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) |
7721 ((dlimb_t)a1 << (NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN)) |
7722 ((dlimb_t)a2 << (LIMB_BITS + NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN));
7723 a0 &= ((limb_t)1 << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) - 1;
7724 }
7725 }
7726 for(j = 0; j < nb_mods; j++) {
7727 m = ntt_mods[first_m_idx + j];
7728 m_inv = s->ntt_mods_div[first_m_idx + j];
7729 r = mod_fast(a, m, m_inv);
7730 if (dpl > (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
7731 b = ((dlimb_t)r << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) | a0;
7732 r = mod_fast(b, m, m_inv);
7733 }
7734 tabr[i + j * fft_len] = int_to_ntt_limb(r, m);
7735 }
7736 }
7737 }
7738
7739 #if defined(__AVX2__)
7740
7741 #define VEC_LEN 4
7742
7743 typedef union {
7744 __m256d v;
7745 double d[4];
7746 } VecUnion;
7747
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)7748 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
7749 const NTTLimb *buf, int fft_len_log2, int dpl,
7750 int nb_mods)
7751 {
7752 const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
7753 const __m256d *mods_cr_vec, *mf, *m_inv;
7754 VecUnion y[NB_MODS];
7755 limb_t u[NB_MODS], carry[NB_MODS], fft_len, base_mask1, r;
7756 slimb_t i, len, pos;
7757 int j, k, l, shift, n_limb1, p;
7758 dlimb_t t;
7759
7760 j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
7761 mods_cr_vec = s->ntt_mods_cr_vec + j;
7762 mf = s->ntt_mods_vec + NB_MODS - nb_mods;
7763 m_inv = s->ntt_mods_inv_vec + NB_MODS - nb_mods;
7764
7765 shift = dpl & (LIMB_BITS - 1);
7766 if (shift == 0)
7767 base_mask1 = -1;
7768 else
7769 base_mask1 = ((limb_t)1 << shift) - 1;
7770 n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
7771 for(j = 0; j < NB_MODS; j++)
7772 carry[j] = 0;
7773 for(j = 0; j < NB_MODS; j++)
7774 u[j] = 0; /* avoid warnings */
7775 memset(tabr, 0, sizeof(limb_t) * r_len);
7776 fft_len = (limb_t)1 << fft_len_log2;
7777 len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
7778 len = (len + VEC_LEN - 1) & ~(VEC_LEN - 1);
7779 i = 0;
7780 while (i < len) {
7781 for(j = 0; j < nb_mods; j++)
7782 y[j].v = *(__m256d *)&buf[i + fft_len * j];
7783
7784 /* Chinese remainder to get mixed radix representation */
7785 l = 0;
7786 for(j = 0; j < nb_mods - 1; j++) {
7787 y[j].v = ntt_mod1(y[j].v, mf[j]);
7788 for(k = j + 1; k < nb_mods; k++) {
7789 y[k].v = ntt_mul_mod(y[k].v - y[j].v,
7790 mods_cr_vec[l], mf[k], m_inv[k]);
7791 l++;
7792 }
7793 }
7794 y[j].v = ntt_mod1(y[j].v, mf[j]);
7795
7796 for(p = 0; p < VEC_LEN; p++) {
7797 /* back to normal representation */
7798 u[0] = (int64_t)y[nb_mods - 1].d[p];
7799 l = 1;
7800 for(j = nb_mods - 2; j >= 1; j--) {
7801 r = (int64_t)y[j].d[p];
7802 for(k = 0; k < l; k++) {
7803 t = (dlimb_t)u[k] * mods[j] + r;
7804 r = t >> LIMB_BITS;
7805 u[k] = t;
7806 }
7807 u[l] = r;
7808 l++;
7809 }
7810 /* XXX: for nb_mods = 5, l should be 4 */
7811
7812 /* last step adds the carry */
7813 r = (int64_t)y[0].d[p];
7814 for(k = 0; k < l; k++) {
7815 t = (dlimb_t)u[k] * mods[j] + r + carry[k];
7816 r = t >> LIMB_BITS;
7817 u[k] = t;
7818 }
7819 u[l] = r + carry[l];
7820
7821 #if 0
7822 printf("%" PRId64 ": ", i);
7823 for(j = nb_mods - 1; j >= 0; j--) {
7824 printf(" %019" PRIu64, u[j]);
7825 }
7826 printf("\n");
7827 #endif
7828
7829 /* write the digits */
7830 pos = i * dpl;
7831 for(j = 0; j < n_limb1; j++) {
7832 put_bits(tabr, r_len, pos, u[j]);
7833 pos += LIMB_BITS;
7834 }
7835 put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
7836 /* shift by dpl digits and set the carry */
7837 if (shift == 0) {
7838 for(j = n_limb1 + 1; j < nb_mods; j++)
7839 carry[j - (n_limb1 + 1)] = u[j];
7840 } else {
7841 for(j = n_limb1; j < nb_mods - 1; j++) {
7842 carry[j - n_limb1] = (u[j] >> shift) |
7843 (u[j + 1] << (LIMB_BITS - shift));
7844 }
7845 carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
7846 }
7847 i++;
7848 }
7849 }
7850 }
7851 #else
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)7852 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
7853 const NTTLimb *buf, int fft_len_log2, int dpl,
7854 int nb_mods)
7855 {
7856 const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
7857 const limb_t *mods_cr, *mods_cr_inv;
7858 limb_t y[NB_MODS], u[NB_MODS+2], carry[NB_MODS], fft_len, base_mask1, r;
7859 slimb_t i, len, pos;
7860 int j, k, l, shift, n_limb1;
7861 dlimb_t t;
7862
7863 j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
7864 mods_cr = ntt_mods_cr + j;
7865 mods_cr_inv = s->ntt_mods_cr_inv + j;
7866
7867 shift = dpl & (LIMB_BITS - 1);
7868 if (shift == 0)
7869 base_mask1 = -1;
7870 else
7871 base_mask1 = ((limb_t)1 << shift) - 1;
7872 n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
7873 for(j = 0; j < NB_MODS; j++)
7874 carry[j] = 0;
7875 for(j = 0; j < NB_MODS; j++)
7876 u[j] = 0; /* avoid warnings */
7877 memset(tabr, 0, sizeof(limb_t) * r_len);
7878 fft_len = (limb_t)1 << fft_len_log2;
7879 len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
7880 for(i = 0; i < len; i++) {
7881 for(j = 0; j < nb_mods; j++) {
7882 y[j] = ntt_limb_to_int(buf[i + fft_len * j], mods[j]);
7883 }
7884
7885 /* Chinese remainder to get mixed radix representation */
7886 l = 0;
7887 for(j = 0; j < nb_mods - 1; j++) {
7888 for(k = j + 1; k < nb_mods; k++) {
7889 limb_t m;
7890 m = mods[k];
7891 /* Note: there is no overflow in the sub_mod() because
7892 the modulos are sorted by increasing order */
7893 y[k] = mul_mod_fast2(y[k] - y[j] + m,
7894 mods_cr[l], m, mods_cr_inv[l]);
7895 l++;
7896 }
7897 }
7898
7899 /* back to normal representation */
7900 u[0] = y[nb_mods - 1];
7901 l = 1;
7902 for(j = nb_mods - 2; j >= 1; j--) {
7903 r = y[j];
7904 for(k = 0; k < l; k++) {
7905 t = (dlimb_t)u[k] * mods[j] + r;
7906 r = t >> LIMB_BITS;
7907 u[k] = t;
7908 }
7909 u[l] = r;
7910 l++;
7911 }
7912
7913 /* last step adds the carry */
7914 r = y[0];
7915 for(k = 0; k < l; k++) {
7916 t = (dlimb_t)u[k] * mods[j] + r + carry[k];
7917 r = t >> LIMB_BITS;
7918 u[k] = t;
7919 }
7920 u[l] = r + carry[l];
7921
7922 #if 0
7923 printf("%" PRId64 ": ", (int64_t)i);
7924 for(j = nb_mods - 1; j >= 0; j--) {
7925 printf(" " FMT_LIMB, u[j]);
7926 }
7927 printf("\n");
7928 #endif
7929
7930 /* write the digits */
7931 pos = i * dpl;
7932 for(j = 0; j < n_limb1; j++) {
7933 put_bits(tabr, r_len, pos, u[j]);
7934 pos += LIMB_BITS;
7935 }
7936 put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
7937 /* shift by dpl digits and set the carry */
7938 if (shift == 0) {
7939 for(j = n_limb1 + 1; j < nb_mods; j++)
7940 carry[j - (n_limb1 + 1)] = u[j];
7941 } else {
7942 for(j = n_limb1; j < nb_mods - 1; j++) {
7943 carry[j - n_limb1] = (u[j] >> shift) |
7944 (u[j + 1] << (LIMB_BITS - shift));
7945 }
7946 carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
7947 }
7948 }
7949 }
7950 #endif
7951
ntt_static_init(bf_context_t * s1)7952 static int ntt_static_init(bf_context_t *s1)
7953 {
7954 BFNTTState *s;
7955 int inverse, i, j, k, l;
7956 limb_t c, c_inv, c_inv2, m, m_inv;
7957
7958 if (s1->ntt_state)
7959 return 0;
7960 #if defined(__AVX2__)
7961 s = bf_aligned_malloc(s1, sizeof(*s), 64);
7962 #else
7963 s = bf_malloc(s1, sizeof(*s));
7964 #endif
7965 if (!s)
7966 return -1;
7967 memset(s, 0, sizeof(*s));
7968 s1->ntt_state = s;
7969 s->ctx = s1;
7970
7971 for(j = 0; j < NB_MODS; j++) {
7972 m = ntt_mods[j];
7973 m_inv = init_mul_mod_fast(m);
7974 s->ntt_mods_div[j] = m_inv;
7975 #if defined(__AVX2__)
7976 s->ntt_mods_vec[j] = _mm256_set1_pd(m);
7977 s->ntt_mods_inv_vec[j] = _mm256_set1_pd(1.0 / (double)m);
7978 #endif
7979 c_inv2 = (m + 1) / 2; /* 1/2 */
7980 c_inv = 1;
7981 for(i = 0; i <= NTT_PROOT_2EXP; i++) {
7982 s->ntt_len_inv[j][i][0] = c_inv;
7983 s->ntt_len_inv[j][i][1] = init_mul_mod_fast2(c_inv, m);
7984 c_inv = mul_mod_fast(c_inv, c_inv2, m, m_inv);
7985 }
7986
7987 for(inverse = 0; inverse < 2; inverse++) {
7988 c = ntt_proot[inverse][j];
7989 for(i = 0; i < NTT_PROOT_2EXP; i++) {
7990 s->ntt_proot_pow[j][inverse][NTT_PROOT_2EXP - i] = c;
7991 s->ntt_proot_pow_inv[j][inverse][NTT_PROOT_2EXP - i] =
7992 init_mul_mod_fast2(c, m);
7993 c = mul_mod_fast(c, c, m, m_inv);
7994 }
7995 }
7996 }
7997
7998 l = 0;
7999 for(j = 0; j < NB_MODS - 1; j++) {
8000 for(k = j + 1; k < NB_MODS; k++) {
8001 #if defined(__AVX2__)
8002 s->ntt_mods_cr_vec[l] = _mm256_set1_pd(int_to_ntt_limb2(ntt_mods_cr[l],
8003 ntt_mods[k]));
8004 #else
8005 s->ntt_mods_cr_inv[l] = init_mul_mod_fast2(ntt_mods_cr[l],
8006 ntt_mods[k]);
8007 #endif
8008 l++;
8009 }
8010 }
8011 return 0;
8012 }
8013
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8014 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8015 {
8016 int dpl, fft_len_log2, n_bits, nb_mods, dpl_found, fft_len_log2_found;
8017 int int_bits, nb_mods_found;
8018 limb_t cost, min_cost;
8019
8020 min_cost = -1;
8021 dpl_found = 0;
8022 nb_mods_found = 4;
8023 fft_len_log2_found = 0;
8024 for(nb_mods = 3; nb_mods <= NB_MODS; nb_mods++) {
8025 int_bits = ntt_int_bits[NB_MODS - nb_mods];
8026 dpl = bf_min((int_bits - 4) / 2,
8027 2 * LIMB_BITS + 2 * NTT_MOD_LOG2_MIN - NTT_MOD_LOG2_MAX);
8028 for(;;) {
8029 fft_len_log2 = ceil_log2((len * LIMB_BITS + dpl - 1) / dpl);
8030 if (fft_len_log2 > NTT_PROOT_2EXP)
8031 goto next;
8032 n_bits = fft_len_log2 + 2 * dpl;
8033 if (n_bits <= int_bits) {
8034 cost = ((limb_t)(fft_len_log2 + 1) << fft_len_log2) * nb_mods;
8035 // printf("n=%d dpl=%d: cost=%" PRId64 "\n", nb_mods, dpl, (int64_t)cost);
8036 if (cost < min_cost) {
8037 min_cost = cost;
8038 dpl_found = dpl;
8039 nb_mods_found = nb_mods;
8040 fft_len_log2_found = fft_len_log2;
8041 }
8042 break;
8043 }
8044 dpl--;
8045 if (dpl == 0)
8046 break;
8047 }
8048 next: ;
8049 }
8050 if (!dpl_found)
8051 abort();
8052 /* limit dpl if possible to reduce fixed cost of limb/NTT conversion */
8053 if (dpl_found > (LIMB_BITS + NTT_MOD_LOG2_MIN) &&
8054 ((limb_t)(LIMB_BITS + NTT_MOD_LOG2_MIN) << fft_len_log2_found) >=
8055 len * LIMB_BITS) {
8056 dpl_found = LIMB_BITS + NTT_MOD_LOG2_MIN;
8057 }
8058 *pnb_mods = nb_mods_found;
8059 *pdpl = dpl_found;
8060 return fft_len_log2_found;
8061 }
8062
8063 /* return 0 if OK, -1 if memory error */
fft_mul(bf_context_t * s1,bf_t * res,limb_t * a_tab,limb_t a_len,limb_t * b_tab,limb_t b_len,int mul_flags)8064 static no_inline int fft_mul(bf_context_t *s1,
8065 bf_t *res, limb_t *a_tab, limb_t a_len,
8066 limb_t *b_tab, limb_t b_len, int mul_flags)
8067 {
8068 BFNTTState *s;
8069 int dpl, fft_len_log2, j, nb_mods, reduced_mem;
8070 slimb_t len, fft_len;
8071 NTTLimb *buf1, *buf2, *ptr;
8072 #if defined(USE_MUL_CHECK)
8073 limb_t ha, hb, hr, h_ref;
8074 #endif
8075
8076 if (ntt_static_init(s1))
8077 return -1;
8078 s = s1->ntt_state;
8079
8080 /* find the optimal number of digits per limb (dpl) */
8081 len = a_len + b_len;
8082 fft_len_log2 = bf_get_fft_size(&dpl, &nb_mods, len);
8083 fft_len = (uint64_t)1 << fft_len_log2;
8084 // printf("len=%" PRId64 " fft_len_log2=%d dpl=%d\n", len, fft_len_log2, dpl);
8085 #if defined(USE_MUL_CHECK)
8086 ha = mp_mod1(a_tab, a_len, BF_CHKSUM_MOD, 0);
8087 hb = mp_mod1(b_tab, b_len, BF_CHKSUM_MOD, 0);
8088 #endif
8089 if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) == 0) {
8090 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8091 bf_resize(res, 0);
8092 } else if (mul_flags & FFT_MUL_R_OVERLAP_B) {
8093 limb_t *tmp_tab, tmp_len;
8094 /* it is better to free 'b' first */
8095 tmp_tab = a_tab;
8096 a_tab = b_tab;
8097 b_tab = tmp_tab;
8098 tmp_len = a_len;
8099 a_len = b_len;
8100 b_len = tmp_len;
8101 }
8102 buf1 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8103 if (!buf1)
8104 return -1;
8105 limb_to_ntt(s, buf1, fft_len, a_tab, a_len, dpl,
8106 NB_MODS - nb_mods, nb_mods);
8107 if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) ==
8108 FFT_MUL_R_OVERLAP_A) {
8109 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8110 bf_resize(res, 0);
8111 }
8112 reduced_mem = (fft_len_log2 >= 14);
8113 if (!reduced_mem) {
8114 buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8115 if (!buf2)
8116 goto fail;
8117 limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8118 NB_MODS - nb_mods, nb_mods);
8119 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8120 bf_resize(res, 0); /* in case res == b */
8121 } else {
8122 buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len);
8123 if (!buf2)
8124 goto fail;
8125 }
8126 for(j = 0; j < nb_mods; j++) {
8127 if (reduced_mem) {
8128 limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8129 NB_MODS - nb_mods + j, 1);
8130 ptr = buf2;
8131 } else {
8132 ptr = buf2 + fft_len * j;
8133 }
8134 if (ntt_conv(s, buf1 + fft_len * j, ptr,
8135 fft_len_log2, fft_len_log2, j + NB_MODS - nb_mods))
8136 goto fail;
8137 }
8138 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8139 bf_resize(res, 0); /* in case res == b and reduced mem */
8140 ntt_free(s, buf2);
8141 buf2 = NULL;
8142 if (!(mul_flags & FFT_MUL_R_NORESIZE)) {
8143 if (bf_resize(res, len))
8144 goto fail;
8145 }
8146 ntt_to_limb(s, res->tab, len, buf1, fft_len_log2, dpl, nb_mods);
8147 ntt_free(s, buf1);
8148 #if defined(USE_MUL_CHECK)
8149 hr = mp_mod1(res->tab, len, BF_CHKSUM_MOD, 0);
8150 h_ref = mul_mod(ha, hb, BF_CHKSUM_MOD);
8151 if (hr != h_ref) {
8152 printf("ntt_mul_error: len=%" PRId_LIMB " fft_len_log2=%d dpl=%d nb_mods=%d\n",
8153 len, fft_len_log2, dpl, nb_mods);
8154 // printf("ha=0x" FMT_LIMB" hb=0x" FMT_LIMB " hr=0x" FMT_LIMB " expected=0x" FMT_LIMB "\n", ha, hb, hr, h_ref);
8155 exit(1);
8156 }
8157 #endif
8158 return 0;
8159 fail:
8160 ntt_free(s, buf1);
8161 ntt_free(s, buf2);
8162 return -1;
8163 }
8164
8165 #else /* USE_FFT_MUL */
8166
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8167 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8168 {
8169 return 0;
8170 }
8171
8172 #endif /* !USE_FFT_MUL */
8173