1 /*
2  * Tiny arbitrary precision floating point library
3  *
4  * Copyright (c) 2017-2020 Fabrice Bellard
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22  * THE SOFTWARE.
23  */
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <inttypes.h>
27 #include <math.h>
28 #include <string.h>
29 #include <assert.h>
30 
31 #ifdef __AVX2__
32 #include <immintrin.h>
33 #endif
34 
35 #include "cutils.h"
36 #include "libbf.h"
37 
38 /* enable it to check the multiplication result */
39 //#define USE_MUL_CHECK
40 /* enable it to use FFT/NTT multiplication */
41 #define USE_FFT_MUL
42 /* enable decimal floating point support */
43 #define USE_BF_DEC
44 
45 //#define inline __attribute__((always_inline))
46 
47 #ifdef __AVX2__
48 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
49 #else
50 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
51 #endif
52 
53 /* XXX: adjust */
54 #define DIVNORM_LARGE_THRESHOLD 50
55 #define UDIV1NORM_THRESHOLD 3
56 
57 #if LIMB_BITS == 64
58 #define FMT_LIMB1 "%" PRIx64
59 #define FMT_LIMB "%016" PRIx64
60 #define PRId_LIMB PRId64
61 #define PRIu_LIMB PRIu64
62 
63 #else
64 
65 #define FMT_LIMB1 "%x"
66 #define FMT_LIMB "%08x"
67 #define PRId_LIMB "d"
68 #define PRIu_LIMB "u"
69 
70 #endif
71 
72 typedef intptr_t mp_size_t;
73 
74 typedef int bf_op2_func_t(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
75                           bf_flags_t flags);
76 
77 #ifdef USE_FFT_MUL
78 
79 #define FFT_MUL_R_OVERLAP_A (1 << 0)
80 #define FFT_MUL_R_OVERLAP_B (1 << 1)
81 #define FFT_MUL_R_NORESIZE  (1 << 2)
82 
83 static no_inline int fft_mul(bf_context_t *s,
84                              bf_t *res, limb_t *a_tab, limb_t a_len,
85                              limb_t *b_tab, limb_t b_len, int mul_flags);
86 static void fft_clear_cache(bf_context_t *s);
87 #endif
88 #ifdef USE_BF_DEC
89 static void mp_pow_init(void);
90 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos);
91 #endif
92 
93 
94 /* could leading zeros */
clz(limb_t a)95 static inline int clz(limb_t a)
96 {
97     if (a == 0) {
98         return LIMB_BITS;
99     } else {
100 #if LIMB_BITS == 64
101         return clz64(a);
102 #else
103         return clz32(a);
104 #endif
105     }
106 }
107 
ctz(limb_t a)108 static inline int ctz(limb_t a)
109 {
110     if (a == 0) {
111         return LIMB_BITS;
112     } else {
113 #if LIMB_BITS == 64
114         return ctz64(a);
115 #else
116         return ctz32(a);
117 #endif
118     }
119 }
120 
ceil_log2(limb_t a)121 static inline int ceil_log2(limb_t a)
122 {
123     if (a <= 1)
124         return 0;
125     else
126         return LIMB_BITS - clz(a - 1);
127 }
128 
129 #if 0 //unused
130 /* b must be >= 1 */
131 static inline slimb_t ceil_div(slimb_t a, slimb_t b)
132 {
133     if (a >= 0)
134         return (a + b - 1) / b;
135     else
136         return a / b;
137 }
138 #endif
139 
140 /* b must be >= 1 */
floor_div(slimb_t a,slimb_t b)141 static inline slimb_t floor_div(slimb_t a, slimb_t b)
142 {
143     if (a >= 0) {
144         return a / b;
145     } else {
146         return (a - b + 1) / b;
147     }
148 }
149 
150 /* return r = a modulo b (0 <= r <= b - 1. b must be >= 1 */
smod(slimb_t a,slimb_t b)151 static inline limb_t smod(slimb_t a, slimb_t b)
152 {
153     a = a % (slimb_t)b;
154     if (a < 0)
155         a += b;
156     return a;
157 }
158 
159 #define malloc(s) malloc_is_forbidden(s)
160 #define free(p) free_is_forbidden(p)
161 #define realloc(p, s) realloc_is_forbidden(p, s)
162 
bf_context_init(bf_context_t * s,bf_realloc_func_t * realloc_func,void * realloc_opaque)163 void bf_context_init(bf_context_t *s, bf_realloc_func_t *realloc_func,
164                      void *realloc_opaque)
165 {
166     memset(s, 0, sizeof(*s));
167     s->realloc_func = realloc_func;
168     s->realloc_opaque = realloc_opaque;
169 #ifdef USE_BF_DEC
170     mp_pow_init();
171 #endif
172 }
173 
bf_context_end(bf_context_t * s)174 void bf_context_end(bf_context_t *s)
175 {
176     bf_clear_cache(s);
177 }
178 
bf_init(bf_context_t * s,bf_t * r)179 void bf_init(bf_context_t *s, bf_t *r)
180 {
181     r->ctx = s;
182     r->sign = 0;
183     r->expn = BF_EXP_ZERO;
184     r->len = 0;
185     r->tab = NULL;
186 }
187 
188 /* return 0 if OK, -1 if alloc error */
bf_resize(bf_t * r,limb_t len)189 int bf_resize(bf_t *r, limb_t len)
190 {
191     limb_t *tab;
192 
193     if (len != r->len) {
194         tab = bf_realloc(r->ctx, r->tab, len * sizeof(limb_t));
195         if (!tab && len != 0)
196             return -1;
197         r->tab = tab;
198         r->len = len;
199     }
200     return 0;
201 }
202 
203 /* return 0 or BF_ST_MEM_ERROR */
bf_set_ui(bf_t * r,uint64_t a)204 int bf_set_ui(bf_t *r, uint64_t a)
205 {
206     r->sign = 0;
207     if (a == 0) {
208         r->expn = BF_EXP_ZERO;
209         bf_resize(r, 0); /* cannot fail */
210     }
211 #if LIMB_BITS == 32
212     else if (a <= 0xffffffff)
213 #else
214     else
215 #endif
216     {
217         int shift;
218         if (bf_resize(r, 1))
219             goto fail;
220         shift = clz(a);
221         r->tab[0] = a << shift;
222         r->expn = LIMB_BITS - shift;
223     }
224 #if LIMB_BITS == 32
225     else {
226         uint32_t a1, a0;
227         int shift;
228         if (bf_resize(r, 2))
229             goto fail;
230         a0 = a;
231         a1 = a >> 32;
232         shift = clz(a1);
233         r->tab[0] = a0 << shift;
234         r->tab[1] = (a1 << shift) | (a0 >> (LIMB_BITS - shift));
235         r->expn = 2 * LIMB_BITS - shift;
236     }
237 #endif
238     return 0;
239  fail:
240     bf_set_nan(r);
241     return BF_ST_MEM_ERROR;
242 }
243 
244 /* return 0 or BF_ST_MEM_ERROR */
bf_set_si(bf_t * r,int64_t a)245 int bf_set_si(bf_t *r, int64_t a)
246 {
247     int ret;
248 
249     if (a < 0) {
250         ret = bf_set_ui(r, -a);
251         r->sign = 1;
252     } else {
253         ret = bf_set_ui(r, a);
254     }
255     return ret;
256 }
257 
bf_set_nan(bf_t * r)258 void bf_set_nan(bf_t *r)
259 {
260     bf_resize(r, 0); /* cannot fail */
261     r->expn = BF_EXP_NAN;
262     r->sign = 0;
263 }
264 
bf_set_zero(bf_t * r,int is_neg)265 void bf_set_zero(bf_t *r, int is_neg)
266 {
267     bf_resize(r, 0); /* cannot fail */
268     r->expn = BF_EXP_ZERO;
269     r->sign = is_neg;
270 }
271 
bf_set_inf(bf_t * r,int is_neg)272 void bf_set_inf(bf_t *r, int is_neg)
273 {
274     bf_resize(r, 0); /* cannot fail */
275     r->expn = BF_EXP_INF;
276     r->sign = is_neg;
277 }
278 
279 /* return 0 or BF_ST_MEM_ERROR */
bf_set(bf_t * r,const bf_t * a)280 int bf_set(bf_t *r, const bf_t *a)
281 {
282     if (r == a)
283         return 0;
284     if (bf_resize(r, a->len)) {
285         bf_set_nan(r);
286         return BF_ST_MEM_ERROR;
287     }
288     r->sign = a->sign;
289     r->expn = a->expn;
290     memcpy(r->tab, a->tab, a->len * sizeof(limb_t));
291     return 0;
292 }
293 
294 /* equivalent to bf_set(r, a); bf_delete(a) */
bf_move(bf_t * r,bf_t * a)295 void bf_move(bf_t *r, bf_t *a)
296 {
297     bf_context_t *s = r->ctx;
298     if (r == a)
299         return;
300     bf_free(s, r->tab);
301     *r = *a;
302 }
303 
get_limbz(const bf_t * a,limb_t idx)304 static limb_t get_limbz(const bf_t *a, limb_t idx)
305 {
306     if (idx >= a->len)
307         return 0;
308     else
309         return a->tab[idx];
310 }
311 
312 /* get LIMB_BITS at bit position 'pos' in tab */
get_bits(const limb_t * tab,limb_t len,slimb_t pos)313 static inline limb_t get_bits(const limb_t *tab, limb_t len, slimb_t pos)
314 {
315     limb_t i, a0, a1;
316     int p;
317 
318     i = pos >> LIMB_LOG2_BITS;
319     p = pos & (LIMB_BITS - 1);
320     if (i < len)
321         a0 = tab[i];
322     else
323         a0 = 0;
324     if (p == 0) {
325         return a0;
326     } else {
327         i++;
328         if (i < len)
329             a1 = tab[i];
330         else
331             a1 = 0;
332         return (a0 >> p) | (a1 << (LIMB_BITS - p));
333     }
334 }
335 
get_bit(const limb_t * tab,limb_t len,slimb_t pos)336 static inline limb_t get_bit(const limb_t *tab, limb_t len, slimb_t pos)
337 {
338     slimb_t i;
339     i = pos >> LIMB_LOG2_BITS;
340     if (i < 0 || i >= len)
341         return 0;
342     return (tab[i] >> (pos & (LIMB_BITS - 1))) & 1;
343 }
344 
limb_mask(int start,int last)345 static inline limb_t limb_mask(int start, int last)
346 {
347     limb_t v;
348     int n;
349     n = last - start + 1;
350     if (n == LIMB_BITS)
351         v = -1;
352     else
353         v = (((limb_t)1 << n) - 1) << start;
354     return v;
355 }
356 
mp_scan_nz(const limb_t * tab,mp_size_t n)357 static limb_t mp_scan_nz(const limb_t *tab, mp_size_t n)
358 {
359     mp_size_t i;
360     for(i = 0; i < n; i++) {
361         if (tab[i] != 0)
362             return 1;
363     }
364     return 0;
365 }
366 
367 /* return != 0 if one bit between 0 and bit_pos inclusive is not zero. */
scan_bit_nz(const bf_t * r,slimb_t bit_pos)368 static inline limb_t scan_bit_nz(const bf_t *r, slimb_t bit_pos)
369 {
370     slimb_t pos;
371     limb_t v;
372 
373     pos = bit_pos >> LIMB_LOG2_BITS;
374     if (pos < 0)
375         return 0;
376     v = r->tab[pos] & limb_mask(0, bit_pos & (LIMB_BITS - 1));
377     if (v != 0)
378         return 1;
379     pos--;
380     while (pos >= 0) {
381         if (r->tab[pos] != 0)
382             return 1;
383         pos--;
384     }
385     return 0;
386 }
387 
388 /* return the addend for rounding. Note that prec can be <= 0 (for
389    BF_FLAG_RADPNT_PREC) */
bf_get_rnd_add(int * pret,const bf_t * r,limb_t l,slimb_t prec,int rnd_mode)390 static int bf_get_rnd_add(int *pret, const bf_t *r, limb_t l,
391                           slimb_t prec, int rnd_mode)
392 {
393     int add_one, inexact;
394     limb_t bit1, bit0;
395 
396     if (rnd_mode == BF_RNDF) {
397         bit0 = 1; /* faithful rounding does not honor the INEXACT flag */
398     } else {
399         /* starting limb for bit 'prec + 1' */
400         bit0 = scan_bit_nz(r, l * LIMB_BITS - 1 - bf_max(0, prec + 1));
401     }
402 
403     /* get the bit at 'prec' */
404     bit1 = get_bit(r->tab, l, l * LIMB_BITS - 1 - prec);
405     inexact = (bit1 | bit0) != 0;
406 
407     add_one = 0;
408     switch(rnd_mode) {
409     case BF_RNDZ:
410         break;
411     case BF_RNDN:
412         if (bit1) {
413             if (bit0) {
414                 add_one = 1;
415             } else {
416                 /* round to even */
417                 add_one =
418                     get_bit(r->tab, l, l * LIMB_BITS - 1 - (prec - 1));
419             }
420         }
421         break;
422     case BF_RNDD:
423     case BF_RNDU:
424         if (r->sign == (rnd_mode == BF_RNDD))
425             add_one = inexact;
426         break;
427     case BF_RNDNA:
428     case BF_RNDF:
429         add_one = bit1;
430         break;
431     case BF_RNDNU:
432         if (bit1) {
433             if (r->sign)
434                 add_one = bit0;
435             else
436                 add_one = 1;
437         }
438         break;
439     default:
440         abort();
441     }
442 
443     if (inexact)
444         *pret |= BF_ST_INEXACT;
445     return add_one;
446 }
447 
bf_set_overflow(bf_t * r,int sign,limb_t prec,bf_flags_t flags)448 static int bf_set_overflow(bf_t *r, int sign, limb_t prec, bf_flags_t flags)
449 {
450     slimb_t i, l, e_max;
451     int rnd_mode;
452 
453     rnd_mode = flags & BF_RND_MASK;
454     if (prec == BF_PREC_INF ||
455         rnd_mode == BF_RNDN ||
456         rnd_mode == BF_RNDNA ||
457         rnd_mode == BF_RNDNU ||
458         (rnd_mode == BF_RNDD && sign == 1) ||
459         (rnd_mode == BF_RNDU && sign == 0)) {
460         bf_set_inf(r, sign);
461     } else {
462         /* set to maximum finite number */
463         l = (prec + LIMB_BITS - 1) / LIMB_BITS;
464         if (bf_resize(r, l)) {
465             bf_set_nan(r);
466             return BF_ST_MEM_ERROR;
467         }
468         r->tab[0] = limb_mask((-prec) & (LIMB_BITS - 1),
469                               LIMB_BITS - 1);
470         for(i = 1; i < l; i++)
471             r->tab[i] = (limb_t)-1;
472         e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
473         r->expn = e_max;
474         r->sign = sign;
475     }
476     return BF_ST_OVERFLOW | BF_ST_INEXACT;
477 }
478 
479 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
480    assumed to have length 'l' (1 <= l <= r->len). Note: 'prec1' can be
481    infinite (BF_PREC_INF). Can fail with BF_ST_MEM_ERROR in case of
482    overflow not returning infinity. */
__bf_round(bf_t * r,limb_t prec1,bf_flags_t flags,limb_t l)483 static int __bf_round(bf_t *r, limb_t prec1, bf_flags_t flags, limb_t l)
484 {
485     limb_t v, a;
486     int shift, add_one, ret, rnd_mode;
487     slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
488 
489     /* e_min and e_max are computed to match the IEEE 754 conventions */
490     e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
491     e_min = -e_range + 3;
492     e_max = e_range;
493 
494     if (flags & BF_FLAG_RADPNT_PREC) {
495         /* 'prec' is the precision after the radix point */
496         if (prec1 != BF_PREC_INF)
497             prec = r->expn + prec1;
498         else
499             prec = prec1;
500     } else if (unlikely(r->expn < e_min) && (flags & BF_FLAG_SUBNORMAL)) {
501         /* restrict the precision in case of potentially subnormal
502            result */
503         assert(prec1 != BF_PREC_INF);
504         prec = prec1 - (e_min - r->expn);
505     } else {
506         prec = prec1;
507     }
508 
509     /* round to prec bits */
510     rnd_mode = flags & BF_RND_MASK;
511     ret = 0;
512     add_one = bf_get_rnd_add(&ret, r, l, prec, rnd_mode);
513 
514     if (prec <= 0) {
515         if (add_one) {
516             bf_resize(r, 1); /* cannot fail */
517             r->tab[0] = (limb_t)1 << (LIMB_BITS - 1);
518             r->expn += 1 - prec;
519             ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
520             return ret;
521         } else {
522             goto underflow;
523         }
524     } else if (add_one) {
525         limb_t carry;
526 
527         /* add one starting at digit 'prec - 1' */
528         bit_pos = l * LIMB_BITS - 1 - (prec - 1);
529         pos = bit_pos >> LIMB_LOG2_BITS;
530         carry = (limb_t)1 << (bit_pos & (LIMB_BITS - 1));
531 
532         for(i = pos; i < l; i++) {
533             v = r->tab[i] + carry;
534             carry = (v < carry);
535             r->tab[i] = v;
536             if (carry == 0)
537                 break;
538         }
539         if (carry) {
540             /* shift right by one digit */
541             v = 1;
542             for(i = l - 1; i >= pos; i--) {
543                 a = r->tab[i];
544                 r->tab[i] = (a >> 1) | (v << (LIMB_BITS - 1));
545                 v = a;
546             }
547             r->expn++;
548         }
549     }
550 
551     /* check underflow */
552     if (unlikely(r->expn < e_min)) {
553         if (flags & BF_FLAG_SUBNORMAL) {
554             /* if inexact, also set the underflow flag */
555             if (ret & BF_ST_INEXACT)
556                 ret |= BF_ST_UNDERFLOW;
557         } else {
558         underflow:
559             ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
560             bf_set_zero(r, r->sign);
561             return ret;
562         }
563     }
564 
565     /* check overflow */
566     if (unlikely(r->expn > e_max))
567         return bf_set_overflow(r, r->sign, prec1, flags);
568 
569     /* keep the bits starting at 'prec - 1' */
570     bit_pos = l * LIMB_BITS - 1 - (prec - 1);
571     i = bit_pos >> LIMB_LOG2_BITS;
572     if (i >= 0) {
573         shift = bit_pos & (LIMB_BITS - 1);
574         if (shift != 0)
575             r->tab[i] &= limb_mask(shift, LIMB_BITS - 1);
576     } else {
577         i = 0;
578     }
579     /* remove trailing zeros */
580     while (r->tab[i] == 0)
581         i++;
582     if (i > 0) {
583         l -= i;
584         memmove(r->tab, r->tab + i, l * sizeof(limb_t));
585     }
586     bf_resize(r, l); /* cannot fail */
587     return ret;
588 }
589 
590 /* 'r' must be a finite number. */
bf_normalize_and_round(bf_t * r,limb_t prec1,bf_flags_t flags)591 int bf_normalize_and_round(bf_t *r, limb_t prec1, bf_flags_t flags)
592 {
593     limb_t l, v, a;
594     int shift, ret;
595     slimb_t i;
596 
597     //    bf_print_str("bf_renorm", r);
598     l = r->len;
599     while (l > 0 && r->tab[l - 1] == 0)
600         l--;
601     if (l == 0) {
602         /* zero */
603         r->expn = BF_EXP_ZERO;
604         bf_resize(r, 0); /* cannot fail */
605         ret = 0;
606     } else {
607         r->expn -= (r->len - l) * LIMB_BITS;
608         /* shift to have the MSB set to '1' */
609         v = r->tab[l - 1];
610         shift = clz(v);
611         if (shift != 0) {
612             v = 0;
613             for(i = 0; i < l; i++) {
614                 a = r->tab[i];
615                 r->tab[i] = (a << shift) | (v >> (LIMB_BITS - shift));
616                 v = a;
617             }
618             r->expn -= shift;
619         }
620         ret = __bf_round(r, prec1, flags, l);
621     }
622     //    bf_print_str("r_final", r);
623     return ret;
624 }
625 
626 /* return true if rounding can be done at precision 'prec' assuming
627    the exact result r is such that |r-a| <= 2^(EXP(a)-k). */
628 /* XXX: check the case where the exponent would be incremented by the
629    rounding */
bf_can_round(const bf_t * a,slimb_t prec,bf_rnd_t rnd_mode,slimb_t k)630 int bf_can_round(const bf_t *a, slimb_t prec, bf_rnd_t rnd_mode, slimb_t k)
631 {
632     BOOL is_rndn;
633     slimb_t bit_pos, n;
634     limb_t bit;
635 
636     if (a->expn == BF_EXP_INF || a->expn == BF_EXP_NAN)
637         return FALSE;
638     if (rnd_mode == BF_RNDF) {
639         return (k >= (prec + 1));
640     }
641     if (a->expn == BF_EXP_ZERO)
642         return FALSE;
643     is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA ||
644                rnd_mode == BF_RNDNU);
645     if (k < (prec + 2))
646         return FALSE;
647     bit_pos = a->len * LIMB_BITS - 1 - prec;
648     n = k - prec;
649     /* bit pattern for RNDN or RNDNA: 0111.. or 1000...
650        for other rounding modes: 000... or 111...
651     */
652     bit = get_bit(a->tab, a->len, bit_pos);
653     bit_pos--;
654     n--;
655     bit ^= is_rndn;
656     /* XXX: slow, but a few iterations on average */
657     while (n != 0) {
658         if (get_bit(a->tab, a->len, bit_pos) != bit)
659             return TRUE;
660         bit_pos--;
661         n--;
662     }
663     return FALSE;
664 }
665 
666 /* Cannot fail with BF_ST_MEM_ERROR. */
bf_round(bf_t * r,limb_t prec,bf_flags_t flags)667 int bf_round(bf_t *r, limb_t prec, bf_flags_t flags)
668 {
669     if (r->len == 0)
670         return 0;
671     return __bf_round(r, prec, flags, r->len);
672 }
673 
674 /* for debugging */
dump_limbs(const char * str,const limb_t * tab,limb_t n)675 static __maybe_unused void dump_limbs(const char *str, const limb_t *tab, limb_t n)
676 {
677     limb_t i;
678     printf("%s: len=%" PRId_LIMB "\n", str, n);
679     for(i = 0; i < n; i++) {
680         printf("%" PRId_LIMB ": " FMT_LIMB "\n",
681                i, tab[i]);
682     }
683 }
684 
mp_print_str(const char * str,const limb_t * tab,limb_t n)685 void mp_print_str(const char *str, const limb_t *tab, limb_t n)
686 {
687     slimb_t i;
688     printf("%s= 0x", str);
689     for(i = n - 1; i >= 0; i--) {
690         if (i != (n - 1))
691             printf("_");
692         printf(FMT_LIMB, tab[i]);
693     }
694     printf("\n");
695 }
696 
mp_print_str_h(const char * str,const limb_t * tab,limb_t n,limb_t high)697 static __maybe_unused void mp_print_str_h(const char *str,
698                                           const limb_t *tab, limb_t n,
699                                           limb_t high)
700 {
701     slimb_t i;
702     printf("%s= 0x", str);
703     printf(FMT_LIMB, high);
704     for(i = n - 1; i >= 0; i--) {
705         printf("_");
706         printf(FMT_LIMB, tab[i]);
707     }
708     printf("\n");
709 }
710 
711 /* for debugging */
bf_print_str(const char * str,const bf_t * a)712 void bf_print_str(const char *str, const bf_t *a)
713 {
714     slimb_t i;
715     printf("%s=", str);
716 
717     if (a->expn == BF_EXP_NAN) {
718         printf("NaN");
719     } else {
720         if (a->sign)
721             putchar('-');
722         if (a->expn == BF_EXP_ZERO) {
723             putchar('0');
724         } else if (a->expn == BF_EXP_INF) {
725             printf("Inf");
726         } else {
727             printf("0x0.");
728             for(i = a->len - 1; i >= 0; i--)
729                 printf(FMT_LIMB, a->tab[i]);
730             printf("p%" PRId_LIMB, a->expn);
731         }
732     }
733     printf("\n");
734 }
735 
736 /* compare the absolute value of 'a' and 'b'. Return < 0 if a < b, 0
737    if a = b and > 0 otherwise. */
bf_cmpu(const bf_t * a,const bf_t * b)738 int bf_cmpu(const bf_t *a, const bf_t *b)
739 {
740     slimb_t i;
741     limb_t len, v1, v2;
742 
743     if (a->expn != b->expn) {
744         if (a->expn < b->expn)
745             return -1;
746         else
747             return 1;
748     }
749     len = bf_max(a->len, b->len);
750     for(i = len - 1; i >= 0; i--) {
751         v1 = get_limbz(a, a->len - len + i);
752         v2 = get_limbz(b, b->len - len + i);
753         if (v1 != v2) {
754             if (v1 < v2)
755                 return -1;
756             else
757                 return 1;
758         }
759     }
760     return 0;
761 }
762 
763 /* Full order: -0 < 0, NaN == NaN and NaN is larger than all other numbers */
bf_cmp_full(const bf_t * a,const bf_t * b)764 int bf_cmp_full(const bf_t *a, const bf_t *b)
765 {
766     int res;
767 
768     if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
769         if (a->expn == b->expn)
770             res = 0;
771         else if (a->expn == BF_EXP_NAN)
772             res = 1;
773         else
774             res = -1;
775     } else if (a->sign != b->sign) {
776         res = 1 - 2 * a->sign;
777     } else {
778         res = bf_cmpu(a, b);
779         if (a->sign)
780             res = -res;
781     }
782     return res;
783 }
784 
785 #define BF_CMP_EQ 1
786 #define BF_CMP_LT 2
787 #define BF_CMP_LE 3
788 
bf_cmp(const bf_t * a,const bf_t * b,int op)789 static int bf_cmp(const bf_t *a, const bf_t *b, int op)
790 {
791     BOOL is_both_zero;
792     int res;
793 
794     if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN)
795         return 0;
796     if (a->sign != b->sign) {
797         is_both_zero = (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_ZERO);
798         if (is_both_zero) {
799             return op & BF_CMP_EQ;
800         } else if (op & BF_CMP_LT) {
801             return a->sign;
802         } else {
803             return FALSE;
804         }
805     } else {
806         res = bf_cmpu(a, b);
807         if (res == 0) {
808             return op & BF_CMP_EQ;
809         } else if (op & BF_CMP_LT) {
810             return (res < 0) ^ a->sign;
811         } else {
812             return FALSE;
813         }
814     }
815 }
816 
bf_cmp_eq(const bf_t * a,const bf_t * b)817 int bf_cmp_eq(const bf_t *a, const bf_t *b)
818 {
819     return bf_cmp(a, b, BF_CMP_EQ);
820 }
821 
bf_cmp_le(const bf_t * a,const bf_t * b)822 int bf_cmp_le(const bf_t *a, const bf_t *b)
823 {
824     return bf_cmp(a, b, BF_CMP_LE);
825 }
826 
bf_cmp_lt(const bf_t * a,const bf_t * b)827 int bf_cmp_lt(const bf_t *a, const bf_t *b)
828 {
829     return bf_cmp(a, b, BF_CMP_LT);
830 }
831 
832 /* Compute the number of bits 'n' matching the pattern:
833    a= X1000..0
834    b= X0111..1
835 
836    When computing a-b, the result will have at least n leading zero
837    bits.
838 
839    Precondition: a > b and a.expn - b.expn = 0 or 1
840 */
count_cancelled_bits(const bf_t * a,const bf_t * b)841 static limb_t count_cancelled_bits(const bf_t *a, const bf_t *b)
842 {
843     slimb_t bit_offset, b_offset, n;
844     int p, p1;
845     limb_t v1, v2, mask;
846 
847     bit_offset = a->len * LIMB_BITS - 1;
848     b_offset = (b->len - a->len) * LIMB_BITS - (LIMB_BITS - 1) +
849         a->expn - b->expn;
850     n = 0;
851 
852     /* first search the equals bits */
853     for(;;) {
854         v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
855         v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
856         //        printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
857         if (v1 != v2)
858             break;
859         n += LIMB_BITS;
860         bit_offset -= LIMB_BITS;
861     }
862     /* find the position of the first different bit */
863     p = clz(v1 ^ v2) + 1;
864     n += p;
865     /* then search for '0' in a and '1' in b */
866     p = LIMB_BITS - p;
867     if (p > 0) {
868         /* search in the trailing p bits of v1 and v2 */
869         mask = limb_mask(0, p - 1);
870         p1 = bf_min(clz(v1 & mask), clz((~v2) & mask)) - (LIMB_BITS - p);
871         n += p1;
872         if (p1 != p)
873             goto done;
874     }
875     bit_offset -= LIMB_BITS;
876     for(;;) {
877         v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
878         v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
879         //        printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
880         if (v1 != 0 || v2 != -1) {
881             /* different: count the matching bits */
882             p1 = bf_min(clz(v1), clz(~v2));
883             n += p1;
884             break;
885         }
886         n += LIMB_BITS;
887         bit_offset -= LIMB_BITS;
888     }
889  done:
890     return n;
891 }
892 
bf_add_internal(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int b_neg)893 static int bf_add_internal(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
894                            bf_flags_t flags, int b_neg)
895 {
896     const bf_t *tmp;
897     int is_sub, ret, cmp_res, a_sign, b_sign;
898 
899     a_sign = a->sign;
900     b_sign = b->sign ^ b_neg;
901     is_sub = a_sign ^ b_sign;
902     cmp_res = bf_cmpu(a, b);
903     if (cmp_res < 0) {
904         tmp = a;
905         a = b;
906         b = tmp;
907         a_sign = b_sign; /* b_sign is never used later */
908     }
909     /* abs(a) >= abs(b) */
910     if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
911         /* zero result */
912         bf_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
913         ret = 0;
914     } else if (a->len == 0 || b->len == 0) {
915         ret = 0;
916         if (a->expn >= BF_EXP_INF) {
917             if (a->expn == BF_EXP_NAN) {
918                 /* at least one operand is NaN */
919                 bf_set_nan(r);
920             } else if (b->expn == BF_EXP_INF && is_sub) {
921                 /* infinities with different signs */
922                 bf_set_nan(r);
923                 ret = BF_ST_INVALID_OP;
924             } else {
925                 bf_set_inf(r, a_sign);
926             }
927         } else {
928             /* at least one zero and not subtract */
929             bf_set(r, a);
930             r->sign = a_sign;
931             goto renorm;
932         }
933     } else {
934         slimb_t d, a_offset, b_bit_offset, i, cancelled_bits;
935         limb_t carry, v1, v2, u, r_len, carry1, precl, tot_len, z, sub_mask;
936 
937         r->sign = a_sign;
938         r->expn = a->expn;
939         d = a->expn - b->expn;
940         /* must add more precision for the leading cancelled bits in
941            subtraction */
942         if (is_sub) {
943             if (d <= 1)
944                 cancelled_bits = count_cancelled_bits(a, b);
945             else
946                 cancelled_bits = 1;
947         } else {
948             cancelled_bits = 0;
949         }
950 
951         /* add two extra bits for rounding */
952         precl = (cancelled_bits + prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
953         tot_len = bf_max(a->len, b->len + (d + LIMB_BITS - 1) / LIMB_BITS);
954         r_len = bf_min(precl, tot_len);
955         if (bf_resize(r, r_len))
956             goto fail;
957         a_offset = a->len - r_len;
958         b_bit_offset = (b->len - r_len) * LIMB_BITS + d;
959 
960         /* compute the bits before for the rounding */
961         carry = is_sub;
962         z = 0;
963         sub_mask = -is_sub;
964         i = r_len - tot_len;
965         while (i < 0) {
966             slimb_t ap, bp;
967             BOOL inflag;
968 
969             ap = a_offset + i;
970             bp = b_bit_offset + i * LIMB_BITS;
971             inflag = FALSE;
972             if (ap >= 0 && ap < a->len) {
973                 v1 = a->tab[ap];
974                 inflag = TRUE;
975             } else {
976                 v1 = 0;
977             }
978             if (bp + LIMB_BITS > 0 && bp < (slimb_t)(b->len * LIMB_BITS)) {
979                 v2 = get_bits(b->tab, b->len, bp);
980                 inflag = TRUE;
981             } else {
982                 v2 = 0;
983             }
984             if (!inflag) {
985                 /* outside 'a' and 'b': go directly to the next value
986                    inside a or b so that the running time does not
987                    depend on the exponent difference */
988                 i = 0;
989                 if (ap < 0)
990                     i = bf_min(i, -a_offset);
991                 /* b_bit_offset + i * LIMB_BITS + LIMB_BITS >= 1
992                    equivalent to
993                    i >= ceil(-b_bit_offset + 1 - LIMB_BITS) / LIMB_BITS)
994                 */
995                 if (bp + LIMB_BITS <= 0)
996                     i = bf_min(i, (-b_bit_offset) >> LIMB_LOG2_BITS);
997             } else {
998                 i++;
999             }
1000             v2 ^= sub_mask;
1001             u = v1 + v2;
1002             carry1 = u < v1;
1003             u += carry;
1004             carry = (u < carry) | carry1;
1005             z |= u;
1006         }
1007         /* and the result */
1008         for(i = 0; i < r_len; i++) {
1009             v1 = get_limbz(a, a_offset + i);
1010             v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS);
1011             v2 ^= sub_mask;
1012             u = v1 + v2;
1013             carry1 = u < v1;
1014             u += carry;
1015             carry = (u < carry) | carry1;
1016             r->tab[i] = u;
1017         }
1018         /* set the extra bits for the rounding */
1019         r->tab[0] |= (z != 0);
1020 
1021         /* carry is only possible in add case */
1022         if (!is_sub && carry) {
1023             if (bf_resize(r, r_len + 1))
1024                 goto fail;
1025             r->tab[r_len] = 1;
1026             r->expn += LIMB_BITS;
1027         }
1028     renorm:
1029         ret = bf_normalize_and_round(r, prec, flags);
1030     }
1031     return ret;
1032  fail:
1033     bf_set_nan(r);
1034     return BF_ST_MEM_ERROR;
1035 }
1036 
__bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1037 static int __bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1038                      bf_flags_t flags)
1039 {
1040     return bf_add_internal(r, a, b, prec, flags, 0);
1041 }
1042 
__bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1043 static int __bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1044                      bf_flags_t flags)
1045 {
1046     return bf_add_internal(r, a, b, prec, flags, 1);
1047 }
1048 
mp_add(limb_t * res,const limb_t * op1,const limb_t * op2,limb_t n,limb_t carry)1049 limb_t mp_add(limb_t *res, const limb_t *op1, const limb_t *op2,
1050               limb_t n, limb_t carry)
1051 {
1052     slimb_t i;
1053     limb_t k, a, v, k1;
1054 
1055     k = carry;
1056     for(i=0;i<n;i++) {
1057         v = op1[i];
1058         a = v + op2[i];
1059         k1 = a < v;
1060         a = a + k;
1061         k = (a < k) | k1;
1062         res[i] = a;
1063     }
1064     return k;
1065 }
1066 
mp_add_ui(limb_t * tab,limb_t b,size_t n)1067 limb_t mp_add_ui(limb_t *tab, limb_t b, size_t n)
1068 {
1069     size_t i;
1070     limb_t k, a;
1071 
1072     k=b;
1073     for(i=0;i<n;i++) {
1074         if (k == 0)
1075             break;
1076         a = tab[i] + k;
1077         k = (a < k);
1078         tab[i] = a;
1079     }
1080     return k;
1081 }
1082 
mp_sub(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)1083 limb_t mp_sub(limb_t *res, const limb_t *op1, const limb_t *op2,
1084               mp_size_t n, limb_t carry)
1085 {
1086     int i;
1087     limb_t k, a, v, k1;
1088 
1089     k = carry;
1090     for(i=0;i<n;i++) {
1091         v = op1[i];
1092         a = v - op2[i];
1093         k1 = a > v;
1094         v = a - k;
1095         k = (v > a) | k1;
1096         res[i] = v;
1097     }
1098     return k;
1099 }
1100 
1101 /* compute 0 - op2 */
mp_neg(limb_t * res,const limb_t * op2,mp_size_t n,limb_t carry)1102 static limb_t mp_neg(limb_t *res, const limb_t *op2, mp_size_t n, limb_t carry)
1103 {
1104     int i;
1105     limb_t k, a, v, k1;
1106 
1107     k = carry;
1108     for(i=0;i<n;i++) {
1109         v = 0;
1110         a = v - op2[i];
1111         k1 = a > v;
1112         v = a - k;
1113         k = (v > a) | k1;
1114         res[i] = v;
1115     }
1116     return k;
1117 }
1118 
mp_sub_ui(limb_t * tab,limb_t b,mp_size_t n)1119 limb_t mp_sub_ui(limb_t *tab, limb_t b, mp_size_t n)
1120 {
1121     mp_size_t i;
1122     limb_t k, a, v;
1123 
1124     k=b;
1125     for(i=0;i<n;i++) {
1126         v = tab[i];
1127         a = v - k;
1128         k = a > v;
1129         tab[i] = a;
1130         if (k == 0)
1131             break;
1132     }
1133     return k;
1134 }
1135 
1136 /* r = (a + high*B^n) >> shift. Return the remainder r (0 <= r < 2^shift).
1137    1 <= shift <= LIMB_BITS - 1 */
mp_shr(limb_t * tab_r,const limb_t * tab,mp_size_t n,int shift,limb_t high)1138 static limb_t mp_shr(limb_t *tab_r, const limb_t *tab, mp_size_t n,
1139                      int shift, limb_t high)
1140 {
1141     mp_size_t i;
1142     limb_t l, a;
1143 
1144     assert(shift >= 1 && shift < LIMB_BITS);
1145     l = high;
1146     for(i = n - 1; i >= 0; i--) {
1147         a = tab[i];
1148         tab_r[i] = (a >> shift) | (l << (LIMB_BITS - shift));
1149         l = a;
1150     }
1151     return l & (((limb_t)1 << shift) - 1);
1152 }
1153 
1154 /* tabr[] = taba[] * b + l. Return the high carry */
mp_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t l)1155 static limb_t mp_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1156                       limb_t b, limb_t l)
1157 {
1158     limb_t i;
1159     dlimb_t t;
1160 
1161     for(i = 0; i < n; i++) {
1162         t = (dlimb_t)taba[i] * (dlimb_t)b + l;
1163         tabr[i] = t;
1164         l = t >> LIMB_BITS;
1165     }
1166     return l;
1167 }
1168 
1169 /* tabr[] += taba[] * b, return the high word. */
mp_add_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1170 static limb_t mp_add_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1171                           limb_t b)
1172 {
1173     limb_t i, l;
1174     dlimb_t t;
1175 
1176     l = 0;
1177     for(i = 0; i < n; i++) {
1178         t = (dlimb_t)taba[i] * (dlimb_t)b + l + tabr[i];
1179         tabr[i] = t;
1180         l = t >> LIMB_BITS;
1181     }
1182     return l;
1183 }
1184 
1185 /* size of the result : op1_size + op2_size. */
mp_mul_basecase(limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1186 static void mp_mul_basecase(limb_t *result,
1187                             const limb_t *op1, limb_t op1_size,
1188                             const limb_t *op2, limb_t op2_size)
1189 {
1190     limb_t i, r;
1191 
1192     result[op1_size] = mp_mul1(result, op1, op1_size, op2[0], 0);
1193     for(i=1;i<op2_size;i++) {
1194         r = mp_add_mul1(result + i, op1, op1_size, op2[i]);
1195         result[i + op1_size] = r;
1196     }
1197 }
1198 
1199 /* return 0 if OK, -1 if memory error */
1200 /* XXX: change API so that result can be allocated */
mp_mul(bf_context_t * s,limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1201 int mp_mul(bf_context_t *s, limb_t *result,
1202            const limb_t *op1, limb_t op1_size,
1203            const limb_t *op2, limb_t op2_size)
1204 {
1205 #ifdef USE_FFT_MUL
1206     if (unlikely(bf_min(op1_size, op2_size) >= FFT_MUL_THRESHOLD)) {
1207         bf_t r_s, *r = &r_s;
1208         r->tab = result;
1209         /* XXX: optimize memory usage in API */
1210         if (fft_mul(s, r, (limb_t *)op1, op1_size,
1211                     (limb_t *)op2, op2_size, FFT_MUL_R_NORESIZE))
1212             return -1;
1213     } else
1214 #endif
1215     {
1216         mp_mul_basecase(result, op1, op1_size, op2, op2_size);
1217     }
1218     return 0;
1219 }
1220 
1221 /* tabr[] -= taba[] * b. Return the value to substract to the high
1222    word. */
mp_sub_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1223 static limb_t mp_sub_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1224                           limb_t b)
1225 {
1226     limb_t i, l;
1227     dlimb_t t;
1228 
1229     l = 0;
1230     for(i = 0; i < n; i++) {
1231         t = tabr[i] - (dlimb_t)taba[i] * (dlimb_t)b - l;
1232         tabr[i] = t;
1233         l = -(t >> LIMB_BITS);
1234     }
1235     return l;
1236 }
1237 
1238 /* WARNING: d must be >= 2^(LIMB_BITS-1) */
udiv1norm_init(limb_t d)1239 static inline limb_t udiv1norm_init(limb_t d)
1240 {
1241     limb_t a0, a1;
1242     a1 = -d - 1;
1243     a0 = -1;
1244     return (((dlimb_t)a1 << LIMB_BITS) | a0) / d;
1245 }
1246 
1247 /* return the quotient and the remainder in '*pr'of 'a1*2^LIMB_BITS+a0
1248    / d' with 0 <= a1 < d. */
udiv1norm(limb_t * pr,limb_t a1,limb_t a0,limb_t d,limb_t d_inv)1249 static inline limb_t udiv1norm(limb_t *pr, limb_t a1, limb_t a0,
1250                                 limb_t d, limb_t d_inv)
1251 {
1252     limb_t n1m, n_adj, q, r, ah;
1253     dlimb_t a;
1254     n1m = ((slimb_t)a0 >> (LIMB_BITS - 1));
1255     n_adj = a0 + (n1m & d);
1256     a = (dlimb_t)d_inv * (a1 - n1m) + n_adj;
1257     q = (a >> LIMB_BITS) + a1;
1258     /* compute a - q * r and update q so that the remainder is\
1259        between 0 and d - 1 */
1260     a = ((dlimb_t)a1 << LIMB_BITS) | a0;
1261     a = a - (dlimb_t)q * d - d;
1262     ah = a >> LIMB_BITS;
1263     q += 1 + ah;
1264     r = (limb_t)a + (ah & d);
1265     *pr = r;
1266     return q;
1267 }
1268 
1269 /* b must be >= 1 << (LIMB_BITS - 1) */
mp_div1norm(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t r)1270 static limb_t mp_div1norm(limb_t *tabr, const limb_t *taba, limb_t n,
1271                           limb_t b, limb_t r)
1272 {
1273     slimb_t i;
1274 
1275     if (n >= UDIV1NORM_THRESHOLD) {
1276         limb_t b_inv;
1277         b_inv = udiv1norm_init(b);
1278         for(i = n - 1; i >= 0; i--) {
1279             tabr[i] = udiv1norm(&r, r, taba[i], b, b_inv);
1280         }
1281     } else {
1282         dlimb_t a1;
1283         for(i = n - 1; i >= 0; i--) {
1284             a1 = ((dlimb_t)r << LIMB_BITS) | taba[i];
1285             tabr[i] = a1 / b;
1286             r = a1 % b;
1287         }
1288     }
1289     return r;
1290 }
1291 
1292 static int mp_divnorm_large(bf_context_t *s,
1293                             limb_t *tabq, limb_t *taba, limb_t na,
1294                             const limb_t *tabb, limb_t nb);
1295 
1296 /* base case division: divides taba[0..na-1] by tabb[0..nb-1]. tabb[nb
1297    - 1] must be >= 1 << (LIMB_BITS - 1). na - nb must be >= 0. 'taba'
1298    is modified and contains the remainder (nb limbs). tabq[0..na-nb]
1299    contains the quotient with tabq[na - nb] <= 1. */
mp_divnorm(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1300 static int mp_divnorm(bf_context_t *s, limb_t *tabq, limb_t *taba, limb_t na,
1301                       const limb_t *tabb, limb_t nb)
1302 {
1303     limb_t r, a, c, q, v, b1, b1_inv, n, dummy_r;
1304     slimb_t i, j;
1305 
1306     b1 = tabb[nb - 1];
1307     if (nb == 1) {
1308         taba[0] = mp_div1norm(tabq, taba, na, b1, 0);
1309         return 0;
1310     }
1311     n = na - nb;
1312     if (bf_min(n, nb) >= DIVNORM_LARGE_THRESHOLD) {
1313         return mp_divnorm_large(s, tabq, taba, na, tabb, nb);
1314     }
1315 
1316     if (n >= UDIV1NORM_THRESHOLD)
1317         b1_inv = udiv1norm_init(b1);
1318     else
1319         b1_inv = 0;
1320 
1321     /* first iteration: the quotient is only 0 or 1 */
1322     q = 1;
1323     for(j = nb - 1; j >= 0; j--) {
1324         if (taba[n + j] != tabb[j]) {
1325             if (taba[n + j] < tabb[j])
1326                 q = 0;
1327             break;
1328         }
1329     }
1330     tabq[n] = q;
1331     if (q) {
1332         mp_sub(taba + n, taba + n, tabb, nb, 0);
1333     }
1334 
1335     for(i = n - 1; i >= 0; i--) {
1336         if (unlikely(taba[i + nb] >= b1)) {
1337             q = -1;
1338         } else if (b1_inv) {
1339             q = udiv1norm(&dummy_r, taba[i + nb], taba[i + nb - 1], b1, b1_inv);
1340         } else {
1341             dlimb_t al;
1342             al = ((dlimb_t)taba[i + nb] << LIMB_BITS) | taba[i + nb - 1];
1343             q = al / b1;
1344             r = al % b1;
1345         }
1346         r = mp_sub_mul1(taba + i, tabb, nb, q);
1347 
1348         v = taba[i + nb];
1349         a = v - r;
1350         c = (a > v);
1351         taba[i + nb] = a;
1352 
1353         if (c != 0) {
1354             /* negative result */
1355             for(;;) {
1356                 q--;
1357                 c = mp_add(taba + i, taba + i, tabb, nb, 0);
1358                 /* propagate carry and test if positive result */
1359                 if (c != 0) {
1360                     if (++taba[i + nb] == 0) {
1361                         break;
1362                     }
1363                 }
1364             }
1365         }
1366         tabq[i] = q;
1367     }
1368     return 0;
1369 }
1370 
1371 /* compute r=B^(2*n)/a such as a*r < B^(2*n) < a*r + 2 with n >= 1. 'a'
1372    has n limbs with a[n-1] >= B/2 and 'r' has n+1 limbs with r[n] = 1.
1373 
1374    See Modern Computer Arithmetic by Richard P. Brent and Paul
1375    Zimmermann, algorithm 3.5 */
mp_recip(bf_context_t * s,limb_t * tabr,const limb_t * taba,limb_t n)1376 int mp_recip(bf_context_t *s, limb_t *tabr, const limb_t *taba, limb_t n)
1377 {
1378     mp_size_t l, h, k, i;
1379     limb_t *tabxh, *tabt, c, *tabu;
1380 
1381     if (n <= 2) {
1382         /* return ceil(B^(2*n)/a) - 1 */
1383         /* XXX: could avoid allocation */
1384         tabu = bf_malloc(s, sizeof(limb_t) * (2 * n + 1));
1385         tabt = bf_malloc(s, sizeof(limb_t) * (n + 2));
1386         if (!tabt || !tabu)
1387             goto fail;
1388         for(i = 0; i < 2 * n; i++)
1389             tabu[i] = 0;
1390         tabu[2 * n] = 1;
1391         if (mp_divnorm(s, tabt, tabu, 2 * n + 1, taba, n))
1392             goto fail;
1393         for(i = 0; i < n + 1; i++)
1394             tabr[i] = tabt[i];
1395         if (mp_scan_nz(tabu, n) == 0) {
1396             /* only happens for a=B^n/2 */
1397             mp_sub_ui(tabr, 1, n + 1);
1398         }
1399     } else {
1400         l = (n - 1) / 2;
1401         h = n - l;
1402         /* n=2p  -> l=p-1, h = p + 1, k = p + 3
1403            n=2p+1-> l=p,  h = p + 1; k = p + 2
1404         */
1405         tabt = bf_malloc(s, sizeof(limb_t) * (n + h + 1));
1406         tabu = bf_malloc(s, sizeof(limb_t) * (n + 2 * h - l + 2));
1407         if (!tabt || !tabu)
1408             goto fail;
1409         tabxh = tabr + l;
1410         if (mp_recip(s, tabxh, taba + l, h))
1411             goto fail;
1412         if (mp_mul(s, tabt, taba, n, tabxh, h + 1)) /* n + h + 1 limbs */
1413             goto fail;
1414         while (tabt[n + h] != 0) {
1415             mp_sub_ui(tabxh, 1, h + 1);
1416             c = mp_sub(tabt, tabt, taba, n, 0);
1417             mp_sub_ui(tabt + n, c, h + 1);
1418         }
1419         /* T = B^(n+h) - T */
1420         mp_neg(tabt, tabt, n + h + 1, 0);
1421         tabt[n + h]++;
1422         if (mp_mul(s, tabu, tabt + l, n + h + 1 - l, tabxh, h + 1))
1423             goto fail;
1424         /* n + 2*h - l + 2 limbs */
1425         k = 2 * h - l;
1426         for(i = 0; i < l; i++)
1427             tabr[i] = tabu[i + k];
1428         mp_add(tabr + l, tabr + l, tabu + 2 * h, h, 0);
1429     }
1430     bf_free(s, tabt);
1431     bf_free(s, tabu);
1432     return 0;
1433  fail:
1434     bf_free(s, tabt);
1435     bf_free(s, tabu);
1436     return -1;
1437 }
1438 
1439 /* return -1, 0 or 1 */
mp_cmp(const limb_t * taba,const limb_t * tabb,mp_size_t n)1440 static int mp_cmp(const limb_t *taba, const limb_t *tabb, mp_size_t n)
1441 {
1442     mp_size_t i;
1443     for(i = n - 1; i >= 0; i--) {
1444         if (taba[i] != tabb[i]) {
1445             if (taba[i] < tabb[i])
1446                 return -1;
1447             else
1448                 return 1;
1449         }
1450     }
1451     return 0;
1452 }
1453 
1454 //#define DEBUG_DIVNORM_LARGE
1455 //#define DEBUG_DIVNORM_LARGE2
1456 
1457 /* subquadratic divnorm */
mp_divnorm_large(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1458 static int mp_divnorm_large(bf_context_t *s,
1459                             limb_t *tabq, limb_t *taba, limb_t na,
1460                             const limb_t *tabb, limb_t nb)
1461 {
1462     limb_t *tabb_inv, nq, *tabt, i, n;
1463     nq = na - nb;
1464 #ifdef DEBUG_DIVNORM_LARGE
1465     printf("na=%d nb=%d nq=%d\n", (int)na, (int)nb, (int)nq);
1466     mp_print_str("a", taba, na);
1467     mp_print_str("b", tabb, nb);
1468 #endif
1469     assert(nq >= 1);
1470     n = nq;
1471     if (nq < nb)
1472         n++;
1473     tabb_inv = bf_malloc(s, sizeof(limb_t) * (n + 1));
1474     tabt = bf_malloc(s, sizeof(limb_t) * 2 * (n + 1));
1475     if (!tabb_inv || !tabt)
1476         goto fail;
1477 
1478     if (n >= nb) {
1479         for(i = 0; i < n - nb; i++)
1480             tabt[i] = 0;
1481         for(i = 0; i < nb; i++)
1482             tabt[i + n - nb] = tabb[i];
1483     } else {
1484         /* truncate B: need to increment it so that the approximate
1485            inverse is smaller that the exact inverse */
1486         for(i = 0; i < n; i++)
1487             tabt[i] = tabb[i + nb - n];
1488         if (mp_add_ui(tabt, 1, n)) {
1489             /* tabt = B^n : tabb_inv = B^n */
1490             memset(tabb_inv, 0, n * sizeof(limb_t));
1491             tabb_inv[n] = 1;
1492             goto recip_done;
1493         }
1494     }
1495     if (mp_recip(s, tabb_inv, tabt, n))
1496         goto fail;
1497  recip_done:
1498     /* Q=A*B^-1 */
1499     if (mp_mul(s, tabt, tabb_inv, n + 1, taba + na - (n + 1), n + 1))
1500         goto fail;
1501 
1502     for(i = 0; i < nq + 1; i++)
1503         tabq[i] = tabt[i + 2 * (n + 1) - (nq + 1)];
1504 #ifdef DEBUG_DIVNORM_LARGE
1505     mp_print_str("q", tabq, nq + 1);
1506 #endif
1507 
1508     bf_free(s, tabt);
1509     bf_free(s, tabb_inv);
1510     tabb_inv = NULL;
1511 
1512     /* R=A-B*Q */
1513     tabt = bf_malloc(s, sizeof(limb_t) * (na + 1));
1514     if (!tabt)
1515         goto fail;
1516     if (mp_mul(s, tabt, tabq, nq + 1, tabb, nb))
1517         goto fail;
1518     /* we add one more limb for the result */
1519     mp_sub(taba, taba, tabt, nb + 1, 0);
1520     bf_free(s, tabt);
1521     /* the approximated quotient is smaller than than the exact one,
1522        hence we may have to increment it */
1523 #ifdef DEBUG_DIVNORM_LARGE2
1524     int cnt = 0;
1525     static int cnt_max;
1526 #endif
1527     for(;;) {
1528         if (taba[nb] == 0 && mp_cmp(taba, tabb, nb) < 0)
1529             break;
1530         taba[nb] -= mp_sub(taba, taba, tabb, nb, 0);
1531         mp_add_ui(tabq, 1, nq + 1);
1532 #ifdef DEBUG_DIVNORM_LARGE2
1533         cnt++;
1534 #endif
1535     }
1536 #ifdef DEBUG_DIVNORM_LARGE2
1537     if (cnt > cnt_max) {
1538         cnt_max = cnt;
1539         printf("\ncnt=%d nq=%d nb=%d\n", cnt_max, (int)nq, (int)nb);
1540     }
1541 #endif
1542     return 0;
1543  fail:
1544     bf_free(s, tabb_inv);
1545     bf_free(s, tabt);
1546     return -1;
1547 }
1548 
bf_mul(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1549 int bf_mul(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1550            bf_flags_t flags)
1551 {
1552     int ret, r_sign;
1553 
1554     if (a->len < b->len) {
1555         const bf_t *tmp = a;
1556         a = b;
1557         b = tmp;
1558     }
1559     r_sign = a->sign ^ b->sign;
1560     /* here b->len <= a->len */
1561     if (b->len == 0) {
1562         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1563             bf_set_nan(r);
1564             ret = 0;
1565         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
1566             if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
1567                 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
1568                 bf_set_nan(r);
1569                 ret = BF_ST_INVALID_OP;
1570             } else {
1571                 bf_set_inf(r, r_sign);
1572                 ret = 0;
1573             }
1574         } else {
1575             bf_set_zero(r, r_sign);
1576             ret = 0;
1577         }
1578     } else {
1579         bf_t tmp, *r1 = NULL;
1580         limb_t a_len, b_len, precl;
1581         limb_t *a_tab, *b_tab;
1582 
1583         a_len = a->len;
1584         b_len = b->len;
1585 
1586         if ((flags & BF_RND_MASK) == BF_RNDF) {
1587             /* faithful rounding does not require using the full inputs */
1588             precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1589             a_len = bf_min(a_len, precl);
1590             b_len = bf_min(b_len, precl);
1591         }
1592         a_tab = a->tab + a->len - a_len;
1593         b_tab = b->tab + b->len - b_len;
1594 
1595 #ifdef USE_FFT_MUL
1596         if (b_len >= FFT_MUL_THRESHOLD) {
1597             int mul_flags = 0;
1598             if (r == a)
1599                 mul_flags |= FFT_MUL_R_OVERLAP_A;
1600             if (r == b)
1601                 mul_flags |= FFT_MUL_R_OVERLAP_B;
1602             if (fft_mul(r->ctx, r, a_tab, a_len, b_tab, b_len, mul_flags))
1603                 goto fail;
1604         } else
1605 #endif
1606         {
1607             if (r == a || r == b) {
1608                 bf_init(r->ctx, &tmp);
1609                 r1 = r;
1610                 r = &tmp;
1611             }
1612             if (bf_resize(r, a_len + b_len)) {
1613             fail:
1614                 bf_set_nan(r);
1615                 ret = BF_ST_MEM_ERROR;
1616                 goto done;
1617             }
1618             mp_mul_basecase(r->tab, a_tab, a_len, b_tab, b_len);
1619         }
1620         r->sign = r_sign;
1621         r->expn = a->expn + b->expn;
1622         ret = bf_normalize_and_round(r, prec, flags);
1623     done:
1624         if (r == &tmp)
1625             bf_move(r1, &tmp);
1626     }
1627     return ret;
1628 }
1629 
1630 /* multiply 'r' by 2^e */
bf_mul_2exp(bf_t * r,slimb_t e,limb_t prec,bf_flags_t flags)1631 int bf_mul_2exp(bf_t *r, slimb_t e, limb_t prec, bf_flags_t flags)
1632 {
1633     slimb_t e_max;
1634     if (r->len == 0)
1635         return 0;
1636     e_max = ((limb_t)1 << BF_EXP_BITS_MAX) - 1;
1637     e = bf_max(e, -e_max);
1638     e = bf_min(e, e_max);
1639     r->expn += e;
1640     return __bf_round(r, prec, flags, r->len);
1641 }
1642 
1643 /* Return e such as a=m*2^e with m odd integer. return 0 if a is zero,
1644    Infinite or Nan. */
bf_get_exp_min(const bf_t * a)1645 slimb_t bf_get_exp_min(const bf_t *a)
1646 {
1647     slimb_t i;
1648     limb_t v;
1649     int k;
1650 
1651     for(i = 0; i < a->len; i++) {
1652         v = a->tab[i];
1653         if (v != 0) {
1654             k = ctz(v);
1655             return a->expn - (a->len - i) * LIMB_BITS + k;
1656         }
1657     }
1658     return 0;
1659 }
1660 
1661 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
1662    integer defined as floor(a/b) and r = a - q * b. */
bf_tdivremu(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b)1663 static void bf_tdivremu(bf_t *q, bf_t *r,
1664                         const bf_t *a, const bf_t *b)
1665 {
1666     if (bf_cmpu(a, b) < 0) {
1667         bf_set_ui(q, 0);
1668         bf_set(r, a);
1669     } else {
1670         bf_div(q, a, b, bf_max(a->expn - b->expn + 1, 2), BF_RNDZ);
1671         bf_rint(q, BF_RNDZ);
1672         bf_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
1673         bf_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
1674     }
1675 }
1676 
__bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1677 static int __bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1678                     bf_flags_t flags)
1679 {
1680     bf_context_t *s = r->ctx;
1681     int ret, r_sign;
1682     limb_t n, nb, precl;
1683 
1684     r_sign = a->sign ^ b->sign;
1685     if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
1686         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1687             bf_set_nan(r);
1688             return 0;
1689         } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
1690             bf_set_nan(r);
1691             return BF_ST_INVALID_OP;
1692         } else if (a->expn == BF_EXP_INF) {
1693             bf_set_inf(r, r_sign);
1694             return 0;
1695         } else {
1696             bf_set_zero(r, r_sign);
1697             return 0;
1698         }
1699     } else if (a->expn == BF_EXP_ZERO) {
1700         if (b->expn == BF_EXP_ZERO) {
1701             bf_set_nan(r);
1702             return BF_ST_INVALID_OP;
1703         } else {
1704             bf_set_zero(r, r_sign);
1705             return 0;
1706         }
1707     } else if (b->expn == BF_EXP_ZERO) {
1708         bf_set_inf(r, r_sign);
1709         return BF_ST_DIVIDE_ZERO;
1710     }
1711 
1712     /* number of limbs of the quotient (2 extra bits for rounding) */
1713     precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1714     nb = b->len;
1715     n = bf_max(a->len, precl);
1716 
1717     {
1718         limb_t *taba, na;
1719         slimb_t d;
1720 
1721         na = n + nb;
1722         taba = bf_malloc(s, (na + 1) * sizeof(limb_t));
1723         if (!taba)
1724             goto fail;
1725         d = na - a->len;
1726         memset(taba, 0, d * sizeof(limb_t));
1727         memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
1728         if (bf_resize(r, n + 1))
1729             goto fail;
1730         if (mp_divnorm(s, r->tab, taba, na, b->tab, nb))
1731             goto fail;
1732 
1733         /* see if non zero remainder */
1734         if (mp_scan_nz(taba, nb))
1735             r->tab[0] |= 1;
1736         bf_free(r->ctx, taba);
1737         r->expn = a->expn - b->expn + LIMB_BITS;
1738         r->sign = r_sign;
1739         ret = bf_normalize_and_round(r, prec, flags);
1740     }
1741     return ret;
1742  fail:
1743     bf_set_nan(r);
1744     return BF_ST_MEM_ERROR;
1745 }
1746 
1747 /* division and remainder.
1748 
1749    rnd_mode is the rounding mode for the quotient. The additional
1750    rounding mode BF_RND_EUCLIDIAN is supported.
1751 
1752    'q' is an integer. 'r' is rounded with prec and flags (prec can be
1753    BF_PREC_INF).
1754 */
bf_divrem(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1755 int bf_divrem(bf_t *q, bf_t *r, const bf_t *a, const bf_t *b,
1756               limb_t prec, bf_flags_t flags, int rnd_mode)
1757 {
1758     bf_t a1_s, *a1 = &a1_s;
1759     bf_t b1_s, *b1 = &b1_s;
1760     int q_sign, ret;
1761     BOOL is_ceil, is_rndn;
1762 
1763     assert(q != a && q != b);
1764     assert(r != a && r != b);
1765     assert(q != r);
1766 
1767     if (a->len == 0 || b->len == 0) {
1768         bf_set_zero(q, 0);
1769         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1770             bf_set_nan(r);
1771             return 0;
1772         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
1773             bf_set_nan(r);
1774             return BF_ST_INVALID_OP;
1775         } else {
1776             bf_set(r, a);
1777             return bf_round(r, prec, flags);
1778         }
1779     }
1780 
1781     q_sign = a->sign ^ b->sign;
1782     is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA ||
1783                rnd_mode == BF_RNDNU);
1784     switch(rnd_mode) {
1785     default:
1786     case BF_RNDZ:
1787     case BF_RNDN:
1788     case BF_RNDNA:
1789         is_ceil = FALSE;
1790         break;
1791     case BF_RNDD:
1792         is_ceil = q_sign;
1793         break;
1794     case BF_RNDU:
1795         is_ceil = q_sign ^ 1;
1796         break;
1797     case BF_DIVREM_EUCLIDIAN:
1798         is_ceil = a->sign;
1799         break;
1800     case BF_RNDNU:
1801         /* XXX: unsupported yet */
1802         abort();
1803     }
1804 
1805     a1->expn = a->expn;
1806     a1->tab = a->tab;
1807     a1->len = a->len;
1808     a1->sign = 0;
1809 
1810     b1->expn = b->expn;
1811     b1->tab = b->tab;
1812     b1->len = b->len;
1813     b1->sign = 0;
1814 
1815     /* XXX: could improve to avoid having a large 'q' */
1816     bf_tdivremu(q, r, a1, b1);
1817     if (bf_is_nan(q) || bf_is_nan(r))
1818         goto fail;
1819 
1820     if (r->len != 0) {
1821         if (is_rndn) {
1822             int res;
1823             b1->expn--;
1824             res = bf_cmpu(r, b1);
1825             b1->expn++;
1826             if (res > 0 ||
1827                 (res == 0 &&
1828                  (rnd_mode == BF_RNDNA ||
1829                   get_bit(q->tab, q->len, q->len * LIMB_BITS - q->expn)))) {
1830                 goto do_sub_r;
1831             }
1832         } else if (is_ceil) {
1833         do_sub_r:
1834             ret = bf_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
1835             ret |= bf_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
1836             if (ret & BF_ST_MEM_ERROR)
1837                 goto fail;
1838         }
1839     }
1840 
1841     r->sign ^= a->sign;
1842     q->sign = q_sign;
1843     return bf_round(r, prec, flags);
1844  fail:
1845     bf_set_nan(q);
1846     bf_set_nan(r);
1847     return BF_ST_MEM_ERROR;
1848 }
1849 
bf_fmod(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1850 int bf_fmod(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1851             bf_flags_t flags)
1852 {
1853     bf_t q_s, *q = &q_s;
1854     int ret;
1855 
1856     bf_init(r->ctx, q);
1857     ret = bf_divrem(q, r, a, b, prec, flags, BF_RNDZ);
1858     bf_delete(q);
1859     return ret;
1860 }
1861 
bf_remainder(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1862 int bf_remainder(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1863                  bf_flags_t flags)
1864 {
1865     bf_t q_s, *q = &q_s;
1866     int ret;
1867 
1868     bf_init(r->ctx, q);
1869     ret = bf_divrem(q, r, a, b, prec, flags, BF_RNDN);
1870     bf_delete(q);
1871     return ret;
1872 }
1873 
bf_get_limb(slimb_t * pres,const bf_t * a,int flags)1874 static inline int bf_get_limb(slimb_t *pres, const bf_t *a, int flags)
1875 {
1876 #if LIMB_BITS == 32
1877     return bf_get_int32(pres, a, flags);
1878 #else
1879     return bf_get_int64(pres, a, flags);
1880 #endif
1881 }
1882 
bf_remquo(slimb_t * pq,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1883 int bf_remquo(slimb_t *pq, bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1884               bf_flags_t flags)
1885 {
1886     bf_t q_s, *q = &q_s;
1887     int ret;
1888 
1889     bf_init(r->ctx, q);
1890     ret = bf_divrem(q, r, a, b, prec, flags, BF_RNDN);
1891     bf_get_limb(pq, q, BF_GET_INT_MOD);
1892     bf_delete(q);
1893     return ret;
1894 }
1895 
mul_mod(limb_t a,limb_t b,limb_t m)1896 static __maybe_unused inline limb_t mul_mod(limb_t a, limb_t b, limb_t m)
1897 {
1898     dlimb_t t;
1899     t = (dlimb_t)a * (dlimb_t)b;
1900     return t % m;
1901 }
1902 
1903 #if defined(USE_MUL_CHECK)
mp_mod1(const limb_t * tab,limb_t n,limb_t m,limb_t r)1904 static limb_t mp_mod1(const limb_t *tab, limb_t n, limb_t m, limb_t r)
1905 {
1906     slimb_t i;
1907     dlimb_t t;
1908 
1909     for(i = n - 1; i >= 0; i--) {
1910         t = ((dlimb_t)r << LIMB_BITS) | tab[i];
1911         r = t % m;
1912     }
1913     return r;
1914 }
1915 #endif
1916 
1917 static const uint16_t sqrt_table[192] = {
1918 128,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,144,145,146,147,148,149,150,150,151,152,153,154,155,155,156,157,158,159,160,160,161,162,163,163,164,165,166,167,167,168,169,170,170,171,172,173,173,174,175,176,176,177,178,178,179,180,181,181,182,183,183,184,185,185,186,187,187,188,189,189,190,191,192,192,193,193,194,195,195,196,197,197,198,199,199,200,201,201,202,203,203,204,204,205,206,206,207,208,208,209,209,210,211,211,212,212,213,214,214,215,215,216,217,217,218,218,219,219,220,221,221,222,222,223,224,224,225,225,226,226,227,227,228,229,229,230,230,231,231,232,232,233,234,234,235,235,236,236,237,237,238,238,239,240,240,241,241,242,242,243,243,244,244,245,245,246,246,247,247,248,248,249,249,250,250,251,251,252,252,253,253,254,254,255,
1919 };
1920 
1921 /* a >= 2^(LIMB_BITS - 2).  Return (s, r) with s=floor(sqrt(a)) and
1922    r=a-s^2. 0 <= r <= 2 * s */
mp_sqrtrem1(limb_t * pr,limb_t a)1923 static limb_t mp_sqrtrem1(limb_t *pr, limb_t a)
1924 {
1925     limb_t s1, r1, s, r, q, u, num;
1926 
1927     /* use a table for the 16 -> 8 bit sqrt */
1928     s1 = sqrt_table[(a >> (LIMB_BITS - 8)) - 64];
1929     r1 = (a >> (LIMB_BITS - 16)) - s1 * s1;
1930     if (r1 > 2 * s1) {
1931         r1 -= 2 * s1 + 1;
1932         s1++;
1933     }
1934 
1935     /* one iteration to get a 32 -> 16 bit sqrt */
1936     num = (r1 << 8) | ((a >> (LIMB_BITS - 32 + 8)) & 0xff);
1937     q = num / (2 * s1); /* q <= 2^8 */
1938     u = num % (2 * s1);
1939     s = (s1 << 8) + q;
1940     r = (u << 8) | ((a >> (LIMB_BITS - 32)) & 0xff);
1941     r -= q * q;
1942     if ((slimb_t)r < 0) {
1943         s--;
1944         r += 2 * s + 1;
1945     }
1946 
1947 #if LIMB_BITS == 64
1948     s1 = s;
1949     r1 = r;
1950     /* one more iteration for 64 -> 32 bit sqrt */
1951     num = (r1 << 16) | ((a >> (LIMB_BITS - 64 + 16)) & 0xffff);
1952     q = num / (2 * s1); /* q <= 2^16 */
1953     u = num % (2 * s1);
1954     s = (s1 << 16) + q;
1955     r = (u << 16) | ((a >> (LIMB_BITS - 64)) & 0xffff);
1956     r -= q * q;
1957     if ((slimb_t)r < 0) {
1958         s--;
1959         r += 2 * s + 1;
1960     }
1961 #endif
1962     *pr = r;
1963     return s;
1964 }
1965 
1966 /* return floor(sqrt(a)) */
bf_isqrt(limb_t a)1967 limb_t bf_isqrt(limb_t a)
1968 {
1969     limb_t s, r;
1970     int k;
1971 
1972     if (a == 0)
1973         return 0;
1974     k = clz(a) & ~1;
1975     s = mp_sqrtrem1(&r, a << k);
1976     s >>= (k >> 1);
1977     return s;
1978 }
1979 
mp_sqrtrem2(limb_t * tabs,limb_t * taba)1980 static limb_t mp_sqrtrem2(limb_t *tabs, limb_t *taba)
1981 {
1982     limb_t s1, r1, s, q, u, a0, a1;
1983     dlimb_t r, num;
1984     int l;
1985 
1986     a0 = taba[0];
1987     a1 = taba[1];
1988     s1 = mp_sqrtrem1(&r1, a1);
1989     l = LIMB_BITS / 2;
1990     num = ((dlimb_t)r1 << l) | (a0 >> l);
1991     q = num / (2 * s1);
1992     u = num % (2 * s1);
1993     s = (s1 << l) + q;
1994     r = ((dlimb_t)u << l) | (a0 & (((limb_t)1 << l) - 1));
1995     if (unlikely((q >> l) != 0))
1996         r -= (dlimb_t)1 << LIMB_BITS; /* special case when q=2^l */
1997     else
1998         r -= q * q;
1999     if ((slimb_t)(r >> LIMB_BITS) < 0) {
2000         s--;
2001         r += 2 * (dlimb_t)s + 1;
2002     }
2003     tabs[0] = s;
2004     taba[0] = r;
2005     return r >> LIMB_BITS;
2006 }
2007 
2008 //#define DEBUG_SQRTREM
2009 
2010 /* tmp_buf must contain (n / 2 + 1 limbs). *prh contains the highest
2011    limb of the remainder. */
mp_sqrtrem_rec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf,limb_t * prh)2012 static int mp_sqrtrem_rec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n,
2013                           limb_t *tmp_buf, limb_t *prh)
2014 {
2015     limb_t l, h, rh, ql, qh, c, i;
2016 
2017     if (n == 1) {
2018         *prh = mp_sqrtrem2(tabs, taba);
2019         return 0;
2020     }
2021 #ifdef DEBUG_SQRTREM
2022     mp_print_str("a", taba, 2 * n);
2023 #endif
2024     l = n / 2;
2025     h = n - l;
2026     if (mp_sqrtrem_rec(s, tabs + l, taba + 2 * l, h, tmp_buf, &qh))
2027         return -1;
2028 #ifdef DEBUG_SQRTREM
2029     mp_print_str("s1", tabs + l, h);
2030     mp_print_str_h("r1", taba + 2 * l, h, qh);
2031     mp_print_str_h("r2", taba + l, n, qh);
2032 #endif
2033 
2034     /* the remainder is in taba + 2 * l. Its high bit is in qh */
2035     if (qh) {
2036         mp_sub(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
2037     }
2038     /* instead of dividing by 2*s, divide by s (which is normalized)
2039        and update q and r */
2040     if (mp_divnorm(s, tmp_buf, taba + l, n, tabs + l, h))
2041         return -1;
2042     qh += tmp_buf[l];
2043     for(i = 0; i < l; i++)
2044         tabs[i] = tmp_buf[i];
2045     ql = mp_shr(tabs, tabs, l, 1, qh & 1);
2046     qh = qh >> 1; /* 0 or 1 */
2047     if (ql)
2048         rh = mp_add(taba + l, taba + l, tabs + l, h, 0);
2049     else
2050         rh = 0;
2051 #ifdef DEBUG_SQRTREM
2052     mp_print_str_h("q", tabs, l, qh);
2053     mp_print_str_h("u", taba + l, h, rh);
2054 #endif
2055 
2056     mp_add_ui(tabs + l, qh, h);
2057 #ifdef DEBUG_SQRTREM
2058     mp_print_str_h("s2", tabs, n, sh);
2059 #endif
2060 
2061     /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
2062     /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
2063     if (qh) {
2064         c = qh;
2065     } else {
2066         if (mp_mul(s, taba + n, tabs, l, tabs, l))
2067             return -1;
2068         c = mp_sub(taba, taba, taba + n, 2 * l, 0);
2069     }
2070     rh -= mp_sub_ui(taba + 2 * l, c, n - 2 * l);
2071     if ((slimb_t)rh < 0) {
2072         mp_sub_ui(tabs, 1, n);
2073         rh += mp_add_mul1(taba, tabs, n, 2);
2074         rh += mp_add_ui(taba, 1, n);
2075     }
2076     *prh = rh;
2077     return 0;
2078 }
2079 
2080 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= 2 ^ (LIMB_BITS
2081    - 2). Return (s, r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2
2082    * s. tabs has n limbs. r is returned in the lower n limbs of
2083    taba. Its r[n] is the returned value of the function. */
2084 /* Algorithm from the article "Karatsuba Square Root" by Paul Zimmermann and
2085    inspirated from its GMP implementation */
mp_sqrtrem(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)2086 int mp_sqrtrem(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
2087 {
2088     limb_t tmp_buf1[8];
2089     limb_t *tmp_buf;
2090     mp_size_t n2;
2091     int ret;
2092     n2 = n / 2 + 1;
2093     if (n2 <= countof(tmp_buf1)) {
2094         tmp_buf = tmp_buf1;
2095     } else {
2096         tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
2097         if (!tmp_buf)
2098             return -1;
2099     }
2100     ret = mp_sqrtrem_rec(s, tabs, taba, n, tmp_buf, taba + n);
2101     if (tmp_buf != tmp_buf1)
2102         bf_free(s, tmp_buf);
2103     return ret;
2104 }
2105 
2106 /* Integer square root with remainder. 'a' must be an integer. r =
2107    floor(sqrt(a)) and rem = a - r^2.  BF_ST_INEXACT is set if the result
2108    is inexact. 'rem' can be NULL if the remainder is not needed. */
bf_sqrtrem(bf_t * r,bf_t * rem1,const bf_t * a)2109 int bf_sqrtrem(bf_t *r, bf_t *rem1, const bf_t *a)
2110 {
2111     int ret;
2112 
2113     if (a->len == 0) {
2114         if (a->expn == BF_EXP_NAN) {
2115             bf_set_nan(r);
2116         } else if (a->expn == BF_EXP_INF && a->sign) {
2117             goto invalid_op;
2118         } else {
2119             bf_set(r, a);
2120         }
2121         if (rem1)
2122             bf_set_ui(rem1, 0);
2123         ret = 0;
2124     } else if (a->sign) {
2125  invalid_op:
2126         bf_set_nan(r);
2127         if (rem1)
2128             bf_set_ui(rem1, 0);
2129         ret = BF_ST_INVALID_OP;
2130     } else {
2131         bf_t rem_s, *rem;
2132 
2133         bf_sqrt(r, a, (a->expn + 1) / 2, BF_RNDZ);
2134         bf_rint(r, BF_RNDZ);
2135         /* see if the result is exact by computing the remainder */
2136         if (rem1) {
2137             rem = rem1;
2138         } else {
2139             rem = &rem_s;
2140             bf_init(r->ctx, rem);
2141         }
2142         /* XXX: could avoid recomputing the remainder */
2143         bf_mul(rem, r, r, BF_PREC_INF, BF_RNDZ);
2144         bf_neg(rem);
2145         bf_add(rem, rem, a, BF_PREC_INF, BF_RNDZ);
2146         if (bf_is_nan(rem)) {
2147             ret = BF_ST_MEM_ERROR;
2148             goto done;
2149         }
2150         if (rem->len != 0) {
2151             ret = BF_ST_INEXACT;
2152         } else {
2153             ret = 0;
2154         }
2155     done:
2156         if (!rem1)
2157             bf_delete(rem);
2158     }
2159     return ret;
2160 }
2161 
bf_sqrt(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)2162 int bf_sqrt(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
2163 {
2164     bf_context_t *s = a->ctx;
2165     int ret;
2166 
2167     assert(r != a);
2168 
2169     if (a->len == 0) {
2170         if (a->expn == BF_EXP_NAN) {
2171             bf_set_nan(r);
2172         } else if (a->expn == BF_EXP_INF && a->sign) {
2173             goto invalid_op;
2174         } else {
2175             bf_set(r, a);
2176         }
2177         ret = 0;
2178     } else if (a->sign) {
2179  invalid_op:
2180         bf_set_nan(r);
2181         ret = BF_ST_INVALID_OP;
2182     } else {
2183         limb_t *a1;
2184         slimb_t n, n1;
2185         limb_t res;
2186 
2187         /* convert the mantissa to an integer with at least 2 *
2188            prec + 4 bits */
2189         n = (2 * (prec + 2) + 2 * LIMB_BITS - 1) / (2 * LIMB_BITS);
2190         if (bf_resize(r, n))
2191             goto fail;
2192         a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
2193         if (!a1)
2194             goto fail;
2195         n1 = bf_min(2 * n, a->len);
2196         memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
2197         memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
2198         if (a->expn & 1) {
2199             res = mp_shr(a1, a1, 2 * n, 1, 0);
2200         } else {
2201             res = 0;
2202         }
2203         if (mp_sqrtrem(s, r->tab, a1, n)) {
2204             bf_free(s, a1);
2205             goto fail;
2206         }
2207         if (!res) {
2208             res = mp_scan_nz(a1, n + 1);
2209         }
2210         bf_free(s, a1);
2211         if (!res) {
2212             res = mp_scan_nz(a->tab, a->len - n1);
2213         }
2214         if (res != 0)
2215             r->tab[0] |= 1;
2216         r->sign = 0;
2217         r->expn = (a->expn + 1) >> 1;
2218         ret = bf_round(r, prec, flags);
2219     }
2220     return ret;
2221  fail:
2222     bf_set_nan(r);
2223     return BF_ST_MEM_ERROR;
2224 }
2225 
bf_op2(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,bf_op2_func_t * func)2226 static no_inline int bf_op2(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2227                             bf_flags_t flags, bf_op2_func_t *func)
2228 {
2229     bf_t tmp;
2230     int ret;
2231 
2232     if (r == a || r == b) {
2233         bf_init(r->ctx, &tmp);
2234         ret = func(&tmp, a, b, prec, flags);
2235         bf_move(r, &tmp);
2236     } else {
2237         ret = func(r, a, b, prec, flags);
2238     }
2239     return ret;
2240 }
2241 
bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2242 int bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2243             bf_flags_t flags)
2244 {
2245     return bf_op2(r, a, b, prec, flags, __bf_add);
2246 }
2247 
bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2248 int bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2249             bf_flags_t flags)
2250 {
2251     return bf_op2(r, a, b, prec, flags, __bf_sub);
2252 }
2253 
bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2254 int bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2255            bf_flags_t flags)
2256 {
2257     return bf_op2(r, a, b, prec, flags, __bf_div);
2258 }
2259 
bf_mul_ui(bf_t * r,const bf_t * a,uint64_t b1,limb_t prec,bf_flags_t flags)2260 int bf_mul_ui(bf_t *r, const bf_t *a, uint64_t b1, limb_t prec,
2261                bf_flags_t flags)
2262 {
2263     bf_t b;
2264     int ret;
2265     bf_init(r->ctx, &b);
2266     ret = bf_set_ui(&b, b1);
2267     ret |= bf_mul(r, a, &b, prec, flags);
2268     bf_delete(&b);
2269     return ret;
2270 }
2271 
bf_mul_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2272 int bf_mul_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2273                bf_flags_t flags)
2274 {
2275     bf_t b;
2276     int ret;
2277     bf_init(r->ctx, &b);
2278     ret = bf_set_si(&b, b1);
2279     ret |= bf_mul(r, a, &b, prec, flags);
2280     bf_delete(&b);
2281     return ret;
2282 }
2283 
bf_add_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2284 int bf_add_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2285               bf_flags_t flags)
2286 {
2287     bf_t b;
2288     int ret;
2289 
2290     bf_init(r->ctx, &b);
2291     ret = bf_set_si(&b, b1);
2292     ret |= bf_add(r, a, &b, prec, flags);
2293     bf_delete(&b);
2294     return ret;
2295 }
2296 
bf_pow_ui(bf_t * r,const bf_t * a,limb_t b,limb_t prec,bf_flags_t flags)2297 static int bf_pow_ui(bf_t *r, const bf_t *a, limb_t b, limb_t prec,
2298                      bf_flags_t flags)
2299 {
2300     int ret, n_bits, i;
2301 
2302     assert(r != a);
2303     if (b == 0)
2304         return bf_set_ui(r, 1);
2305     ret = bf_set(r, a);
2306     n_bits = LIMB_BITS - clz(b);
2307     for(i = n_bits - 2; i >= 0; i--) {
2308         ret |= bf_mul(r, r, r, prec, flags);
2309         if ((b >> i) & 1)
2310             ret |= bf_mul(r, r, a, prec, flags);
2311     }
2312     return ret;
2313 }
2314 
bf_pow_ui_ui(bf_t * r,limb_t a1,limb_t b,limb_t prec,bf_flags_t flags)2315 static int bf_pow_ui_ui(bf_t *r, limb_t a1, limb_t b,
2316                         limb_t prec, bf_flags_t flags)
2317 {
2318     bf_t a;
2319     int ret;
2320 
2321     if (a1 == 10 && b <= LIMB_DIGITS) {
2322         /* use precomputed powers. We do not round at this point
2323            because we expect the caller to do it */
2324         ret = bf_set_ui(r, mp_pow_dec[b]);
2325     } else {
2326         bf_init(r->ctx, &a);
2327         ret = bf_set_ui(&a, a1);
2328         ret |= bf_pow_ui(r, &a, b, prec, flags);
2329         bf_delete(&a);
2330     }
2331     return ret;
2332 }
2333 
2334 /* convert to integer (infinite precision) */
bf_rint(bf_t * r,int rnd_mode)2335 int bf_rint(bf_t *r, int rnd_mode)
2336 {
2337     return bf_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
2338 }
2339 
2340 /* logical operations */
2341 #define BF_LOGIC_OR  0
2342 #define BF_LOGIC_XOR 1
2343 #define BF_LOGIC_AND 2
2344 
bf_logic_op1(limb_t a,limb_t b,int op)2345 static inline limb_t bf_logic_op1(limb_t a, limb_t b, int op)
2346 {
2347     switch(op) {
2348     case BF_LOGIC_OR:
2349         return a | b;
2350     case BF_LOGIC_XOR:
2351         return a ^ b;
2352     default:
2353     case BF_LOGIC_AND:
2354         return a & b;
2355     }
2356 }
2357 
bf_logic_op(bf_t * r,const bf_t * a1,const bf_t * b1,int op)2358 static int bf_logic_op(bf_t *r, const bf_t *a1, const bf_t *b1, int op)
2359 {
2360     bf_t b1_s, a1_s, *a, *b;
2361     limb_t a_sign, b_sign, r_sign;
2362     slimb_t l, i, a_bit_offset, b_bit_offset;
2363     limb_t v1, v2, v1_mask, v2_mask, r_mask;
2364     int ret;
2365 
2366     assert(r != a1 && r != b1);
2367 
2368     if (a1->expn <= 0)
2369         a_sign = 0; /* minus zero is considered as positive */
2370     else
2371         a_sign = a1->sign;
2372 
2373     if (b1->expn <= 0)
2374         b_sign = 0; /* minus zero is considered as positive */
2375     else
2376         b_sign = b1->sign;
2377 
2378     if (a_sign) {
2379         a = &a1_s;
2380         bf_init(r->ctx, a);
2381         if (bf_add_si(a, a1, 1, BF_PREC_INF, BF_RNDZ)) {
2382             b = NULL;
2383             goto fail;
2384         }
2385     } else {
2386         a = (bf_t *)a1;
2387     }
2388 
2389     if (b_sign) {
2390         b = &b1_s;
2391         bf_init(r->ctx, b);
2392         if (bf_add_si(b, b1, 1, BF_PREC_INF, BF_RNDZ))
2393             goto fail;
2394     } else {
2395         b = (bf_t *)b1;
2396     }
2397 
2398     r_sign = bf_logic_op1(a_sign, b_sign, op);
2399     if (op == BF_LOGIC_AND && r_sign == 0) {
2400         /* no need to compute extra zeros for and */
2401         if (a_sign == 0 && b_sign == 0)
2402             l = bf_min(a->expn, b->expn);
2403         else if (a_sign == 0)
2404             l = a->expn;
2405         else
2406             l = b->expn;
2407     } else {
2408         l = bf_max(a->expn, b->expn);
2409     }
2410     /* Note: a or b can be zero */
2411     l = (bf_max(l, 1) + LIMB_BITS - 1) / LIMB_BITS;
2412     if (bf_resize(r, l))
2413         goto fail;
2414     a_bit_offset = a->len * LIMB_BITS - a->expn;
2415     b_bit_offset = b->len * LIMB_BITS - b->expn;
2416     v1_mask = -a_sign;
2417     v2_mask = -b_sign;
2418     r_mask = -r_sign;
2419     for(i = 0; i < l; i++) {
2420         v1 = get_bits(a->tab, a->len, a_bit_offset + i * LIMB_BITS) ^ v1_mask;
2421         v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS) ^ v2_mask;
2422         r->tab[i] = bf_logic_op1(v1, v2, op) ^ r_mask;
2423     }
2424     r->expn = l * LIMB_BITS;
2425     r->sign = r_sign;
2426     bf_normalize_and_round(r, BF_PREC_INF, BF_RNDZ); /* cannot fail */
2427     if (r_sign) {
2428         if (bf_add_si(r, r, -1, BF_PREC_INF, BF_RNDZ))
2429             goto fail;
2430     }
2431     ret = 0;
2432  done:
2433     if (a == &a1_s)
2434         bf_delete(a);
2435     if (b == &b1_s)
2436         bf_delete(b);
2437     return ret;
2438  fail:
2439     bf_set_nan(r);
2440     ret = BF_ST_MEM_ERROR;
2441     goto done;
2442 }
2443 
2444 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_or(bf_t * r,const bf_t * a,const bf_t * b)2445 int bf_logic_or(bf_t *r, const bf_t *a, const bf_t *b)
2446 {
2447     return bf_logic_op(r, a, b, BF_LOGIC_OR);
2448 }
2449 
2450 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_xor(bf_t * r,const bf_t * a,const bf_t * b)2451 int bf_logic_xor(bf_t *r, const bf_t *a, const bf_t *b)
2452 {
2453     return bf_logic_op(r, a, b, BF_LOGIC_XOR);
2454 }
2455 
2456 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_and(bf_t * r,const bf_t * a,const bf_t * b)2457 int bf_logic_and(bf_t *r, const bf_t *a, const bf_t *b)
2458 {
2459     return bf_logic_op(r, a, b, BF_LOGIC_AND);
2460 }
2461 
2462 /* conversion between fixed size types */
2463 
2464 typedef union {
2465     double d;
2466     uint64_t u;
2467 } Float64Union;
2468 
bf_get_float64(const bf_t * a,double * pres,bf_rnd_t rnd_mode)2469 int bf_get_float64(const bf_t *a, double *pres, bf_rnd_t rnd_mode)
2470 {
2471     Float64Union u;
2472     int e, ret;
2473     uint64_t m;
2474 
2475     ret = 0;
2476     if (a->expn == BF_EXP_NAN) {
2477         u.u = 0x7ff8000000000000; /* quiet nan */
2478     } else {
2479         bf_t b_s, *b = &b_s;
2480 
2481         bf_init(a->ctx, b);
2482         bf_set(b, a);
2483         if (bf_is_finite(b)) {
2484             ret = bf_round(b, 53, rnd_mode | BF_FLAG_SUBNORMAL | bf_set_exp_bits(11));
2485         }
2486         if (b->expn == BF_EXP_INF) {
2487             e = (1 << 11) - 1;
2488             m = 0;
2489         } else if (b->expn == BF_EXP_ZERO) {
2490             e = 0;
2491             m = 0;
2492         } else {
2493             e = b->expn + 1023 - 1;
2494 #if LIMB_BITS == 32
2495             if (b->len == 2) {
2496                 m = ((uint64_t)b->tab[1] << 32) | b->tab[0];
2497             } else {
2498                 m = ((uint64_t)b->tab[0] << 32);
2499             }
2500 #else
2501             m = b->tab[0];
2502 #endif
2503             if (e <= 0) {
2504                 /* subnormal */
2505                 m = m >> (12 - e);
2506                 e = 0;
2507             } else {
2508                 m = (m << 1) >> 12;
2509             }
2510         }
2511         u.u = m | ((uint64_t)e << 52) | ((uint64_t)b->sign << 63);
2512         bf_delete(b);
2513     }
2514     *pres = u.d;
2515     return ret;
2516 }
2517 
bf_set_float64(bf_t * a,double d)2518 int bf_set_float64(bf_t *a, double d)
2519 {
2520     Float64Union u;
2521     uint64_t m;
2522     int shift, e, sgn;
2523 
2524     u.d = d;
2525     sgn = u.u >> 63;
2526     e = (u.u >> 52) & ((1 << 11) - 1);
2527     m = u.u & (((uint64_t)1 << 52) - 1);
2528     if (e == ((1 << 11) - 1)) {
2529         if (m != 0) {
2530             bf_set_nan(a);
2531         } else {
2532             bf_set_inf(a, sgn);
2533         }
2534     } else if (e == 0) {
2535         if (m == 0) {
2536             bf_set_zero(a, sgn);
2537         } else {
2538             /* subnormal number */
2539             m <<= 12;
2540             shift = clz64(m);
2541             m <<= shift;
2542             e = -shift;
2543             goto norm;
2544         }
2545     } else {
2546         m = (m << 11) | ((uint64_t)1 << 63);
2547     norm:
2548         a->expn = e - 1023 + 1;
2549 #if LIMB_BITS == 32
2550         if (bf_resize(a, 2))
2551             goto fail;
2552         a->tab[0] = m;
2553         a->tab[1] = m >> 32;
2554 #else
2555         if (bf_resize(a, 1))
2556             goto fail;
2557         a->tab[0] = m;
2558 #endif
2559         a->sign = sgn;
2560     }
2561     return 0;
2562 fail:
2563     bf_set_nan(a);
2564     return BF_ST_MEM_ERROR;
2565 }
2566 
2567 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
2568    is an overflow and 0 otherwise. */
bf_get_int32(int * pres,const bf_t * a,int flags)2569 int bf_get_int32(int *pres, const bf_t *a, int flags)
2570 {
2571     uint32_t v;
2572     int ret;
2573     if (a->expn >= BF_EXP_INF) {
2574         ret = 0;
2575         if (flags & BF_GET_INT_MOD) {
2576             v = 0;
2577         } else if (a->expn == BF_EXP_INF) {
2578             v = (uint32_t)INT32_MAX + a->sign;
2579             /* XXX: return overflow ? */
2580         } else {
2581             v = INT32_MAX;
2582         }
2583     } else if (a->expn <= 0) {
2584         v = 0;
2585         ret = 0;
2586     } else if (a->expn <= 31) {
2587         v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2588         if (a->sign)
2589             v = -v;
2590         ret = 0;
2591     } else if (!(flags & BF_GET_INT_MOD)) {
2592         ret = BF_ST_OVERFLOW;
2593         if (a->sign) {
2594             v = (uint32_t)INT32_MAX + 1;
2595             if (a->expn == 32 &&
2596                 (a->tab[a->len - 1] >> (LIMB_BITS - 32)) == v) {
2597                 ret = 0;
2598             }
2599         } else {
2600             v = INT32_MAX;
2601         }
2602     } else {
2603         v = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
2604         if (a->sign)
2605             v = -v;
2606         ret = 0;
2607     }
2608     *pres = v;
2609     return ret;
2610 }
2611 
2612 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
2613    is an overflow and 0 otherwise. */
bf_get_int64(int64_t * pres,const bf_t * a,int flags)2614 int bf_get_int64(int64_t *pres, const bf_t *a, int flags)
2615 {
2616     uint64_t v;
2617     int ret;
2618     if (a->expn >= BF_EXP_INF) {
2619         ret = 0;
2620         if (flags & BF_GET_INT_MOD) {
2621             v = 0;
2622         } else if (a->expn == BF_EXP_INF) {
2623             v = (uint64_t)INT64_MAX + a->sign;
2624         } else {
2625             v = INT64_MAX;
2626         }
2627     } else if (a->expn <= 0) {
2628         v = 0;
2629         ret = 0;
2630     } else if (a->expn <= 63) {
2631 #if LIMB_BITS == 32
2632         if (a->expn <= 32)
2633             v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2634         else
2635             v = (((uint64_t)a->tab[a->len - 1] << 32) |
2636                  get_limbz(a, a->len - 2)) >> (64 - a->expn);
2637 #else
2638         v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2639 #endif
2640         if (a->sign)
2641             v = -v;
2642         ret = 0;
2643     } else if (!(flags & BF_GET_INT_MOD)) {
2644         ret = BF_ST_OVERFLOW;
2645         if (a->sign) {
2646             uint64_t v1;
2647             v = (uint64_t)INT64_MAX + 1;
2648             if (a->expn == 64) {
2649                 v1 = a->tab[a->len - 1];
2650 #if LIMB_BITS == 32
2651                 v1 = (v1 << 32) | get_limbz(a, a->len - 2);
2652 #endif
2653                 if (v1 == v)
2654                     ret = 0;
2655             }
2656         } else {
2657             v = INT64_MAX;
2658         }
2659     } else {
2660         slimb_t bit_pos = a->len * LIMB_BITS - a->expn;
2661         v = get_bits(a->tab, a->len, bit_pos);
2662 #if LIMB_BITS == 32
2663         v |= (uint64_t)get_bits(a->tab, a->len, bit_pos + 32) << 32;
2664 #endif
2665         if (a->sign)
2666             v = -v;
2667         ret = 0;
2668     }
2669     *pres = v;
2670     return ret;
2671 }
2672 
2673 /* base conversion from radix */
2674 
2675 static const uint8_t digits_per_limb_table[BF_RADIX_MAX - 1] = {
2676 #if LIMB_BITS == 32
2677 32,20,16,13,12,11,10,10, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
2678 #else
2679 64,40,32,27,24,22,21,20,19,18,17,17,16,16,16,15,15,15,14,14,14,14,13,13,13,13,13,13,13,12,12,12,12,12,12,
2680 #endif
2681 };
2682 
get_limb_radix(int radix)2683 static limb_t get_limb_radix(int radix)
2684 {
2685     int i, k;
2686     limb_t radixl;
2687 
2688     k = digits_per_limb_table[radix - 2];
2689     radixl = radix;
2690     for(i = 1; i < k; i++)
2691         radixl *= radix;
2692     return radixl;
2693 }
2694 
2695 /* return != 0 if error */
bf_integer_from_radix_rec(bf_t * r,const limb_t * tab,limb_t n,int level,limb_t n0,limb_t radix,bf_t * pow_tab)2696 static int bf_integer_from_radix_rec(bf_t *r, const limb_t *tab,
2697                                      limb_t n, int level, limb_t n0,
2698                                      limb_t radix, bf_t *pow_tab)
2699 {
2700     int ret;
2701     if (n == 1) {
2702         ret = bf_set_ui(r, tab[0]);
2703     } else {
2704         bf_t T_s, *T = &T_s, *B;
2705         limb_t n1, n2;
2706 
2707         n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
2708         n1 = n - n2;
2709         //        printf("level=%d n0=%ld n1=%ld n2=%ld\n", level, n0, n1, n2);
2710         B = &pow_tab[level];
2711         if (B->len == 0) {
2712             ret = bf_pow_ui_ui(B, radix, n2, BF_PREC_INF, BF_RNDZ);
2713             if (ret)
2714                 return ret;
2715         }
2716         ret = bf_integer_from_radix_rec(r, tab + n2, n1, level + 1, n0,
2717                                         radix, pow_tab);
2718         if (ret)
2719             return ret;
2720         ret = bf_mul(r, r, B, BF_PREC_INF, BF_RNDZ);
2721         if (ret)
2722             return ret;
2723         bf_init(r->ctx, T);
2724         ret = bf_integer_from_radix_rec(T, tab, n2, level + 1, n0,
2725                                         radix, pow_tab);
2726         if (!ret)
2727             ret = bf_add(r, r, T, BF_PREC_INF, BF_RNDZ);
2728         bf_delete(T);
2729     }
2730     return ret;
2731     //    bf_print_str("  r=", r);
2732 }
2733 
2734 /* return 0 if OK != 0 if memory error */
bf_integer_from_radix(bf_t * r,const limb_t * tab,limb_t n,limb_t radix)2735 static int bf_integer_from_radix(bf_t *r, const limb_t *tab,
2736                                  limb_t n, limb_t radix)
2737 {
2738     bf_context_t *s = r->ctx;
2739     int pow_tab_len, i, ret;
2740     limb_t radixl;
2741     bf_t *pow_tab;
2742 
2743     radixl = get_limb_radix(radix);
2744     pow_tab_len = ceil_log2(n) + 2; /* XXX: check */
2745     pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
2746     if (!pow_tab)
2747         return -1;
2748     for(i = 0; i < pow_tab_len; i++)
2749         bf_init(r->ctx, &pow_tab[i]);
2750     ret = bf_integer_from_radix_rec(r, tab, n, 0, n, radixl, pow_tab);
2751     for(i = 0; i < pow_tab_len; i++) {
2752         bf_delete(&pow_tab[i]);
2753     }
2754     bf_free(s, pow_tab);
2755     return ret;
2756 }
2757 
2758 /* compute and round T * radix^expn. */
bf_mul_pow_radix(bf_t * r,const bf_t * T,limb_t radix,slimb_t expn,limb_t prec,bf_flags_t flags)2759 int bf_mul_pow_radix(bf_t *r, const bf_t *T, limb_t radix,
2760                      slimb_t expn, limb_t prec, bf_flags_t flags)
2761 {
2762     int ret, expn_sign, overflow;
2763     slimb_t e, extra_bits, prec1, ziv_extra_bits;
2764     bf_t B_s, *B = &B_s;
2765 
2766     if (T->len == 0) {
2767         return bf_set(r, T);
2768     } else if (expn == 0) {
2769         ret = bf_set(r, T);
2770         ret |= bf_round(r, prec, flags);
2771         return ret;
2772     }
2773 
2774     e = expn;
2775     expn_sign = 0;
2776     if (e < 0) {
2777         e = -e;
2778         expn_sign = 1;
2779     }
2780     bf_init(r->ctx, B);
2781     if (prec == BF_PREC_INF) {
2782         /* infinite precision: only used if the result is known to be exact */
2783         ret = bf_pow_ui_ui(B, radix, e, BF_PREC_INF, BF_RNDN);
2784         if (expn_sign) {
2785             ret |= bf_div(r, T, B, T->len * LIMB_BITS, BF_RNDN);
2786         } else {
2787             ret |= bf_mul(r, T, B, BF_PREC_INF, BF_RNDN);
2788         }
2789     } else {
2790         ziv_extra_bits = 16;
2791         for(;;) {
2792             prec1 = prec + ziv_extra_bits;
2793             /* XXX: correct overflow/underflow handling */
2794             /* XXX: rigorous error analysis needed */
2795             extra_bits = ceil_log2(e) * 2 + 1;
2796             ret = bf_pow_ui_ui(B, radix, e, prec1 + extra_bits, BF_RNDN);
2797             overflow = !bf_is_finite(B);
2798             /* XXX: if bf_pow_ui_ui returns an exact result, can stop
2799                after the next operation */
2800             if (expn_sign)
2801                 ret |= bf_div(r, T, B, prec1 + extra_bits, BF_RNDN);
2802             else
2803                 ret |= bf_mul(r, T, B, prec1 + extra_bits, BF_RNDN);
2804             if (ret & BF_ST_MEM_ERROR)
2805                 break;
2806             if ((ret & BF_ST_INEXACT) &&
2807                 !bf_can_round(r, prec, flags & BF_RND_MASK, prec1) &&
2808                 !overflow) {
2809                 /* and more precision and retry */
2810                 ziv_extra_bits = ziv_extra_bits  + (ziv_extra_bits / 2);
2811             } else {
2812                 ret = bf_round(r, prec, flags) | (ret & BF_ST_INEXACT);
2813                 break;
2814             }
2815         }
2816     }
2817     bf_delete(B);
2818     return ret;
2819 }
2820 
to_digit(int c)2821 static inline int to_digit(int c)
2822 {
2823     if (c >= '0' && c <= '9')
2824         return c - '0';
2825     else if (c >= 'A' && c <= 'Z')
2826         return c - 'A' + 10;
2827     else if (c >= 'a' && c <= 'z')
2828         return c - 'a' + 10;
2829     else
2830         return 36;
2831 }
2832 
2833 /* add a limb at 'pos' and decrement pos. new space is created if
2834    needed. Return 0 if OK, -1 if memory error */
bf_add_limb(bf_t * a,slimb_t * ppos,limb_t v)2835 static int bf_add_limb(bf_t *a, slimb_t *ppos, limb_t v)
2836 {
2837     slimb_t pos;
2838     pos = *ppos;
2839     if (unlikely(pos < 0)) {
2840         limb_t new_size, d, *new_tab;
2841         new_size = bf_max(a->len + 1, a->len * 3 / 2);
2842         new_tab = bf_realloc(a->ctx, a->tab, sizeof(limb_t) * new_size);
2843         if (!new_tab)
2844             return -1;
2845         a->tab = new_tab;
2846         d = new_size - a->len;
2847         memmove(a->tab + d, a->tab, a->len * sizeof(limb_t));
2848         a->len = new_size;
2849         pos += d;
2850     }
2851     a->tab[pos--] = v;
2852     *ppos = pos;
2853     return 0;
2854 }
2855 
bf_tolower(int c)2856 static int bf_tolower(int c)
2857 {
2858     if (c >= 'A' && c <= 'Z')
2859         c = c - 'A' + 'a';
2860     return c;
2861 }
2862 
strcasestart(const char * str,const char * val,const char ** ptr)2863 static int strcasestart(const char *str, const char *val, const char **ptr)
2864 {
2865     const char *p, *q;
2866     p = str;
2867     q = val;
2868     while (*q != '\0') {
2869         if (bf_tolower(*p) != *q)
2870             return 0;
2871         p++;
2872         q++;
2873     }
2874     if (ptr)
2875         *ptr = p;
2876     return 1;
2877 }
2878 
bf_atof_internal(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)2879 static int bf_atof_internal(bf_t *r, slimb_t *pexponent,
2880                             const char *str, const char **pnext, int radix,
2881                             limb_t prec, bf_flags_t flags, BOOL is_dec)
2882 {
2883     const char *p, *p_start;
2884     int is_neg, radix_bits, exp_is_neg, ret, digits_per_limb, shift;
2885     limb_t cur_limb;
2886     slimb_t pos, expn, int_len, digit_count;
2887     BOOL has_decpt, is_bin_exp;
2888     bf_t a_s, *a;
2889 
2890     *pexponent = 0;
2891     p = str;
2892     if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2893         strcasestart(p, "nan", &p)) {
2894         bf_set_nan(r);
2895         ret = 0;
2896         goto done;
2897     }
2898     is_neg = 0;
2899 
2900     if (p[0] == '+') {
2901         p++;
2902         p_start = p;
2903     } else if (p[0] == '-') {
2904         is_neg = 1;
2905         p++;
2906         p_start = p;
2907     } else {
2908         p_start = p;
2909     }
2910     if (p[0] == '0') {
2911         if ((p[1] == 'x' || p[1] == 'X') &&
2912             (radix == 0 || radix == 16) &&
2913             !(flags & BF_ATOF_NO_HEX)) {
2914             radix = 16;
2915             p += 2;
2916         } else if ((p[1] == 'o' || p[1] == 'O') &&
2917                    radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2918             p += 2;
2919             radix = 8;
2920         } else if ((p[1] == 'b' || p[1] == 'B') &&
2921                    radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2922             p += 2;
2923             radix = 2;
2924         } else {
2925             goto no_prefix;
2926         }
2927         /* there must be a digit after the prefix */
2928         if (to_digit((uint8_t)*p) >= radix) {
2929             bf_set_nan(r);
2930             ret = 0;
2931             goto done;
2932         }
2933     no_prefix: ;
2934     } else {
2935         if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2936             strcasestart(p, "inf", &p)) {
2937             bf_set_inf(r, is_neg);
2938             ret = 0;
2939             goto done;
2940         }
2941     }
2942 
2943     if (radix == 0)
2944         radix = 10;
2945     if (is_dec) {
2946         assert(radix == 10);
2947         radix_bits = 0;
2948         a = r;
2949     } else if ((radix & (radix - 1)) != 0) {
2950         radix_bits = 0; /* base is not a power of two */
2951         a = &a_s;
2952         bf_init(r->ctx, a);
2953     } else {
2954         radix_bits = ceil_log2(radix);
2955         a = r;
2956     }
2957 
2958     /* skip leading zeros */
2959     /* XXX: could also skip zeros after the decimal point */
2960     while (*p == '0')
2961         p++;
2962 
2963     if (radix_bits) {
2964         shift = digits_per_limb = LIMB_BITS;
2965     } else {
2966         radix_bits = 0;
2967         shift = digits_per_limb = digits_per_limb_table[radix - 2];
2968     }
2969     cur_limb = 0;
2970     bf_resize(a, 1);
2971     pos = 0;
2972     has_decpt = FALSE;
2973     int_len = digit_count = 0;
2974     for(;;) {
2975         limb_t c;
2976         if (*p == '.' && (p > p_start || to_digit(p[1]) < radix)) {
2977             if (has_decpt)
2978                 break;
2979             has_decpt = TRUE;
2980             int_len = digit_count;
2981             p++;
2982         }
2983         c = to_digit(*p);
2984         if (c >= radix)
2985             break;
2986         digit_count++;
2987         p++;
2988         if (radix_bits) {
2989             shift -= radix_bits;
2990             if (shift <= 0) {
2991                 cur_limb |= c >> (-shift);
2992                 if (bf_add_limb(a, &pos, cur_limb))
2993                     goto mem_error;
2994                 if (shift < 0)
2995                     cur_limb = c << (LIMB_BITS + shift);
2996                 else
2997                     cur_limb = 0;
2998                 shift += LIMB_BITS;
2999             } else {
3000                 cur_limb |= c << shift;
3001             }
3002         } else {
3003             cur_limb = cur_limb * radix + c;
3004             shift--;
3005             if (shift == 0) {
3006                 if (bf_add_limb(a, &pos, cur_limb))
3007                     goto mem_error;
3008                 shift = digits_per_limb;
3009                 cur_limb = 0;
3010             }
3011         }
3012     }
3013     if (!has_decpt)
3014         int_len = digit_count;
3015 
3016     /* add the last limb and pad with zeros */
3017     if (shift != digits_per_limb) {
3018         if (radix_bits == 0) {
3019             while (shift != 0) {
3020                 cur_limb *= radix;
3021                 shift--;
3022             }
3023         }
3024         if (bf_add_limb(a, &pos, cur_limb)) {
3025         mem_error:
3026             ret = BF_ST_MEM_ERROR;
3027             if (!radix_bits)
3028                 bf_delete(a);
3029             bf_set_nan(r);
3030             goto done;
3031         }
3032     }
3033 
3034     /* reset the next limbs to zero (we prefer to reallocate in the
3035        renormalization) */
3036     memset(a->tab, 0, (pos + 1) * sizeof(limb_t));
3037 
3038     if (p == p_start) {
3039         ret = 0;
3040         if (!radix_bits)
3041             bf_delete(a);
3042         bf_set_nan(r);
3043         goto done;
3044     }
3045 
3046     /* parse the exponent, if any */
3047     expn = 0;
3048     is_bin_exp = FALSE;
3049     if (((radix == 10 && (*p == 'e' || *p == 'E')) ||
3050          (radix != 10 && (*p == '@' ||
3051                           (radix_bits && (*p == 'p' || *p == 'P'))))) &&
3052         p > p_start) {
3053         is_bin_exp = (*p == 'p' || *p == 'P');
3054         p++;
3055         exp_is_neg = 0;
3056         if (*p == '+') {
3057             p++;
3058         } else if (*p == '-') {
3059             exp_is_neg = 1;
3060             p++;
3061         }
3062         for(;;) {
3063             int c;
3064             c = to_digit(*p);
3065             if (c >= 10)
3066                 break;
3067             if (unlikely(expn > ((EXP_MAX - 2 - 9) / 10))) {
3068                 /* exponent overflow */
3069                 if (exp_is_neg) {
3070                     bf_set_zero(r, is_neg);
3071                     ret = BF_ST_UNDERFLOW | BF_ST_INEXACT;
3072                 } else {
3073                     bf_set_inf(r, is_neg);
3074                     ret = BF_ST_OVERFLOW | BF_ST_INEXACT;
3075                 }
3076                 goto done;
3077             }
3078             p++;
3079             expn = expn * 10 + c;
3080         }
3081         if (exp_is_neg)
3082             expn = -expn;
3083     }
3084     if (is_dec) {
3085         a->expn = expn + int_len;
3086         a->sign = is_neg;
3087         ret = bfdec_normalize_and_round((bfdec_t *)a, prec, flags);
3088     } else if (radix_bits) {
3089         /* XXX: may overflow */
3090         if (!is_bin_exp)
3091             expn *= radix_bits;
3092         a->expn = expn + (int_len * radix_bits);
3093         a->sign = is_neg;
3094         ret = bf_normalize_and_round(a, prec, flags);
3095     } else {
3096         limb_t l;
3097         pos++;
3098         l = a->len - pos; /* number of limbs */
3099         if (l == 0) {
3100             bf_set_zero(r, is_neg);
3101             ret = 0;
3102         } else {
3103             bf_t T_s, *T = &T_s;
3104 
3105             expn -= l * digits_per_limb - int_len;
3106             bf_init(r->ctx, T);
3107             if (bf_integer_from_radix(T, a->tab + pos, l, radix)) {
3108                 bf_set_nan(r);
3109                 ret = BF_ST_MEM_ERROR;
3110             } else {
3111                 T->sign = is_neg;
3112                 if (flags & BF_ATOF_EXPONENT) {
3113                     /* return the exponent */
3114                     *pexponent = expn;
3115                     ret = bf_set(r, T);
3116                 } else {
3117                     ret = bf_mul_pow_radix(r, T, radix, expn, prec, flags);
3118                 }
3119             }
3120             bf_delete(T);
3121         }
3122         bf_delete(a);
3123     }
3124  done:
3125     if (pnext)
3126         *pnext = p;
3127     return ret;
3128 }
3129 
3130 /*
3131    Return (status, n, exp). 'status' is the floating point status. 'n'
3132    is the parsed number.
3133 
3134    If (flags & BF_ATOF_EXPONENT) and if the radix is not a power of
3135    two, the parsed number is equal to r *
3136    (*pexponent)^radix. Otherwise *pexponent = 0.
3137 */
bf_atof2(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3138 int bf_atof2(bf_t *r, slimb_t *pexponent,
3139              const char *str, const char **pnext, int radix,
3140              limb_t prec, bf_flags_t flags)
3141 {
3142     return bf_atof_internal(r, pexponent, str, pnext, radix, prec, flags,
3143                             FALSE);
3144 }
3145 
bf_atof(bf_t * r,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3146 int bf_atof(bf_t *r, const char *str, const char **pnext, int radix,
3147             limb_t prec, bf_flags_t flags)
3148 {
3149     slimb_t dummy_exp;
3150     return bf_atof_internal(r, &dummy_exp, str, pnext, radix, prec, flags, FALSE);
3151 }
3152 
3153 /* base conversion to radix */
3154 
3155 #if LIMB_BITS == 64
3156 #define RADIXL_10 UINT64_C(10000000000000000000)
3157 #else
3158 #define RADIXL_10 UINT64_C(1000000000)
3159 #endif
3160 
3161 static const uint32_t inv_log2_radix[BF_RADIX_MAX - 1][LIMB_BITS / 32 + 1] = {
3162 #if LIMB_BITS == 32
3163 { 0x80000000, 0x00000000,},
3164 { 0x50c24e60, 0xd4d4f4a7,},
3165 { 0x40000000, 0x00000000,},
3166 { 0x372068d2, 0x0a1ee5ca,},
3167 { 0x3184648d, 0xb8153e7a,},
3168 { 0x2d983275, 0x9d5369c4,},
3169 { 0x2aaaaaaa, 0xaaaaaaab,},
3170 { 0x28612730, 0x6a6a7a54,},
3171 { 0x268826a1, 0x3ef3fde6,},
3172 { 0x25001383, 0xbac8a744,},
3173 { 0x23b46706, 0x82c0c709,},
3174 { 0x229729f1, 0xb2c83ded,},
3175 { 0x219e7ffd, 0xa5ad572b,},
3176 { 0x20c33b88, 0xda7c29ab,},
3177 { 0x20000000, 0x00000000,},
3178 { 0x1f50b57e, 0xac5884b3,},
3179 { 0x1eb22cc6, 0x8aa6e26f,},
3180 { 0x1e21e118, 0x0c5daab2,},
3181 { 0x1d9dcd21, 0x439834e4,},
3182 { 0x1d244c78, 0x367a0d65,},
3183 { 0x1cb40589, 0xac173e0c,},
3184 { 0x1c4bd95b, 0xa8d72b0d,},
3185 { 0x1bead768, 0x98f8ce4c,},
3186 { 0x1b903469, 0x050f72e5,},
3187 { 0x1b3b433f, 0x2eb06f15,},
3188 { 0x1aeb6f75, 0x9c46fc38,},
3189 { 0x1aa038eb, 0x0e3bfd17,},
3190 { 0x1a593062, 0xb38d8c56,},
3191 { 0x1a15f4c3, 0x2b95a2e6,},
3192 { 0x19d630dc, 0xcc7ddef9,},
3193 { 0x19999999, 0x9999999a,},
3194 { 0x195fec80, 0x8a609431,},
3195 { 0x1928ee7b, 0x0b4f22f9,},
3196 { 0x18f46acf, 0x8c06e318,},
3197 { 0x18c23246, 0xdc0a9f3d,},
3198 #else
3199 { 0x80000000, 0x00000000, 0x00000000,},
3200 { 0x50c24e60, 0xd4d4f4a7, 0x021f57bc,},
3201 { 0x40000000, 0x00000000, 0x00000000,},
3202 { 0x372068d2, 0x0a1ee5ca, 0x19ea911b,},
3203 { 0x3184648d, 0xb8153e7a, 0x7fc2d2e1,},
3204 { 0x2d983275, 0x9d5369c4, 0x4dec1661,},
3205 { 0x2aaaaaaa, 0xaaaaaaaa, 0xaaaaaaab,},
3206 { 0x28612730, 0x6a6a7a53, 0x810fabde,},
3207 { 0x268826a1, 0x3ef3fde6, 0x23e2566b,},
3208 { 0x25001383, 0xbac8a744, 0x385a3349,},
3209 { 0x23b46706, 0x82c0c709, 0x3f891718,},
3210 { 0x229729f1, 0xb2c83ded, 0x15fba800,},
3211 { 0x219e7ffd, 0xa5ad572a, 0xe169744b,},
3212 { 0x20c33b88, 0xda7c29aa, 0x9bddee52,},
3213 { 0x20000000, 0x00000000, 0x00000000,},
3214 { 0x1f50b57e, 0xac5884b3, 0x70e28eee,},
3215 { 0x1eb22cc6, 0x8aa6e26f, 0x06d1a2a2,},
3216 { 0x1e21e118, 0x0c5daab1, 0x81b4f4bf,},
3217 { 0x1d9dcd21, 0x439834e3, 0x81667575,},
3218 { 0x1d244c78, 0x367a0d64, 0xc8204d6d,},
3219 { 0x1cb40589, 0xac173e0c, 0x3b7b16ba,},
3220 { 0x1c4bd95b, 0xa8d72b0d, 0x5879f25a,},
3221 { 0x1bead768, 0x98f8ce4c, 0x66cc2858,},
3222 { 0x1b903469, 0x050f72e5, 0x0cf5488e,},
3223 { 0x1b3b433f, 0x2eb06f14, 0x8c89719c,},
3224 { 0x1aeb6f75, 0x9c46fc37, 0xab5fc7e9,},
3225 { 0x1aa038eb, 0x0e3bfd17, 0x1bd62080,},
3226 { 0x1a593062, 0xb38d8c56, 0x7998ab45,},
3227 { 0x1a15f4c3, 0x2b95a2e6, 0x46aed6a0,},
3228 { 0x19d630dc, 0xcc7ddef9, 0x5aadd61b,},
3229 { 0x19999999, 0x99999999, 0x9999999a,},
3230 { 0x195fec80, 0x8a609430, 0xe1106014,},
3231 { 0x1928ee7b, 0x0b4f22f9, 0x5f69791d,},
3232 { 0x18f46acf, 0x8c06e318, 0x4d2aeb2c,},
3233 { 0x18c23246, 0xdc0a9f3d, 0x3fe16970,},
3234 #endif
3235 };
3236 
3237 static const limb_t log2_radix[BF_RADIX_MAX - 1] = {
3238 #if LIMB_BITS == 32
3239 0x20000000,
3240 0x32b80347,
3241 0x40000000,
3242 0x4a4d3c26,
3243 0x52b80347,
3244 0x59d5d9fd,
3245 0x60000000,
3246 0x6570068e,
3247 0x6a4d3c26,
3248 0x6eb3a9f0,
3249 0x72b80347,
3250 0x766a008e,
3251 0x79d5d9fd,
3252 0x7d053f6d,
3253 0x80000000,
3254 0x82cc7edf,
3255 0x8570068e,
3256 0x87ef05ae,
3257 0x8a4d3c26,
3258 0x8c8ddd45,
3259 0x8eb3a9f0,
3260 0x90c10501,
3261 0x92b80347,
3262 0x949a784c,
3263 0x966a008e,
3264 0x982809d6,
3265 0x99d5d9fd,
3266 0x9b74948f,
3267 0x9d053f6d,
3268 0x9e88c6b3,
3269 0xa0000000,
3270 0xa16bad37,
3271 0xa2cc7edf,
3272 0xa4231623,
3273 0xa570068e,
3274 #else
3275 0x2000000000000000,
3276 0x32b803473f7ad0f4,
3277 0x4000000000000000,
3278 0x4a4d3c25e68dc57f,
3279 0x52b803473f7ad0f4,
3280 0x59d5d9fd5010b366,
3281 0x6000000000000000,
3282 0x6570068e7ef5a1e8,
3283 0x6a4d3c25e68dc57f,
3284 0x6eb3a9f01975077f,
3285 0x72b803473f7ad0f4,
3286 0x766a008e4788cbcd,
3287 0x79d5d9fd5010b366,
3288 0x7d053f6d26089673,
3289 0x8000000000000000,
3290 0x82cc7edf592262d0,
3291 0x8570068e7ef5a1e8,
3292 0x87ef05ae409a0289,
3293 0x8a4d3c25e68dc57f,
3294 0x8c8ddd448f8b845a,
3295 0x8eb3a9f01975077f,
3296 0x90c10500d63aa659,
3297 0x92b803473f7ad0f4,
3298 0x949a784bcd1b8afe,
3299 0x966a008e4788cbcd,
3300 0x982809d5be7072dc,
3301 0x99d5d9fd5010b366,
3302 0x9b74948f5532da4b,
3303 0x9d053f6d26089673,
3304 0x9e88c6b3626a72aa,
3305 0xa000000000000000,
3306 0xa16bad3758efd873,
3307 0xa2cc7edf592262d0,
3308 0xa4231623369e78e6,
3309 0xa570068e7ef5a1e8,
3310 #endif
3311 };
3312 
3313 /* compute floor(a*b) or ceil(a*b) with b = log2(radix) or
3314    b=1/log2(radix). For is_inv = 0, strict accuracy is not guaranteed
3315    when radix is not a power of two. */
bf_mul_log2_radix(slimb_t a1,unsigned int radix,int is_inv,int is_ceil1)3316 slimb_t bf_mul_log2_radix(slimb_t a1, unsigned int radix, int is_inv,
3317                           int is_ceil1)
3318 {
3319     int is_neg;
3320     limb_t a;
3321     BOOL is_ceil;
3322 
3323     is_ceil = is_ceil1;
3324     a = a1;
3325     if (a1 < 0) {
3326         a = -a;
3327         is_neg = 1;
3328     } else {
3329         is_neg = 0;
3330     }
3331     is_ceil ^= is_neg;
3332     if ((radix & (radix - 1)) == 0) {
3333         int radix_bits;
3334         /* radix is a power of two */
3335         radix_bits = ceil_log2(radix);
3336         if (is_inv) {
3337             if (is_ceil)
3338                 a += radix_bits - 1;
3339             a = a / radix_bits;
3340         } else {
3341             a = a * radix_bits;
3342         }
3343     } else {
3344         const uint32_t *tab;
3345         limb_t b0, b1;
3346         dlimb_t t;
3347 
3348         if (is_inv) {
3349             tab = inv_log2_radix[radix - 2];
3350 #if LIMB_BITS == 32
3351             b1 = tab[0];
3352             b0 = tab[1];
3353 #else
3354             b1 = ((limb_t)tab[0] << 32) | tab[1];
3355             b0 = (limb_t)tab[2] << 32;
3356 #endif
3357             t = (dlimb_t)b0 * (dlimb_t)a;
3358             t = (dlimb_t)b1 * (dlimb_t)a + (t >> LIMB_BITS);
3359             a = t >> (LIMB_BITS - 1);
3360         } else {
3361             b0 = log2_radix[radix - 2];
3362             t = (dlimb_t)b0 * (dlimb_t)a;
3363             a = t >> (LIMB_BITS - 3);
3364         }
3365         /* a = floor(result) and 'result' cannot be an integer */
3366         a += is_ceil;
3367     }
3368     if (is_neg)
3369         a = -a;
3370     return a;
3371 }
3372 
3373 /* 'n' is the number of output limbs */
bf_integer_to_radix_rec(bf_t * pow_tab,limb_t * out,const bf_t * a,limb_t n,int level,limb_t n0,limb_t radixl,unsigned int radixl_bits)3374 static void bf_integer_to_radix_rec(bf_t *pow_tab,
3375                                     limb_t *out, const bf_t *a, limb_t n,
3376                                     int level, limb_t n0, limb_t radixl,
3377                                     unsigned int radixl_bits)
3378 {
3379     limb_t n1, n2, q_prec;
3380     assert(n >= 1);
3381     if (n == 1) {
3382         out[0] = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
3383     } else if (n == 2) {
3384         dlimb_t t;
3385         slimb_t pos;
3386         pos = a->len * LIMB_BITS - a->expn;
3387         t = ((dlimb_t)get_bits(a->tab, a->len, pos + LIMB_BITS) << LIMB_BITS) |
3388             get_bits(a->tab, a->len, pos);
3389         if (likely(radixl == RADIXL_10)) {
3390             /* use division by a constant when possible */
3391             out[0] = t % RADIXL_10;
3392             out[1] = t / RADIXL_10;
3393         } else {
3394             out[0] = t % radixl;
3395             out[1] = t / radixl;
3396         }
3397     } else {
3398         bf_t Q, R, *B, *B_inv;
3399         int q_add;
3400         bf_init(a->ctx, &Q);
3401         bf_init(a->ctx, &R);
3402         n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
3403         n1 = n - n2;
3404         B = &pow_tab[2 * level];
3405         B_inv = &pow_tab[2 * level + 1];
3406         if (B->len == 0) {
3407             /* compute BASE^n2 */
3408             bf_pow_ui_ui(B, radixl, n2, BF_PREC_INF, BF_RNDZ);
3409             /* we use enough bits for the maximum possible 'n1' value,
3410                i.e. n2 + 1 */
3411             bf_set_ui(&R, 1);
3412             bf_div(B_inv, &R, B, (n2 + 1) * radixl_bits + 2, BF_RNDN);
3413         }
3414         //        printf("%d: n1=% " PRId64 " n2=%" PRId64 "\n", level, n1, n2);
3415         q_prec = n1 * radixl_bits;
3416         bf_mul(&Q, a, B_inv, q_prec, BF_RNDN);
3417         bf_rint(&Q, BF_RNDZ);
3418 
3419         bf_mul(&R, &Q, B, BF_PREC_INF, BF_RNDZ);
3420         bf_sub(&R, a, &R, BF_PREC_INF, BF_RNDZ);
3421         /* adjust if necessary */
3422         q_add = 0;
3423         while (R.sign && R.len != 0) {
3424             bf_add(&R, &R, B, BF_PREC_INF, BF_RNDZ);
3425             q_add--;
3426         }
3427         while (bf_cmpu(&R, B) >= 0) {
3428             bf_sub(&R, &R, B, BF_PREC_INF, BF_RNDZ);
3429             q_add++;
3430         }
3431         if (q_add != 0) {
3432             bf_add_si(&Q, &Q, q_add, BF_PREC_INF, BF_RNDZ);
3433         }
3434         bf_integer_to_radix_rec(pow_tab, out + n2, &Q, n1, level + 1, n0,
3435                                 radixl, radixl_bits);
3436         bf_integer_to_radix_rec(pow_tab, out, &R, n2, level + 1, n0,
3437                                 radixl, radixl_bits);
3438         bf_delete(&Q);
3439         bf_delete(&R);
3440     }
3441 }
3442 
bf_integer_to_radix(bf_t * r,const bf_t * a,limb_t radixl)3443 static void bf_integer_to_radix(bf_t *r, const bf_t *a, limb_t radixl)
3444 {
3445     bf_context_t *s = r->ctx;
3446     limb_t r_len;
3447     bf_t *pow_tab;
3448     int i, pow_tab_len;
3449 
3450     r_len = r->len;
3451     pow_tab_len = (ceil_log2(r_len) + 2) * 2; /* XXX: check */
3452     pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
3453     for(i = 0; i < pow_tab_len; i++)
3454         bf_init(r->ctx, &pow_tab[i]);
3455 
3456     bf_integer_to_radix_rec(pow_tab, r->tab, a, r_len, 0, r_len, radixl,
3457                             ceil_log2(radixl));
3458 
3459     for(i = 0; i < pow_tab_len; i++) {
3460         bf_delete(&pow_tab[i]);
3461     }
3462     bf_free(s, pow_tab);
3463 }
3464 
3465 /* a must be >= 0. 'P' is the wanted number of digits in radix
3466    'radix'. 'r' is the mantissa represented as an integer. *pE
3467    contains the exponent. Return != 0 if memory error. */
bf_convert_to_radix(bf_t * r,slimb_t * pE,const bf_t * a,int radix,limb_t P,bf_rnd_t rnd_mode,BOOL is_fixed_exponent)3468 static int bf_convert_to_radix(bf_t *r, slimb_t *pE,
3469                                const bf_t *a, int radix,
3470                                limb_t P, bf_rnd_t rnd_mode,
3471                                BOOL is_fixed_exponent)
3472 {
3473     slimb_t E, e, prec, extra_bits, ziv_extra_bits, prec0;
3474     bf_t B_s, *B = &B_s;
3475     int e_sign, ret, res;
3476 
3477     if (a->len == 0) {
3478         /* zero case */
3479         *pE = 0;
3480         return bf_set(r, a);
3481     }
3482 
3483     if (is_fixed_exponent) {
3484         E = *pE;
3485     } else {
3486         /* compute the new exponent */
3487         E = 1 + bf_mul_log2_radix(a->expn - 1, radix, TRUE, FALSE);
3488     }
3489     //    bf_print_str("a", a);
3490     //    printf("E=%ld P=%ld radix=%d\n", E, P, radix);
3491 
3492     for(;;) {
3493         e = P - E;
3494         e_sign = 0;
3495         if (e < 0) {
3496             e = -e;
3497             e_sign = 1;
3498         }
3499         /* Note: precision for log2(radix) is not critical here */
3500         prec0 = bf_mul_log2_radix(P, radix, FALSE, TRUE);
3501         ziv_extra_bits = 16;
3502         for(;;) {
3503             prec = prec0 + ziv_extra_bits;
3504             /* XXX: rigorous error analysis needed */
3505             extra_bits = ceil_log2(e) * 2 + 1;
3506             ret = bf_pow_ui_ui(r, radix, e, prec + extra_bits, BF_RNDN);
3507             if (!e_sign)
3508                 ret |= bf_mul(r, r, a, prec + extra_bits, BF_RNDN);
3509             else
3510                 ret |= bf_div(r, a, r, prec + extra_bits, BF_RNDN);
3511             if (ret & BF_ST_MEM_ERROR)
3512                 return BF_ST_MEM_ERROR;
3513             /* if the result is not exact, check that it can be safely
3514                rounded to an integer */
3515             if ((ret & BF_ST_INEXACT) &&
3516                 !bf_can_round(r, r->expn, rnd_mode, prec)) {
3517                 /* and more precision and retry */
3518                 ziv_extra_bits = ziv_extra_bits  + (ziv_extra_bits / 2);
3519                 continue;
3520             } else {
3521                 ret = bf_rint(r, rnd_mode);
3522                 if (ret & BF_ST_MEM_ERROR)
3523                     return BF_ST_MEM_ERROR;
3524                 break;
3525             }
3526         }
3527         if (is_fixed_exponent)
3528             break;
3529         /* check that the result is < B^P */
3530         /* XXX: do a fast approximate test first ? */
3531         bf_init(r->ctx, B);
3532         ret = bf_pow_ui_ui(B, radix, P, BF_PREC_INF, BF_RNDZ);
3533         if (ret) {
3534             bf_delete(B);
3535             return ret;
3536         }
3537         res = bf_cmpu(r, B);
3538         bf_delete(B);
3539         if (res < 0)
3540             break;
3541         /* try a larger exponent */
3542         E++;
3543     }
3544     *pE = E;
3545     return 0;
3546 }
3547 
limb_to_a(char * buf,limb_t n,unsigned int radix,int len)3548 static void limb_to_a(char *buf, limb_t n, unsigned int radix, int len)
3549 {
3550     int digit, i;
3551 
3552     if (radix == 10) {
3553         /* specific case with constant divisor */
3554         for(i = len - 1; i >= 0; i--) {
3555             digit = (limb_t)n % 10;
3556             n = (limb_t)n / 10;
3557             buf[i] = digit + '0';
3558         }
3559     } else {
3560         for(i = len - 1; i >= 0; i--) {
3561             digit = (limb_t)n % radix;
3562             n = (limb_t)n / radix;
3563             if (digit < 10)
3564                 digit += '0';
3565             else
3566                 digit += 'a' - 10;
3567             buf[i] = digit;
3568         }
3569     }
3570 }
3571 
3572 /* for power of 2 radixes */
limb_to_a2(char * buf,limb_t n,unsigned int radix_bits,int len)3573 static void limb_to_a2(char *buf, limb_t n, unsigned int radix_bits, int len)
3574 {
3575     int digit, i;
3576     unsigned int mask;
3577 
3578     mask = (1 << radix_bits) - 1;
3579     for(i = len - 1; i >= 0; i--) {
3580         digit = n & mask;
3581         n >>= radix_bits;
3582         if (digit < 10)
3583             digit += '0';
3584         else
3585             digit += 'a' - 10;
3586         buf[i] = digit;
3587     }
3588 }
3589 
3590 /* 'a' must be an integer if the is_dec = FALSE or if the radix is not
3591    a power of two. A dot is added before the 'dot_pos' digit. dot_pos
3592    = n_digits does not display the dot. 0 <= dot_pos <=
3593    n_digits. n_digits >= 1. */
output_digits(DynBuf * s,const bf_t * a1,int radix,limb_t n_digits,limb_t dot_pos,BOOL is_dec)3594 static void output_digits(DynBuf *s, const bf_t *a1, int radix, limb_t n_digits,
3595                           limb_t dot_pos, BOOL is_dec)
3596 {
3597     limb_t i, v, l;
3598     slimb_t pos, pos_incr;
3599     int digits_per_limb, buf_pos, radix_bits, first_buf_pos;
3600     char buf[65];
3601     bf_t a_s, *a;
3602 
3603     if (is_dec) {
3604         digits_per_limb = LIMB_DIGITS;
3605         a = (bf_t *)a1;
3606         radix_bits = 0;
3607         pos = a->len;
3608         pos_incr = 1;
3609         first_buf_pos = 0;
3610     } else if ((radix & (radix - 1)) != 0) {
3611         limb_t n, radixl;
3612 
3613         digits_per_limb = digits_per_limb_table[radix - 2];
3614         radixl = get_limb_radix(radix);
3615         a = &a_s;
3616         bf_init(a1->ctx, a);
3617         n = (n_digits + digits_per_limb - 1) / digits_per_limb;
3618         bf_resize(a, n);
3619         bf_integer_to_radix(a, a1, radixl);
3620         radix_bits = 0;
3621         pos = n;
3622         pos_incr = 1;
3623         first_buf_pos = pos * digits_per_limb - n_digits;
3624     } else {
3625         a = (bf_t *)a1;
3626         radix_bits = ceil_log2(radix);
3627         digits_per_limb = LIMB_BITS / radix_bits;
3628         pos_incr = digits_per_limb * radix_bits;
3629         pos = a->len * LIMB_BITS - a->expn + n_digits * radix_bits;
3630         first_buf_pos = 0;
3631     }
3632     buf_pos = digits_per_limb;
3633     i = 0;
3634     while (i < n_digits) {
3635         if (buf_pos == digits_per_limb) {
3636             pos -= pos_incr;
3637             if (radix_bits == 0) {
3638                 v = get_limbz(a, pos);
3639                 limb_to_a(buf, v, radix, digits_per_limb);
3640             } else {
3641                 v = get_bits(a->tab, a->len, pos);
3642                 limb_to_a2(buf, v, radix_bits, digits_per_limb);
3643             }
3644             buf_pos = first_buf_pos;
3645             first_buf_pos = 0;
3646         }
3647         if (i < dot_pos) {
3648             l = dot_pos;
3649         } else {
3650             if (i == dot_pos)
3651                 dbuf_putc(s, '.');
3652             l = n_digits;
3653         }
3654         l = bf_min(digits_per_limb - buf_pos, l - i);
3655         dbuf_put(s, (uint8_t *)(buf + buf_pos), l);
3656         buf_pos += l;
3657         i += l;
3658     }
3659     if (a != a1)
3660         bf_delete(a);
3661 }
3662 
bf_dbuf_realloc(void * opaque,void * ptr,size_t size)3663 static void *bf_dbuf_realloc(void *opaque, void *ptr, size_t size)
3664 {
3665     bf_context_t *s = opaque;
3666     return bf_realloc(s, ptr, size);
3667 }
3668 
3669 /* return the length in bytes. A trailing '\0' is added */
bf_ftoa_internal(size_t * plen,const bf_t * a2,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)3670 static char *bf_ftoa_internal(size_t *plen, const bf_t *a2, int radix,
3671                               limb_t prec, bf_flags_t flags, BOOL is_dec)
3672 {
3673     DynBuf s_s, *s = &s_s;
3674     int radix_bits;
3675 
3676     //    bf_print_str("ftoa", a2);
3677     //    printf("radix=%d\n", radix);
3678     dbuf_init2(s, a2->ctx, bf_dbuf_realloc);
3679     if (a2->expn == BF_EXP_NAN) {
3680         dbuf_putstr(s, "NaN");
3681     } else {
3682         if (a2->sign)
3683             dbuf_putc(s, '-');
3684         if (a2->expn == BF_EXP_INF) {
3685             if (flags & BF_FTOA_JS_QUIRKS)
3686                 dbuf_putstr(s, "Infinity");
3687             else
3688                 dbuf_putstr(s, "Inf");
3689         } else {
3690             int fmt, ret;
3691             slimb_t n_digits, n, i, n_max, n1;
3692             bf_t a1_s, *a1;
3693             bf_t a_s, *a = &a_s;
3694 
3695             /* make a positive number */
3696             a->tab = a2->tab;
3697             a->len = a2->len;
3698             a->expn = a2->expn;
3699             a->sign = 0;
3700 
3701             if ((radix & (radix - 1)) != 0)
3702                 radix_bits = 0;
3703             else
3704                 radix_bits = ceil_log2(radix);
3705 
3706             fmt = flags & BF_FTOA_FORMAT_MASK;
3707             a1 = &a1_s;
3708             bf_init(a2->ctx, a1);
3709             if (fmt == BF_FTOA_FORMAT_FRAC) {
3710                 size_t pos, start;
3711                 assert(!is_dec);
3712                 /* one more digit for the rounding */
3713                 n = 1 + bf_mul_log2_radix(bf_max(a->expn, 0), radix, TRUE, TRUE);
3714                 n_digits = n + prec;
3715                 n1 = n;
3716                 if (bf_convert_to_radix(a1, &n1, a, radix, n_digits,
3717                                         flags & BF_RND_MASK, TRUE))
3718                     goto fail1;
3719                 start = s->size;
3720                 output_digits(s, a1, radix, n_digits, n, is_dec);
3721                 /* remove leading zeros because we allocated one more digit */
3722                 pos = start;
3723                 while ((pos + 1) < s->size && s->buf[pos] == '0' &&
3724                        s->buf[pos + 1] != '.')
3725                     pos++;
3726                 if (pos > start) {
3727                     memmove(s->buf + start, s->buf + pos, s->size - pos);
3728                     s->size -= (pos - start);
3729                 }
3730             } else {
3731 #ifdef USE_BF_DEC
3732                 if (is_dec) {
3733                     if (fmt == BF_FTOA_FORMAT_FIXED) {
3734                         n_digits = prec;
3735                         n_max = n_digits;
3736                     } else {
3737                         /* prec is ignored */
3738                         prec = n_digits = a->len * LIMB_DIGITS;
3739                         /* remove trailing zero digits */
3740                         while (n_digits > 1 &&
3741                                get_digit(a->tab, a->len, prec - n_digits) == 0) {
3742                             n_digits--;
3743                         }
3744                         n_max = n_digits + 4;
3745                     }
3746                     bf_init(a2->ctx, a1);
3747                     bf_set(a1, a);
3748                     n = a1->expn;
3749                 } else
3750 #endif
3751                 {
3752                     if (fmt == BF_FTOA_FORMAT_FIXED) {
3753                         n_digits = prec;
3754                         n_max = n_digits;
3755                     } else {
3756                         slimb_t n_digits_max, n_digits_min;
3757 
3758                         if (prec == BF_PREC_INF) {
3759                             assert(radix_bits != 0);
3760                             /* XXX: could use the exact number of bits */
3761                             prec = a->len * LIMB_BITS;
3762                         }
3763                         n_digits = 1 + bf_mul_log2_radix(prec, radix, TRUE, TRUE);
3764                         /* max number of digits for non exponential
3765                            notation. The rational is to have the same rule
3766                            as JS i.e. n_max = 21 for 64 bit float in base 10. */
3767                         n_max = n_digits + 4;
3768                         if (fmt == BF_FTOA_FORMAT_FREE_MIN) {
3769                             bf_t b_s, *b = &b_s;
3770 
3771                             /* find the minimum number of digits by
3772                                dichotomy. */
3773                             /* XXX: inefficient */
3774                             n_digits_max = n_digits;
3775                             n_digits_min = 1;
3776                             bf_init(a2->ctx, b);
3777                             while (n_digits_min < n_digits_max) {
3778                                 n_digits = (n_digits_min + n_digits_max) / 2;
3779                                 if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3780                                                         flags & BF_RND_MASK, FALSE)) {
3781                                     bf_delete(b);
3782                                     goto fail1;
3783                                 }
3784                                 /* convert back to a number and compare */
3785                                 ret = bf_mul_pow_radix(b, a1, radix, n - n_digits,
3786                                                        prec,
3787                                                        (flags & ~BF_RND_MASK) |
3788                                                        BF_RNDN);
3789                                 if (ret & BF_ST_MEM_ERROR) {
3790                                     bf_delete(b);
3791                                     goto fail1;
3792                                 }
3793                                 if (bf_cmpu(b, a) == 0) {
3794                                     n_digits_max = n_digits;
3795                                 } else {
3796                                     n_digits_min = n_digits + 1;
3797                                 }
3798                             }
3799                             bf_delete(b);
3800                             n_digits = n_digits_max;
3801                         }
3802                     }
3803                     if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3804                                             flags & BF_RND_MASK, FALSE)) {
3805                     fail1:
3806                         bf_delete(a1);
3807                         goto fail;
3808                     }
3809                 }
3810                 if (a1->expn == BF_EXP_ZERO &&
3811                     fmt != BF_FTOA_FORMAT_FIXED &&
3812                     !(flags & BF_FTOA_FORCE_EXP)) {
3813                     /* just output zero */
3814                     dbuf_putstr(s, "0");
3815                 } else {
3816                     if (flags & BF_FTOA_ADD_PREFIX) {
3817                         if (radix == 16)
3818                             dbuf_putstr(s, "0x");
3819                         else if (radix == 8)
3820                             dbuf_putstr(s, "0o");
3821                         else if (radix == 2)
3822                             dbuf_putstr(s, "0b");
3823                     }
3824                     if (a1->expn == BF_EXP_ZERO)
3825                         n = 1;
3826                     if ((flags & BF_FTOA_FORCE_EXP) ||
3827                         n <= -6 || n > n_max) {
3828                         const char *fmt;
3829                         /* exponential notation */
3830                         output_digits(s, a1, radix, n_digits, 1, is_dec);
3831                         if (radix_bits != 0 && radix <= 16) {
3832                             if (flags & BF_FTOA_JS_QUIRKS)
3833                                 fmt = "p%+" PRId_LIMB;
3834                             else
3835                                 fmt = "p%" PRId_LIMB;
3836                             dbuf_printf(s, fmt, (n - 1) * radix_bits);
3837                         } else {
3838                             if (flags & BF_FTOA_JS_QUIRKS)
3839                                 fmt = "%c%+" PRId_LIMB;
3840                             else
3841                                 fmt = "%c%" PRId_LIMB;
3842                             dbuf_printf(s, fmt,
3843                                         radix <= 10 ? 'e' : '@', n - 1);
3844                         }
3845                     } else if (n <= 0) {
3846                         /* 0.x */
3847                         dbuf_putstr(s, "0.");
3848                         for(i = 0; i < -n; i++) {
3849                             dbuf_putc(s, '0');
3850                         }
3851                         output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3852                     } else {
3853                         if (n_digits <= n) {
3854                             /* no dot */
3855                             output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3856                             for(i = 0; i < (n - n_digits); i++)
3857                                 dbuf_putc(s, '0');
3858                         } else {
3859                             output_digits(s, a1, radix, n_digits, n, is_dec);
3860                         }
3861                     }
3862                 }
3863             }
3864             bf_delete(a1);
3865         }
3866     }
3867     dbuf_putc(s, '\0');
3868     if (dbuf_error(s))
3869         goto fail;
3870     if (plen)
3871         *plen = s->size - 1;
3872     return (char *)s->buf;
3873  fail:
3874     bf_free(a2->ctx, s->buf);
3875     if (plen)
3876         *plen = 0;
3877     return NULL;
3878 }
3879 
bf_ftoa(size_t * plen,const bf_t * a,int radix,limb_t prec,bf_flags_t flags)3880 char *bf_ftoa(size_t *plen, const bf_t *a, int radix, limb_t prec,
3881               bf_flags_t flags)
3882 {
3883     return bf_ftoa_internal(plen, a, radix, prec, flags, FALSE);
3884 }
3885 
3886 /***************************************************************/
3887 /* transcendental functions */
3888 
3889 /* Note: the algorithm is from MPFR */
bf_const_log2_rec(bf_t * T,bf_t * P,bf_t * Q,limb_t n1,limb_t n2,BOOL need_P)3890 static void bf_const_log2_rec(bf_t *T, bf_t *P, bf_t *Q, limb_t n1,
3891                               limb_t n2, BOOL need_P)
3892 {
3893     bf_context_t *s = T->ctx;
3894     if ((n2 - n1) == 1) {
3895         if (n1 == 0) {
3896             bf_set_ui(P, 3);
3897         } else {
3898             bf_set_ui(P, n1);
3899             P->sign = 1;
3900         }
3901         bf_set_ui(Q, 2 * n1 + 1);
3902         Q->expn += 2;
3903         bf_set(T, P);
3904     } else {
3905         limb_t m;
3906         bf_t T1_s, *T1 = &T1_s;
3907         bf_t P1_s, *P1 = &P1_s;
3908         bf_t Q1_s, *Q1 = &Q1_s;
3909 
3910         m = n1 + ((n2 - n1) >> 1);
3911         bf_const_log2_rec(T, P, Q, n1, m, TRUE);
3912         bf_init(s, T1);
3913         bf_init(s, P1);
3914         bf_init(s, Q1);
3915         bf_const_log2_rec(T1, P1, Q1, m, n2, need_P);
3916         bf_mul(T, T, Q1, BF_PREC_INF, BF_RNDZ);
3917         bf_mul(T1, T1, P, BF_PREC_INF, BF_RNDZ);
3918         bf_add(T, T, T1, BF_PREC_INF, BF_RNDZ);
3919         if (need_P)
3920             bf_mul(P, P, P1, BF_PREC_INF, BF_RNDZ);
3921         bf_mul(Q, Q, Q1, BF_PREC_INF, BF_RNDZ);
3922         bf_delete(T1);
3923         bf_delete(P1);
3924         bf_delete(Q1);
3925     }
3926 }
3927 
3928 /* compute log(2) with faithful rounding at precision 'prec' */
bf_const_log2_internal(bf_t * T,limb_t prec)3929 static void bf_const_log2_internal(bf_t *T, limb_t prec)
3930 {
3931     limb_t w, N;
3932     bf_t P_s, *P = &P_s;
3933     bf_t Q_s, *Q = &Q_s;
3934 
3935     w = prec + 15;
3936     N = w / 3 + 1;
3937     bf_init(T->ctx, P);
3938     bf_init(T->ctx, Q);
3939     bf_const_log2_rec(T, P, Q, 0, N, FALSE);
3940     bf_div(T, T, Q, prec, BF_RNDN);
3941     bf_delete(P);
3942     bf_delete(Q);
3943 }
3944 
3945 /* PI constant */
3946 
3947 #define CHUD_A 13591409
3948 #define CHUD_B 545140134
3949 #define CHUD_C 640320
3950 #define CHUD_BITS_PER_TERM 47
3951 
chud_bs(bf_t * P,bf_t * Q,bf_t * G,int64_t a,int64_t b,int need_g,limb_t prec)3952 static void chud_bs(bf_t *P, bf_t *Q, bf_t *G, int64_t a, int64_t b, int need_g,
3953                     limb_t prec)
3954 {
3955     bf_context_t *s = P->ctx;
3956     int64_t c;
3957 
3958     if (a == (b - 1)) {
3959         bf_t T0, T1;
3960 
3961         bf_init(s, &T0);
3962         bf_init(s, &T1);
3963         bf_set_ui(G, 2 * b - 1);
3964         bf_mul_ui(G, G, 6 * b - 1, prec, BF_RNDN);
3965         bf_mul_ui(G, G, 6 * b - 5, prec, BF_RNDN);
3966         bf_set_ui(&T0, CHUD_B);
3967         bf_mul_ui(&T0, &T0, b, prec, BF_RNDN);
3968         bf_set_ui(&T1, CHUD_A);
3969         bf_add(&T0, &T0, &T1, prec, BF_RNDN);
3970         bf_mul(P, G, &T0, prec, BF_RNDN);
3971         P->sign = b & 1;
3972 
3973         bf_set_ui(Q, b);
3974         bf_mul_ui(Q, Q, b, prec, BF_RNDN);
3975         bf_mul_ui(Q, Q, b, prec, BF_RNDN);
3976         bf_mul_ui(Q, Q, (uint64_t)CHUD_C * CHUD_C * CHUD_C / 24, prec, BF_RNDN);
3977         bf_delete(&T0);
3978         bf_delete(&T1);
3979     } else {
3980         bf_t P2, Q2, G2;
3981 
3982         bf_init(s, &P2);
3983         bf_init(s, &Q2);
3984         bf_init(s, &G2);
3985 
3986         c = (a + b) / 2;
3987         chud_bs(P, Q, G, a, c, 1, prec);
3988         chud_bs(&P2, &Q2, &G2, c, b, need_g, prec);
3989 
3990         /* Q = Q1 * Q2 */
3991         /* G = G1 * G2 */
3992         /* P = P1 * Q2 + P2 * G1 */
3993         bf_mul(&P2, &P2, G, prec, BF_RNDN);
3994         if (!need_g)
3995             bf_set_ui(G, 0);
3996         bf_mul(P, P, &Q2, prec, BF_RNDN);
3997         bf_add(P, P, &P2, prec, BF_RNDN);
3998         bf_delete(&P2);
3999 
4000         bf_mul(Q, Q, &Q2, prec, BF_RNDN);
4001         bf_delete(&Q2);
4002         if (need_g)
4003             bf_mul(G, G, &G2, prec, BF_RNDN);
4004         bf_delete(&G2);
4005     }
4006 }
4007 
4008 /* compute Pi with faithful rounding at precision 'prec' using the
4009    Chudnovsky formula */
bf_const_pi_internal(bf_t * Q,limb_t prec)4010 static void bf_const_pi_internal(bf_t *Q, limb_t prec)
4011 {
4012     bf_context_t *s = Q->ctx;
4013     int64_t n, prec1;
4014     bf_t P, G;
4015 
4016     /* number of serie terms */
4017     n = prec / CHUD_BITS_PER_TERM + 1;
4018     /* XXX: precision analysis */
4019     prec1 = prec + 32;
4020 
4021     bf_init(s, &P);
4022     bf_init(s, &G);
4023 
4024     chud_bs(&P, Q, &G, 0, n, 0, BF_PREC_INF);
4025 
4026     bf_mul_ui(&G, Q, CHUD_A, prec1, BF_RNDN);
4027     bf_add(&P, &G, &P, prec1, BF_RNDN);
4028     bf_div(Q, Q, &P, prec1, BF_RNDF);
4029 
4030     bf_set_ui(&P, CHUD_C);
4031     bf_sqrt(&G, &P, prec1, BF_RNDF);
4032     bf_mul_ui(&G, &G, (uint64_t)CHUD_C / 12, prec1, BF_RNDF);
4033     bf_mul(Q, Q, &G, prec, BF_RNDN);
4034     bf_delete(&P);
4035     bf_delete(&G);
4036 }
4037 
bf_const_get(bf_t * T,limb_t prec,bf_flags_t flags,BFConstCache * c,void (* func)(bf_t * res,limb_t prec))4038 static int bf_const_get(bf_t *T, limb_t prec, bf_flags_t flags,
4039                         BFConstCache *c,
4040                         void (*func)(bf_t *res, limb_t prec))
4041 {
4042     limb_t ziv_extra_bits, prec1;
4043 
4044     ziv_extra_bits = 32;
4045     for(;;) {
4046         prec1 = prec + ziv_extra_bits;
4047         if (c->prec < prec1) {
4048             if (c->val.len == 0)
4049                 bf_init(T->ctx, &c->val);
4050             func(&c->val, prec1);
4051             c->prec = prec1;
4052         } else {
4053             prec1 = c->prec;
4054         }
4055         bf_set(T, &c->val);
4056         if (!bf_can_round(T, prec, flags & BF_RND_MASK, prec1)) {
4057             /* and more precision and retry */
4058             ziv_extra_bits = ziv_extra_bits  + (ziv_extra_bits / 2);
4059         } else {
4060             break;
4061         }
4062     }
4063     return bf_round(T, prec, flags);
4064 }
4065 
bf_const_free(BFConstCache * c)4066 static void bf_const_free(BFConstCache *c)
4067 {
4068     bf_delete(&c->val);
4069     memset(c, 0, sizeof(*c));
4070 }
4071 
bf_const_log2(bf_t * T,limb_t prec,bf_flags_t flags)4072 int bf_const_log2(bf_t *T, limb_t prec, bf_flags_t flags)
4073 {
4074     bf_context_t *s = T->ctx;
4075     return bf_const_get(T, prec, flags, &s->log2_cache, bf_const_log2_internal);
4076 }
4077 
bf_const_pi(bf_t * T,limb_t prec,bf_flags_t flags)4078 int bf_const_pi(bf_t *T, limb_t prec, bf_flags_t flags)
4079 {
4080     bf_context_t *s = T->ctx;
4081     return bf_const_get(T, prec, flags, &s->pi_cache, bf_const_pi_internal);
4082 }
4083 
bf_clear_cache(bf_context_t * s)4084 void bf_clear_cache(bf_context_t *s)
4085 {
4086 #ifdef USE_FFT_MUL
4087     fft_clear_cache(s);
4088 #endif
4089     bf_const_free(&s->log2_cache);
4090     bf_const_free(&s->pi_cache);
4091 }
4092 
4093 /* ZivFunc should compute the result 'r' with faithful rounding at
4094    precision 'prec'. For efficiency purposes, the final bf_round()
4095    does not need to be done in the function. */
4096 typedef int ZivFunc(bf_t *r, const bf_t *a, limb_t prec, void *opaque);
4097 
bf_ziv_rounding(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags,ZivFunc * f,void * opaque)4098 static int bf_ziv_rounding(bf_t *r, const bf_t *a,
4099                            limb_t prec, bf_flags_t flags,
4100                            ZivFunc *f, void *opaque)
4101 {
4102     int rnd_mode, ret;
4103     slimb_t prec1, ziv_extra_bits;
4104 
4105     rnd_mode = flags & BF_RND_MASK;
4106     if (rnd_mode == BF_RNDF) {
4107         /* no need to iterate */
4108         f(r, a, prec, opaque);
4109         ret = 0;
4110     } else {
4111         ziv_extra_bits = 32;
4112         for(;;) {
4113             prec1 = prec + ziv_extra_bits;
4114             ret = f(r, a, prec1, opaque);
4115             if (ret & (BF_ST_OVERFLOW | BF_ST_UNDERFLOW | BF_ST_MEM_ERROR)) {
4116                 /* overflow or underflow should never happen because
4117                    it indicates the rounding cannot be done correctly,
4118                    but we do not catch all the cases */
4119                 return ret;
4120             }
4121             /* if the result is exact, we can stop */
4122             if (!(ret & BF_ST_INEXACT)) {
4123                 ret = 0;
4124                 break;
4125             }
4126             if (bf_can_round(r, prec, rnd_mode, prec1)) {
4127                 ret = BF_ST_INEXACT;
4128                 break;
4129             }
4130             ziv_extra_bits = ziv_extra_bits * 2;
4131         }
4132     }
4133     return bf_round(r, prec, flags) | ret;
4134 }
4135 
4136 /* Compute the exponential using faithful rounding at precision 'prec'.
4137    Note: the algorithm is from MPFR */
bf_exp_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4138 static int bf_exp_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4139 {
4140     bf_context_t *s = r->ctx;
4141     bf_t T_s, *T = &T_s;
4142     slimb_t n, K, l, i, prec1;
4143 
4144     assert(r != a);
4145 
4146     /* argument reduction:
4147        T = a - n*log(2) with 0 <= T < log(2) and n integer.
4148     */
4149     bf_init(s, T);
4150     if (a->expn <= -1) {
4151         /* 0 <= abs(a) <= 0.5 */
4152         if (a->sign)
4153             n = -1;
4154         else
4155             n = 0;
4156     } else {
4157         bf_const_log2(T, LIMB_BITS, BF_RNDZ);
4158         bf_div(T, a, T, LIMB_BITS, BF_RNDD);
4159         bf_get_limb(&n, T, 0);
4160     }
4161 
4162     K = bf_isqrt((prec + 1) / 2);
4163     l = (prec - 1) / K + 1;
4164     /* XXX: precision analysis ? */
4165     prec1 = prec + (K + 2 * l + 18) + K + 8;
4166     if (a->expn > 0)
4167         prec1 += a->expn;
4168     //    printf("n=%ld K=%ld prec1=%ld\n", n, K, prec1);
4169 
4170     bf_const_log2(T, prec1, BF_RNDF);
4171     bf_mul_si(T, T, n, prec1, BF_RNDN);
4172     bf_sub(T, a, T, prec1, BF_RNDN);
4173 
4174     /* reduce the range of T */
4175     bf_mul_2exp(T, -K, BF_PREC_INF, BF_RNDZ);
4176 
4177     /* Taylor expansion around zero :
4178      1 + x + x^2/2 + ... + x^n/n!
4179      = (1 + x * (1 + x/2 * (1 + ... (x/n))))
4180     */
4181     {
4182         bf_t U_s, *U = &U_s;
4183 
4184         bf_init(s, U);
4185         bf_set_ui(r, 1);
4186         for(i = l ; i >= 1; i--) {
4187             bf_set_ui(U, i);
4188             bf_div(U, T, U, prec1, BF_RNDN);
4189             bf_mul(r, r, U, prec1, BF_RNDN);
4190             bf_add_si(r, r, 1, prec1, BF_RNDN);
4191         }
4192         bf_delete(U);
4193     }
4194     bf_delete(T);
4195 
4196     /* undo the range reduction */
4197     for(i = 0; i < K; i++) {
4198         bf_mul(r, r, r, prec1, BF_RNDN);
4199     }
4200 
4201     /* undo the argument reduction */
4202     bf_mul_2exp(r, n, BF_PREC_INF, BF_RNDZ);
4203 
4204     return BF_ST_INEXACT;
4205 }
4206 
bf_exp(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4207 int bf_exp(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4208 {
4209     bf_context_t *s = r->ctx;
4210     assert(r != a);
4211     if (a->len == 0) {
4212         if (a->expn == BF_EXP_NAN) {
4213             bf_set_nan(r);
4214         } else if (a->expn == BF_EXP_INF) {
4215             if (a->sign)
4216                 bf_set_zero(r, 0);
4217             else
4218                 bf_set_inf(r, 0);
4219         } else {
4220             bf_set_ui(r, 1);
4221         }
4222         return 0;
4223     }
4224 
4225     /* crude overflow and underflow tests */
4226     if (a->expn > 0) {
4227         bf_t T_s, *T = &T_s;
4228         bf_t log2_s, *log2 = &log2_s;
4229         slimb_t e_min, e_max;
4230         e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
4231         e_min = -e_max + 3;
4232         if (flags & BF_FLAG_SUBNORMAL)
4233             e_min -= (prec - 1);
4234 
4235         bf_init(s, T);
4236         bf_init(s, log2);
4237         bf_const_log2(log2, LIMB_BITS, BF_RNDU);
4238         bf_mul_ui(T, log2, e_max, LIMB_BITS, BF_RNDU);
4239         /* a > e_max * log(2) implies exp(a) > e_max */
4240         if (bf_cmp_lt(T, a) > 0) {
4241             /* overflow */
4242             bf_delete(T);
4243             bf_delete(log2);
4244             return bf_set_overflow(r, 0, prec, flags);
4245         }
4246         /* a < e_min * log(2) implies exp(a) < e_min */
4247         bf_mul_si(T, log2, e_min, LIMB_BITS, BF_RNDD);
4248         if (bf_cmp_lt(a, T)) {
4249             int rnd_mode = flags & BF_RND_MASK;
4250 
4251             /* underflow */
4252             bf_delete(T);
4253             bf_delete(log2);
4254             if (rnd_mode == BF_RNDU) {
4255                 /* set the smallest value */
4256                 bf_set_ui(r, 1);
4257                 r->expn = e_min;
4258             } else {
4259                 bf_set_zero(r, 0);
4260             }
4261             return BF_ST_UNDERFLOW | BF_ST_INEXACT;
4262         }
4263         bf_delete(log2);
4264         bf_delete(T);
4265     }
4266 
4267     return bf_ziv_rounding(r, a, prec, flags, bf_exp_internal, NULL);
4268 }
4269 
bf_log_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4270 static int bf_log_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4271 {
4272     bf_context_t *s = r->ctx;
4273     bf_t T_s, *T = &T_s;
4274     bf_t U_s, *U = &U_s;
4275     bf_t V_s, *V = &V_s;
4276     slimb_t n, prec1, l, i, K;
4277 
4278     assert(r != a);
4279 
4280     bf_init(s, T);
4281     /* argument reduction 1 */
4282     /* T=a*2^n with 2/3 <= T <= 4/3 */
4283     {
4284         bf_t U_s, *U = &U_s;
4285         bf_set(T, a);
4286         n = T->expn;
4287         T->expn = 0;
4288         /* U= ~ 2/3 */
4289         bf_init(s, U);
4290         bf_set_ui(U, 0xaaaaaaaa);
4291         U->expn = 0;
4292         if (bf_cmp_lt(T, U)) {
4293             T->expn++;
4294             n--;
4295         }
4296         bf_delete(U);
4297     }
4298     //    printf("n=%ld\n", n);
4299     //    bf_print_str("T", T);
4300 
4301     /* XXX: precision analysis */
4302     /* number of iterations for argument reduction 2 */
4303     K = bf_isqrt((prec + 1) / 2);
4304     /* order of Taylor expansion */
4305     l = prec / (2 * K) + 1;
4306     /* precision of the intermediate computations */
4307     prec1 = prec + K + 2 * l + 32;
4308 
4309     bf_init(s, U);
4310     bf_init(s, V);
4311 
4312     /* Note: cancellation occurs here, so we use more precision (XXX:
4313        reduce the precision by computing the exact cancellation) */
4314     bf_add_si(T, T, -1, BF_PREC_INF, BF_RNDN);
4315 
4316     /* argument reduction 2 */
4317     for(i = 0; i < K; i++) {
4318         /* T = T / (1 + sqrt(1 + T)) */
4319         bf_add_si(U, T, 1, prec1, BF_RNDN);
4320         bf_sqrt(V, U, prec1, BF_RNDF);
4321         bf_add_si(U, V, 1, prec1, BF_RNDN);
4322         bf_div(T, T, U, prec1, BF_RNDN);
4323     }
4324 
4325     {
4326         bf_t Y_s, *Y = &Y_s;
4327         bf_t Y2_s, *Y2 = &Y2_s;
4328         bf_init(s, Y);
4329         bf_init(s, Y2);
4330 
4331         /* compute ln(1+x) = ln((1+y)/(1-y)) with y=x/(2+x)
4332            = y + y^3/3 + ... + y^(2*l + 1) / (2*l+1)
4333            with Y=Y^2
4334            = y*(1+Y/3+Y^2/5+...) = y*(1+Y*(1/3+Y*(1/5 + ...)))
4335         */
4336         bf_add_si(Y, T, 2, prec1, BF_RNDN);
4337         bf_div(Y, T, Y, prec1, BF_RNDN);
4338 
4339         bf_mul(Y2, Y, Y, prec1, BF_RNDN);
4340         bf_set_ui(r, 0);
4341         for(i = l; i >= 1; i--) {
4342             bf_set_ui(U, 1);
4343             bf_set_ui(V, 2 * i + 1);
4344             bf_div(U, U, V, prec1, BF_RNDN);
4345             bf_add(r, r, U, prec1, BF_RNDN);
4346             bf_mul(r, r, Y2, prec1, BF_RNDN);
4347         }
4348         bf_add_si(r, r, 1, prec1, BF_RNDN);
4349         bf_mul(r, r, Y, prec1, BF_RNDN);
4350         bf_delete(Y);
4351         bf_delete(Y2);
4352     }
4353     bf_delete(V);
4354     bf_delete(U);
4355 
4356     /* multiplication by 2 for the Taylor expansion and undo the
4357        argument reduction 2*/
4358     bf_mul_2exp(r, K + 1, BF_PREC_INF, BF_RNDZ);
4359 
4360     /* undo the argument reduction 1 */
4361     bf_const_log2(T, prec1, BF_RNDF);
4362     bf_mul_si(T, T, n, prec1, BF_RNDN);
4363     bf_add(r, r, T, prec1, BF_RNDN);
4364 
4365     bf_delete(T);
4366     return BF_ST_INEXACT;
4367 }
4368 
bf_log(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4369 int bf_log(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4370 {
4371     bf_context_t *s = r->ctx;
4372     bf_t T_s, *T = &T_s;
4373 
4374     assert(r != a);
4375     if (a->len == 0) {
4376         if (a->expn == BF_EXP_NAN) {
4377             bf_set_nan(r);
4378             return 0;
4379         } else if (a->expn == BF_EXP_INF) {
4380             if (a->sign) {
4381                 bf_set_nan(r);
4382                 return BF_ST_INVALID_OP;
4383             } else {
4384                 bf_set_inf(r, 0);
4385                 return 0;
4386             }
4387         } else {
4388             bf_set_inf(r, 1);
4389             return 0;
4390         }
4391     }
4392     if (a->sign) {
4393         bf_set_nan(r);
4394         return BF_ST_INVALID_OP;
4395     }
4396     bf_init(s, T);
4397     bf_set_ui(T, 1);
4398     if (bf_cmp_eq(a, T)) {
4399         bf_set_zero(r, 0);
4400         bf_delete(T);
4401         return 0;
4402     }
4403     bf_delete(T);
4404 
4405     return bf_ziv_rounding(r, a, prec, flags, bf_log_internal, NULL);
4406 }
4407 
4408 /* x and y finite and x > 0 */
4409 /* XXX: overflow/underflow handling */
bf_pow_generic(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4410 static int bf_pow_generic(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4411 {
4412     bf_context_t *s = r->ctx;
4413     const bf_t *y = opaque;
4414     bf_t T_s, *T = &T_s;
4415     limb_t prec1;
4416 
4417     bf_init(s, T);
4418     /* XXX: proof for the added precision */
4419     prec1 = prec + 32;
4420     bf_log(T, x, prec1, BF_RNDF);
4421     bf_mul(T, T, y, prec1, BF_RNDF);
4422     bf_exp(r, T, prec1, BF_RNDF);
4423     bf_delete(T);
4424     return BF_ST_INEXACT;
4425 }
4426 
4427 /* x and y finite, x > 0, y integer and y fits on one limb */
4428 /* XXX: overflow/underflow handling */
bf_pow_int(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4429 static int bf_pow_int(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4430 {
4431     bf_context_t *s = r->ctx;
4432     const bf_t *y = opaque;
4433     bf_t T_s, *T = &T_s;
4434     limb_t prec1;
4435     int ret;
4436     slimb_t y1;
4437 
4438     bf_get_limb(&y1, y, 0);
4439     if (y1 < 0)
4440         y1 = -y1;
4441     /* XXX: proof for the added precision */
4442     prec1 = prec + ceil_log2(y1) * 2 + 8;
4443     ret = bf_pow_ui(r, x, y1 < 0 ? -y1 : y1, prec1, BF_RNDN);
4444     if (y->sign) {
4445         bf_init(s, T);
4446         bf_set_ui(T, 1);
4447         ret |= bf_div(r, T, r, prec1, BF_RNDN);
4448         bf_delete(T);
4449     }
4450     return ret;
4451 }
4452 
4453 /* x must be a finite non zero float. Return TRUE if there is a
4454    floating point number r such as x=r^(2^n) and return this floating
4455    point number 'r'. Otherwise return FALSE and r is undefined. */
check_exact_power2n(bf_t * r,const bf_t * x,slimb_t n)4456 static BOOL check_exact_power2n(bf_t *r, const bf_t *x, slimb_t n)
4457 {
4458     bf_context_t *s = r->ctx;
4459     bf_t T_s, *T = &T_s;
4460     slimb_t e, i, er;
4461     limb_t v;
4462 
4463     /* x = m*2^e with m odd integer */
4464     e = bf_get_exp_min(x);
4465     /* fast check on the exponent */
4466     if (n > (LIMB_BITS - 1)) {
4467         if (e != 0)
4468             return FALSE;
4469         er = 0;
4470     } else {
4471         if ((e & (((limb_t)1 << n) - 1)) != 0)
4472             return FALSE;
4473         er = e >> n;
4474     }
4475     /* every perfect odd square = 1 modulo 8 */
4476     v = get_bits(x->tab, x->len, x->len * LIMB_BITS - x->expn + e);
4477     if ((v & 7) != 1)
4478         return FALSE;
4479 
4480     bf_init(s, T);
4481     bf_set(T, x);
4482     T->expn -= e;
4483     for(i = 0; i < n; i++) {
4484         if (i != 0)
4485             bf_set(T, r);
4486         if (bf_sqrtrem(r, NULL, T) != 0)
4487             return FALSE;
4488     }
4489     r->expn += er;
4490     return TRUE;
4491 }
4492 
4493 /* prec = BF_PREC_INF is accepted for x and y integers and y >= 0 */
bf_pow(bf_t * r,const bf_t * x,const bf_t * y,limb_t prec,bf_flags_t flags)4494 int bf_pow(bf_t *r, const bf_t *x, const bf_t *y, limb_t prec, bf_flags_t flags)
4495 {
4496     bf_context_t *s = r->ctx;
4497     bf_t T_s, *T = &T_s;
4498     bf_t ytmp_s;
4499     BOOL y_is_int, y_is_odd;
4500     int r_sign, ret, rnd_mode;
4501     slimb_t y_emin;
4502 
4503     if (x->len == 0 || y->len == 0) {
4504         if (y->expn == BF_EXP_ZERO) {
4505             /* pow(x, 0) = 1 */
4506             bf_set_ui(r, 1);
4507         } else if (x->expn == BF_EXP_NAN) {
4508             bf_set_nan(r);
4509         } else {
4510             int cmp_x_abs_1;
4511             bf_set_ui(r, 1);
4512             cmp_x_abs_1 = bf_cmpu(x, r);
4513             if (cmp_x_abs_1 == 0 && (flags & BF_POW_JS_QUICKS) &&
4514                 (y->expn >= BF_EXP_INF)) {
4515                 bf_set_nan(r);
4516             } else if (cmp_x_abs_1 == 0 &&
4517                        (!x->sign || y->expn != BF_EXP_NAN)) {
4518                 /* pow(1, y) = 1 even if y = NaN */
4519                 /* pow(-1, +/-inf) = 1 */
4520             } else if (y->expn == BF_EXP_NAN) {
4521                 bf_set_nan(r);
4522             } else if (y->expn == BF_EXP_INF) {
4523                 if (y->sign == (cmp_x_abs_1 > 0)) {
4524                     bf_set_zero(r, 0);
4525                 } else {
4526                     bf_set_inf(r, 0);
4527                 }
4528             } else {
4529                 y_emin = bf_get_exp_min(y);
4530                 y_is_odd = (y_emin == 0);
4531                 if (y->sign == (x->expn == BF_EXP_ZERO)) {
4532                     bf_set_inf(r, y_is_odd & x->sign);
4533                     if (y->sign) {
4534                         /* pow(0, y) with y < 0 */
4535                         return BF_ST_DIVIDE_ZERO;
4536                     }
4537                 } else {
4538                     bf_set_zero(r, y_is_odd & x->sign);
4539                 }
4540             }
4541         }
4542         return 0;
4543     }
4544     bf_init(s, T);
4545     bf_set(T, x);
4546     y_emin = bf_get_exp_min(y);
4547     y_is_int = (y_emin >= 0);
4548     rnd_mode = flags & BF_RND_MASK;
4549     if (x->sign) {
4550         if (!y_is_int) {
4551             bf_set_nan(r);
4552             bf_delete(T);
4553             return BF_ST_INVALID_OP;
4554         }
4555         y_is_odd = (y_emin == 0);
4556         r_sign = y_is_odd;
4557         /* change the directed rounding mode if the sign of the result
4558            is changed */
4559         if (r_sign && (rnd_mode == BF_RNDD || rnd_mode == BF_RNDU))
4560             flags ^= 1;
4561         bf_neg(T);
4562     } else {
4563         r_sign = 0;
4564     }
4565 
4566     bf_set_ui(r, 1);
4567     if (bf_cmp_eq(T, r)) {
4568         /* abs(x) = 1: nothing more to do */
4569         ret = 0;
4570     } else if (y_is_int) {
4571         slimb_t T_bits, e;
4572     int_pow:
4573         T_bits = T->expn - bf_get_exp_min(T);
4574         if (T_bits == 1) {
4575             /* pow(2^b, y) = 2^(b*y) */
4576             bf_mul_si(T, y, T->expn - 1, LIMB_BITS, BF_RNDZ);
4577             bf_get_limb(&e, T, 0);
4578             bf_set_ui(r, 1);
4579             ret = bf_mul_2exp(r, e, prec, flags);
4580         } else if (prec == BF_PREC_INF) {
4581             slimb_t y1;
4582             /* specific case for infinite precision (integer case) */
4583             bf_get_limb(&y1, y, 0);
4584             assert(!y->sign);
4585             /* x must be an integer, so abs(x) >= 2 */
4586             if (y1 >= ((slimb_t)1 << BF_EXP_BITS_MAX)) {
4587                 bf_delete(T);
4588                 return bf_set_overflow(r, 0, BF_PREC_INF, flags);
4589             }
4590             ret = bf_pow_ui(r, T, y1, BF_PREC_INF, BF_RNDZ);
4591         } else {
4592             if (y->expn <= 31) {
4593                 /* small enough power: use exponentiation in all cases */
4594             } else if (y->sign) {
4595                 /* cannot be exact */
4596                 goto general_case;
4597             } else {
4598                 if (rnd_mode == BF_RNDF)
4599                     goto general_case; /* no need to track exact results */
4600                 /* see if the result has a chance to be exact:
4601                    if x=a*2^b (a odd), x^y=a^y*2^(b*y)
4602                    x^y needs a precision of at least floor_log2(a)*y bits
4603                 */
4604                 bf_mul_si(r, y, T_bits - 1, LIMB_BITS, BF_RNDZ);
4605                 bf_get_limb(&e, r, 0);
4606                 if (prec < e)
4607                     goto general_case;
4608             }
4609             ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_int, (void *)y);
4610         }
4611     } else {
4612         if (rnd_mode != BF_RNDF) {
4613             bf_t *y1;
4614             if (y_emin < 0 && check_exact_power2n(r, T, -y_emin)) {
4615                 /* the problem is reduced to a power to an integer */
4616 #if 0
4617                 printf("\nn=%ld\n", -y_emin);
4618                 bf_print_str("T", T);
4619                 bf_print_str("r", r);
4620 #endif
4621                 bf_set(T, r);
4622                 y1 = &ytmp_s;
4623                 y1->tab = y->tab;
4624                 y1->len = y->len;
4625                 y1->sign = y->sign;
4626                 y1->expn = y->expn - y_emin;
4627                 y = y1;
4628                 goto int_pow;
4629             }
4630         }
4631     general_case:
4632         ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_generic, (void *)y);
4633     }
4634     bf_delete(T);
4635     r->sign = r_sign;
4636     return ret;
4637 }
4638 
4639 /* compute sqrt(-2*x-x^2) to get |sin(x)| from cos(x) - 1. */
bf_sqrt_sin(bf_t * r,const bf_t * x,limb_t prec1)4640 static void bf_sqrt_sin(bf_t *r, const bf_t *x, limb_t prec1)
4641 {
4642     bf_context_t *s = r->ctx;
4643     bf_t T_s, *T = &T_s;
4644     bf_init(s, T);
4645     bf_set(T, x);
4646     bf_mul(r, T, T, prec1, BF_RNDN);
4647     bf_mul_2exp(T, 1, BF_PREC_INF, BF_RNDZ);
4648     bf_add(T, T, r, prec1, BF_RNDN);
4649     bf_neg(T);
4650     bf_sqrt(r, T, prec1, BF_RNDF);
4651     bf_delete(T);
4652 }
4653 
bf_sincos(bf_t * s,bf_t * c,const bf_t * a,limb_t prec)4654 int bf_sincos(bf_t *s, bf_t *c, const bf_t *a, limb_t prec)
4655 {
4656     bf_context_t *s1 = a->ctx;
4657     bf_t T_s, *T = &T_s;
4658     bf_t U_s, *U = &U_s;
4659     bf_t r_s, *r = &r_s;
4660     slimb_t K, prec1, i, l, mod, prec2;
4661     int is_neg;
4662 
4663     assert(c != a && s != a);
4664     if (a->len == 0) {
4665         if (a->expn == BF_EXP_NAN) {
4666             if (c)
4667                 bf_set_nan(c);
4668             if (s)
4669                 bf_set_nan(s);
4670             return 0;
4671         } else if (a->expn == BF_EXP_INF) {
4672             if (c)
4673                 bf_set_nan(c);
4674             if (s)
4675                 bf_set_nan(s);
4676             return BF_ST_INVALID_OP;
4677         } else {
4678             if (c)
4679                 bf_set_ui(c, 1);
4680             if (s)
4681                 bf_set_zero(s, a->sign);
4682             return 0;
4683         }
4684     }
4685 
4686     bf_init(s1, T);
4687     bf_init(s1, U);
4688     bf_init(s1, r);
4689 
4690     /* XXX: precision analysis */
4691     K = bf_isqrt(prec / 2);
4692     l = prec / (2 * K) + 1;
4693     prec1 = prec + 2 * K + l + 8;
4694 
4695     /* after the modulo reduction, -pi/4 <= T <= pi/4 */
4696     if (a->expn <= -1) {
4697         /* abs(a) <= 0.25: no modulo reduction needed */
4698         bf_set(T, a);
4699         mod = 0;
4700     } else {
4701         slimb_t cancel;
4702         cancel = 0;
4703         for(;;) {
4704             prec2 = prec1 + cancel;
4705             bf_const_pi(U, prec2, BF_RNDF);
4706             bf_mul_2exp(U, -1, BF_PREC_INF, BF_RNDZ);
4707             bf_remquo(&mod, T, a, U, prec2, BF_RNDN);
4708             //            printf("T.expn=%ld prec2=%ld\n", T->expn, prec2);
4709             if (mod == 0 || (T->expn != BF_EXP_ZERO &&
4710                              (T->expn + prec2) >= (prec1 - 1)))
4711                 break;
4712             /* increase the number of bits until the precision is good enough */
4713             cancel = bf_max(-T->expn, (cancel + 1) * 3 / 2);
4714         }
4715         mod &= 3;
4716     }
4717 
4718     is_neg = T->sign;
4719 
4720     /* compute cosm1(x) = cos(x) - 1 */
4721     bf_mul(T, T, T, prec1, BF_RNDN);
4722     bf_mul_2exp(T, -2 * K, BF_PREC_INF, BF_RNDZ);
4723 
4724     /* Taylor expansion:
4725        -x^2/2 + x^4/4! - x^6/6! + ...
4726     */
4727     bf_set_ui(r, 1);
4728     for(i = l ; i >= 1; i--) {
4729         bf_set_ui(U, 2 * i - 1);
4730         bf_mul_ui(U, U, 2 * i, BF_PREC_INF, BF_RNDZ);
4731         bf_div(U, T, U, prec1, BF_RNDN);
4732         bf_mul(r, r, U, prec1, BF_RNDN);
4733         bf_neg(r);
4734         if (i != 1)
4735             bf_add_si(r, r, 1, prec1, BF_RNDN);
4736     }
4737     bf_delete(U);
4738 
4739     /* undo argument reduction:
4740        cosm1(2*x)= 2*(2*cosm1(x)+cosm1(x)^2)
4741     */
4742     for(i = 0; i < K; i++) {
4743         bf_mul(T, r, r, prec1, BF_RNDN);
4744         bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4745         bf_add(r, r, T, prec1, BF_RNDN);
4746         bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4747     }
4748     bf_delete(T);
4749 
4750     if (c) {
4751         if ((mod & 1) == 0) {
4752             bf_add_si(c, r, 1, prec1, BF_RNDN);
4753         } else {
4754             bf_sqrt_sin(c, r, prec1);
4755             c->sign = is_neg ^ 1;
4756         }
4757         c->sign ^= mod >> 1;
4758     }
4759     if (s) {
4760         if ((mod & 1) == 0) {
4761             bf_sqrt_sin(s, r, prec1);
4762             s->sign = is_neg;
4763         } else {
4764             bf_add_si(s, r, 1, prec1, BF_RNDN);
4765         }
4766         s->sign ^= mod >> 1;
4767     }
4768     bf_delete(r);
4769     return BF_ST_INEXACT;
4770 }
4771 
bf_cos_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4772 static int bf_cos_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4773 {
4774     return bf_sincos(NULL, r, a, prec);
4775 }
4776 
bf_cos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4777 int bf_cos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4778 {
4779     return bf_ziv_rounding(r, a, prec, flags, bf_cos_internal, NULL);
4780 }
4781 
bf_sin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4782 static int bf_sin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4783 {
4784     return bf_sincos(r, NULL, a, prec);
4785 }
4786 
bf_sin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4787 int bf_sin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4788 {
4789     return bf_ziv_rounding(r, a, prec, flags, bf_sin_internal, NULL);
4790 }
4791 
bf_tan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4792 static int bf_tan_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4793 {
4794     bf_context_t *s = r->ctx;
4795     bf_t T_s, *T = &T_s;
4796     limb_t prec1;
4797 
4798     if (a->len == 0) {
4799         if (a->expn == BF_EXP_NAN) {
4800             bf_set_nan(r);
4801             return 0;
4802         } else if (a->expn == BF_EXP_INF) {
4803             bf_set_nan(r);
4804             return BF_ST_INVALID_OP;
4805         } else {
4806             bf_set_zero(r, a->sign);
4807             return 0;
4808         }
4809     }
4810 
4811     /* XXX: precision analysis */
4812     prec1 = prec + 8;
4813     bf_init(s, T);
4814     bf_sincos(r, T, a, prec1);
4815     bf_div(r, r, T, prec1, BF_RNDF);
4816     bf_delete(T);
4817     return BF_ST_INEXACT;
4818 }
4819 
bf_tan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4820 int bf_tan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4821 {
4822     return bf_ziv_rounding(r, a, prec, flags, bf_tan_internal, NULL);
4823 }
4824 
4825 /* if add_pi2 is true, add pi/2 to the result (used for acos(x) to
4826    avoid cancellation) */
bf_atan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4827 static int bf_atan_internal(bf_t *r, const bf_t *a, limb_t prec,
4828                             void *opaque)
4829 {
4830     bf_context_t *s = r->ctx;
4831     BOOL add_pi2 = (BOOL)(intptr_t)opaque;
4832     bf_t T_s, *T = &T_s;
4833     bf_t U_s, *U = &U_s;
4834     bf_t V_s, *V = &V_s;
4835     bf_t X2_s, *X2 = &X2_s;
4836     int cmp_1;
4837     slimb_t prec1, i, K, l;
4838 
4839     if (a->len == 0) {
4840         if (a->expn == BF_EXP_NAN) {
4841             bf_set_nan(r);
4842             return 0;
4843         } else {
4844             if (a->expn == BF_EXP_INF)
4845                 i = 1 - 2 * a->sign;
4846             else
4847                 i = 0;
4848             i += add_pi2;
4849             /* return i*(pi/2) with -1 <= i <= 2 */
4850             if (i == 0) {
4851                 bf_set_zero(r, add_pi2 ? 0 : a->sign);
4852                 return 0;
4853             } else {
4854                 /* PI or PI/2 */
4855                 bf_const_pi(r, prec, BF_RNDF);
4856                 if (i != 2)
4857                     bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
4858                 r->sign = (i < 0);
4859                 return BF_ST_INEXACT;
4860             }
4861         }
4862     }
4863 
4864     bf_init(s, T);
4865     bf_set_ui(T, 1);
4866     cmp_1 = bf_cmpu(a, T);
4867     if (cmp_1 == 0 && !add_pi2) {
4868         /* short cut: abs(a) == 1 -> +/-pi/4 */
4869         bf_const_pi(r, prec, BF_RNDF);
4870         bf_mul_2exp(r, -2, BF_PREC_INF, BF_RNDZ);
4871         r->sign = a->sign;
4872         bf_delete(T);
4873         return BF_ST_INEXACT;
4874     }
4875 
4876     /* XXX: precision analysis */
4877     K = bf_isqrt((prec + 1) / 2);
4878     l = prec / (2 * K) + 1;
4879     prec1 = prec + K + 2 * l + 32;
4880     //    printf("prec=%ld K=%ld l=%ld prec1=%ld\n", prec, K, l, prec1);
4881 
4882     if (cmp_1 > 0) {
4883         bf_set_ui(T, 1);
4884         bf_div(T, T, a, prec1, BF_RNDN);
4885     } else {
4886         bf_set(T, a);
4887     }
4888 
4889     /* abs(T) <= 1 */
4890 
4891     /* argument reduction */
4892 
4893     bf_init(s, U);
4894     bf_init(s, V);
4895     bf_init(s, X2);
4896     for(i = 0; i < K; i++) {
4897         /* T = T / (1 + sqrt(1 + T^2)) */
4898         bf_mul(U, T, T, prec1, BF_RNDN);
4899         bf_add_si(U, U, 1, prec1, BF_RNDN);
4900         bf_sqrt(V, U, prec1, BF_RNDN);
4901         bf_add_si(V, V, 1, prec1, BF_RNDN);
4902         bf_div(T, T, V, prec1, BF_RNDN);
4903     }
4904 
4905     /* Taylor series:
4906        x - x^3/3 + ... + (-1)^ l * y^(2*l + 1) / (2*l+1)
4907     */
4908     bf_mul(X2, T, T, prec1, BF_RNDN);
4909     bf_set_ui(r, 0);
4910     for(i = l; i >= 1; i--) {
4911         bf_set_si(U, 1);
4912         bf_set_ui(V, 2 * i + 1);
4913         bf_div(U, U, V, prec1, BF_RNDN);
4914         bf_neg(r);
4915         bf_add(r, r, U, prec1, BF_RNDN);
4916         bf_mul(r, r, X2, prec1, BF_RNDN);
4917     }
4918     bf_neg(r);
4919     bf_add_si(r, r, 1, prec1, BF_RNDN);
4920     bf_mul(r, r, T, prec1, BF_RNDN);
4921 
4922     /* undo the argument reduction */
4923     bf_mul_2exp(r, K, BF_PREC_INF, BF_RNDZ);
4924 
4925     bf_delete(U);
4926     bf_delete(V);
4927     bf_delete(X2);
4928 
4929     i = add_pi2;
4930     if (cmp_1 > 0) {
4931         /* undo the inversion : r = sign(a)*PI/2 - r */
4932         bf_neg(r);
4933         i += 1 - 2 * a->sign;
4934     }
4935     /* add i*(pi/2) with -1 <= i <= 2 */
4936     if (i != 0) {
4937         bf_const_pi(T, prec1, BF_RNDF);
4938         if (i != 2)
4939             bf_mul_2exp(T, -1, BF_PREC_INF, BF_RNDZ);
4940         T->sign = (i < 0);
4941         bf_add(r, T, r, prec1, BF_RNDN);
4942     }
4943 
4944     bf_delete(T);
4945     return BF_ST_INEXACT;
4946 }
4947 
bf_atan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4948 int bf_atan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4949 {
4950     return bf_ziv_rounding(r, a, prec, flags, bf_atan_internal, (void *)FALSE);
4951 }
4952 
bf_atan2_internal(bf_t * r,const bf_t * y,limb_t prec,void * opaque)4953 static int bf_atan2_internal(bf_t *r, const bf_t *y, limb_t prec, void *opaque)
4954 {
4955     bf_context_t *s = r->ctx;
4956     const bf_t *x = opaque;
4957     bf_t T_s, *T = &T_s;
4958     limb_t prec1;
4959     int ret;
4960 
4961     if (y->expn == BF_EXP_NAN || x->expn == BF_EXP_NAN) {
4962         bf_set_nan(r);
4963         return 0;
4964     }
4965 
4966     /* compute atan(y/x) assumming inf/inf = 1 and 0/0 = 0 */
4967     bf_init(s, T);
4968     prec1 = prec + 32;
4969     if (y->expn == BF_EXP_INF && x->expn == BF_EXP_INF) {
4970         bf_set_ui(T, 1);
4971         T->sign = y->sign ^ x->sign;
4972     } else if (y->expn == BF_EXP_ZERO && x->expn == BF_EXP_ZERO) {
4973         bf_set_zero(T, y->sign ^ x->sign);
4974     } else {
4975         bf_div(T, y, x, prec1, BF_RNDF);
4976     }
4977     ret = bf_atan(r, T, prec1, BF_RNDF);
4978 
4979     if (x->sign) {
4980         /* if x < 0 (it includes -0), return sign(y)*pi + atan(y/x) */
4981         bf_const_pi(T, prec1, BF_RNDF);
4982         T->sign = y->sign;
4983         bf_add(r, r, T, prec1, BF_RNDN);
4984         ret |= BF_ST_INEXACT;
4985     }
4986 
4987     bf_delete(T);
4988     return ret;
4989 }
4990 
bf_atan2(bf_t * r,const bf_t * y,const bf_t * x,limb_t prec,bf_flags_t flags)4991 int bf_atan2(bf_t *r, const bf_t *y, const bf_t *x,
4992              limb_t prec, bf_flags_t flags)
4993 {
4994     return bf_ziv_rounding(r, y, prec, flags, bf_atan2_internal, (void *)x);
4995 }
4996 
bf_asin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4997 static int bf_asin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4998 {
4999     bf_context_t *s = r->ctx;
5000     BOOL is_acos = (BOOL)(intptr_t)opaque;
5001     bf_t T_s, *T = &T_s;
5002     limb_t prec1, prec2;
5003     int res;
5004 
5005     if (a->len == 0) {
5006         if (a->expn == BF_EXP_NAN) {
5007             bf_set_nan(r);
5008             return 0;
5009         } else if (a->expn == BF_EXP_INF) {
5010             bf_set_nan(r);
5011             return BF_ST_INVALID_OP;
5012         } else {
5013             if (is_acos) {
5014                 bf_const_pi(r, prec, BF_RNDF);
5015                 bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
5016                 return BF_ST_INEXACT;
5017             } else {
5018                 bf_set_zero(r, a->sign);
5019                 return 0;
5020             }
5021         }
5022     }
5023     bf_init(s, T);
5024     bf_set_ui(T, 1);
5025     res = bf_cmpu(a, T);
5026     if (res > 0) {
5027         bf_delete(T);
5028         bf_set_nan(r);
5029         return BF_ST_INVALID_OP;
5030     } else if (res == 0 && a->sign == 0 && is_acos) {
5031         bf_set_zero(r, 0);
5032         bf_delete(T);
5033         return 0;
5034     }
5035 
5036     /* asin(x) = atan(x/sqrt(1-x^2))
5037        acos(x) = pi/2 - asin(x) */
5038     prec1 = prec + 8;
5039     /* increase the precision in x^2 to compensate the cancellation in
5040        (1-x^2) if x is close to 1 */
5041     /* XXX: use less precision when possible */
5042     if (a->expn >= 0)
5043         prec2 = BF_PREC_INF;
5044     else
5045         prec2 = prec1;
5046     bf_mul(T, a, a, prec2, BF_RNDN);
5047     bf_neg(T);
5048     bf_add_si(T, T, 1, prec2, BF_RNDN);
5049 
5050     bf_sqrt(r, T, prec1, BF_RNDN);
5051     bf_div(T, a, r, prec1, BF_RNDN);
5052     if (is_acos)
5053         bf_neg(T);
5054     bf_atan_internal(r, T, prec1, (void *)(intptr_t)is_acos);
5055     bf_delete(T);
5056     return BF_ST_INEXACT;
5057 }
5058 
bf_asin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5059 int bf_asin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5060 {
5061     return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)FALSE);
5062 }
5063 
bf_acos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5064 int bf_acos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5065 {
5066     return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)TRUE);
5067 }
5068 
5069 /***************************************************************/
5070 /* decimal floating point numbers */
5071 
5072 #ifdef USE_BF_DEC
5073 
5074 #define adddq(r1, r0, a1, a0)                   \
5075     do {                                        \
5076         limb_t __t = r0;                        \
5077         r0 += (a0);                             \
5078         r1 += (a1) + (r0 < __t);                \
5079     } while (0)
5080 
5081 #define subdq(r1, r0, a1, a0)                   \
5082     do {                                        \
5083         limb_t __t = r0;                        \
5084         r0 -= (a0);                             \
5085         r1 -= (a1) + (r0 > __t);                \
5086     } while (0)
5087 
5088 #if LIMB_BITS == 64
5089 
5090 /* Note: we assume __int128 is available */
5091 #define muldq(r1, r0, a, b)                     \
5092     do {                                        \
5093         unsigned __int128 __t;                          \
5094         __t = (unsigned __int128)(a) * (unsigned __int128)(b);  \
5095         r0 = __t;                               \
5096         r1 = __t >> 64;                         \
5097     } while (0)
5098 
5099 #define divdq(q, r, a1, a0, b)                  \
5100     do {                                        \
5101         unsigned __int128 __t;                  \
5102         limb_t __b = (b);                       \
5103         __t = ((unsigned __int128)(a1) << 64) | (a0);   \
5104         q = __t / __b;                                  \
5105         r = __t % __b;                                  \
5106     } while (0)
5107 
5108 #else
5109 
5110 #define muldq(r1, r0, a, b)                     \
5111     do {                                        \
5112         uint64_t __t;                          \
5113         __t = (uint64_t)(a) * (uint64_t)(b);  \
5114         r0 = __t;                               \
5115         r1 = __t >> 32;                         \
5116     } while (0)
5117 
5118 #define divdq(q, r, a1, a0, b)                  \
5119     do {                                        \
5120         uint64_t __t;                  \
5121         limb_t __b = (b);                       \
5122         __t = ((uint64_t)(a1) << 32) | (a0);   \
5123         q = __t / __b;                                  \
5124         r = __t % __b;                                  \
5125     } while (0)
5126 
5127 #endif /* LIMB_BITS != 64 */
5128 
5129 #if 0 //unused
5130 static inline limb_t shrd(limb_t low, limb_t high, long shift)
5131 {
5132     if (shift != 0)
5133         low = (low >> shift) | (high << (LIMB_BITS - shift));
5134     return low;
5135 }
5136 #endif
5137 
shld(limb_t a1,limb_t a0,long shift)5138 static inline limb_t shld(limb_t a1, limb_t a0, long shift)
5139 {
5140     if (shift != 0)
5141         return (a1 << shift) | (a0 >> (LIMB_BITS - shift));
5142     else
5143         return a1;
5144 }
5145 
5146 #if LIMB_DIGITS == 19
5147 
5148 /* WARNING: hardcoded for b = 1e19. It is assumed that:
5149    0 <= a1 < 2^63 */
5150 #define divdq_base(q, r, a1, a0)\
5151 do {\
5152     uint64_t __a0, __a1, __t0, __t1, __b = BF_DEC_BASE; \
5153     __a0 = a0;\
5154     __a1 = a1;\
5155     __t0 = __a1;\
5156     __t0 = shld(__t0, __a0, 1);\
5157     muldq(q, __t1, __t0, UINT64_C(17014118346046923173)); \
5158     muldq(__t1, __t0, q, __b);\
5159     subdq(__a1, __a0, __t1, __t0);\
5160     subdq(__a1, __a0, 1, __b * 2);    \
5161     __t0 = (slimb_t)__a1 >> 1; \
5162     q += 2 + __t0;\
5163     adddq(__a1, __a0, 0, __b & __t0);\
5164     q += __a1;                  \
5165     __a0 += __b & __a1;           \
5166     r = __a0;\
5167 } while(0)
5168 
5169 #elif LIMB_DIGITS == 9
5170 
5171 /* WARNING: hardcoded for b = 1e9. It is assumed that:
5172    0 <= a1 < 2^29 */
5173 #define divdq_base(q, r, a1, a0)\
5174 do {\
5175     uint32_t __t0, __t1, __b = BF_DEC_BASE; \
5176     __t0 = a1;\
5177     __t1 = a0;\
5178     __t0 = (__t0 << 3) | (__t1 >> (32 - 3));    \
5179     muldq(q, __t1, __t0, 2305843009U);\
5180     r = a0 - q * __b;\
5181     __t1 = (r >= __b);\
5182     q += __t1;\
5183     if (__t1)\
5184         r -= __b;\
5185 } while(0)
5186 
5187 #endif
5188 
5189 /* fast integer division by a fixed constant */
5190 
5191 typedef struct FastDivData {
5192     limb_t d; /* divisor (only user visible field) */
5193     limb_t m1; /* multiplier */
5194     int shift1;
5195     int shift2;
5196 } FastDivData;
5197 
5198 #if 1
5199 /* From "Division by Invariant Integers using Multiplication" by
5200    Torborn Granlund and Peter L. Montgomery */
5201 /* d must be != 0 */
fast_udiv_init(FastDivData * s,limb_t d)5202 static inline void fast_udiv_init(FastDivData *s, limb_t d)
5203 {
5204     int l;
5205     limb_t q, r, m1;
5206     s->d = d;
5207     if (d == 1)
5208         l = 0;
5209     else
5210         l = 64 - clz64(d - 1);
5211     divdq(q, r, ((limb_t)1 << l) - d, 0, d);
5212     (void)r;
5213     m1 = q + 1;
5214     //    printf("d=%lu l=%d m1=0x%016lx\n", d, l, m1);
5215     s->m1 = m1;
5216     s->shift1 = l;
5217     if (s->shift1 > 1)
5218         s->shift1 = 1;
5219     s->shift2 = l - 1;
5220     if (s->shift2 < 0)
5221         s->shift2 = 0;
5222 }
5223 
fast_udiv(limb_t a,const FastDivData * s)5224 static inline limb_t fast_udiv(limb_t a, const FastDivData *s)
5225 {
5226     limb_t t0, t1;
5227     muldq(t1, t0, s->m1, a);
5228     t0 = (a - t1) >> s->shift1;
5229     return (t1 + t0) >> s->shift2;
5230 }
5231 
5232 #if 0 //unused
5233 static inline limb_t fast_urem(limb_t a, const FastDivData *s)
5234 {
5235     limb_t q;
5236     q = fast_udiv(a, s);
5237     return a - q * s->d;
5238 }
5239 #endif
5240 
5241 #define fast_udivrem(q, r, a, s) q = fast_udiv(a, s), r = a - q * (s)->d
5242 
5243 #else
5244 
fast_udiv_init(FastDivData * s,limb_t d)5245 static inline void fast_udiv_init(FastDivData *s, limb_t d)
5246 {
5247     s->d = d;
5248 }
fast_udiv(limb_t a,const FastDivData * s)5249 static inline limb_t fast_udiv(limb_t a, const FastDivData *s)
5250 {
5251     return a / s->d;
5252 }
fast_urem(limb_t a,const FastDivData * s)5253 static inline limb_t fast_urem(limb_t a, const FastDivData *s)
5254 {
5255     return a % s->d;
5256 }
5257 
5258 #define fast_udivrem(q, r, a, s) q = a / (s)->d, r = a % (s)->d
5259 
5260 #endif
5261 
5262 /* contains 10^i */
5263 /* XXX: make it const */
5264 limb_t mp_pow_dec[LIMB_DIGITS + 1];
5265 static FastDivData mp_pow_div[LIMB_DIGITS + 1];
5266 
mp_pow_init(void)5267 static void mp_pow_init(void)
5268 {
5269     limb_t a;
5270     int i;
5271 
5272     a = 1;
5273     for(i = 0; i <= LIMB_DIGITS; i++) {
5274         mp_pow_dec[i] = a;
5275         fast_udiv_init(&mp_pow_div[i], a);
5276         a = a * 10;
5277     }
5278 }
5279 
mp_add_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5280 limb_t mp_add_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5281                   mp_size_t n, limb_t carry)
5282 {
5283     limb_t base = BF_DEC_BASE;
5284     mp_size_t i;
5285     limb_t k, a, v;
5286 
5287     k=carry;
5288     for(i=0;i<n;i++) {
5289         /* XXX: reuse the trick in add_mod */
5290         v = op1[i];
5291         a = v + op2[i] + k - base;
5292         k = a <= v;
5293         if (!k)
5294             a += base;
5295         res[i]=a;
5296     }
5297     return k;
5298 }
5299 
mp_add_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5300 limb_t mp_add_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5301 {
5302     limb_t base = BF_DEC_BASE;
5303     mp_size_t i;
5304     limb_t k, a, v;
5305 
5306     k=b;
5307     for(i=0;i<n;i++) {
5308         v = tab[i];
5309         a = v + k - base;
5310         k = a <= v;
5311         if (!k)
5312             a += base;
5313         tab[i] = a;
5314         if (k == 0)
5315             break;
5316     }
5317     return k;
5318 }
5319 
mp_sub_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5320 limb_t mp_sub_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5321                   mp_size_t n, limb_t carry)
5322 {
5323     limb_t base = BF_DEC_BASE;
5324     mp_size_t i;
5325     limb_t k, v, a;
5326 
5327     k=carry;
5328     for(i=0;i<n;i++) {
5329         v = op1[i];
5330         a = v - op2[i] - k;
5331         k = a > v;
5332         if (k)
5333             a += base;
5334         res[i] = a;
5335     }
5336     return k;
5337 }
5338 
mp_sub_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5339 limb_t mp_sub_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5340 {
5341     limb_t base = BF_DEC_BASE;
5342     mp_size_t i;
5343     limb_t k, v, a;
5344 
5345     k=b;
5346     for(i=0;i<n;i++) {
5347         v = tab[i];
5348         a = v - k;
5349         k = a > v;
5350         if (k)
5351             a += base;
5352         tab[i]=a;
5353         if (k == 0)
5354             break;
5355     }
5356     return k;
5357 }
5358 
5359 /* taba[] = taba[] * b + l. 0 <= b, l <= base - 1. Return the high carry */
mp_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b,limb_t l)5360 limb_t mp_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5361                    limb_t b, limb_t l)
5362 {
5363     mp_size_t i;
5364     limb_t t0, t1, r;
5365 
5366     for(i = 0; i < n; i++) {
5367         muldq(t1, t0, taba[i], b);
5368         adddq(t1, t0, 0, l);
5369         divdq_base(l, r, t1, t0);
5370         tabr[i] = r;
5371     }
5372     return l;
5373 }
5374 
5375 /* tabr[] += taba[] * b. 0 <= b <= base - 1. Return the value to add
5376    to the high word */
mp_add_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5377 limb_t mp_add_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5378                        limb_t b)
5379 {
5380     mp_size_t i;
5381     limb_t l, t0, t1, r;
5382 
5383     l = 0;
5384     for(i = 0; i < n; i++) {
5385         muldq(t1, t0, taba[i], b);
5386         adddq(t1, t0, 0, l);
5387         adddq(t1, t0, 0, tabr[i]);
5388         divdq_base(l, r, t1, t0);
5389         tabr[i] = r;
5390     }
5391     return l;
5392 }
5393 
5394 /* tabr[] -= taba[] * b. 0 <= b <= base - 1. Return the value to
5395    substract to the high word. */
mp_sub_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5396 limb_t mp_sub_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5397                        limb_t b)
5398 {
5399     limb_t base = BF_DEC_BASE;
5400     mp_size_t i;
5401     limb_t l, t0, t1, r, a, v, c;
5402 
5403     /* XXX: optimize */
5404     l = 0;
5405     for(i = 0; i < n; i++) {
5406         muldq(t1, t0, taba[i], b);
5407         adddq(t1, t0, 0, l);
5408         divdq_base(l, r, t1, t0);
5409         v = tabr[i];
5410         a = v - r;
5411         c = a > v;
5412         if (c)
5413             a += base;
5414         /* never bigger than base because r = 0 when l = base - 1 */
5415         l += c;
5416         tabr[i] = a;
5417     }
5418     return l;
5419 }
5420 
5421 /* size of the result : op1_size + op2_size. */
mp_mul_basecase_dec(limb_t * result,const limb_t * op1,mp_size_t op1_size,const limb_t * op2,mp_size_t op2_size)5422 void mp_mul_basecase_dec(limb_t *result,
5423                          const limb_t *op1, mp_size_t op1_size,
5424                          const limb_t *op2, mp_size_t op2_size)
5425 {
5426     mp_size_t i;
5427     limb_t r;
5428 
5429     result[op1_size] = mp_mul1_dec(result, op1, op1_size, op2[0], 0);
5430 
5431     for(i=1;i<op2_size;i++) {
5432         r = mp_add_mul1_dec(result + i, op1, op1_size, op2[i]);
5433         result[i + op1_size] = r;
5434     }
5435 }
5436 
5437 /* taba[] = (taba[] + r*base^na) / b. 0 <= b < base. 0 <= r <
5438    b. Return the remainder. */
mp_div1_dec(limb_t * tabr,const limb_t * taba,mp_size_t na,limb_t b,limb_t r)5439 limb_t mp_div1_dec(limb_t *tabr, const limb_t *taba, mp_size_t na,
5440                    limb_t b, limb_t r)
5441 {
5442     limb_t base = BF_DEC_BASE;
5443     mp_size_t i;
5444     limb_t t0, t1, q;
5445     int shift;
5446 
5447 #if (BF_DEC_BASE % 2) == 0
5448     if (b == 2) {
5449         limb_t base_div2;
5450         /* Note: only works if base is even */
5451         base_div2 = base >> 1;
5452         if (r)
5453             r = base_div2;
5454         for(i = na - 1; i >= 0; i--) {
5455             t0 = taba[i];
5456             tabr[i] = (t0 >> 1) + r;
5457             r = 0;
5458             if (t0 & 1)
5459                 r = base_div2;
5460         }
5461         if (r)
5462             r = 1;
5463     } else
5464 #endif
5465     if (na >= UDIV1NORM_THRESHOLD) {
5466         shift = clz(b);
5467         if (shift == 0) {
5468             /* normalized case: b >= 2^(LIMB_BITS-1) */
5469             limb_t b_inv;
5470             b_inv = udiv1norm_init(b);
5471             for(i = na - 1; i >= 0; i--) {
5472                 muldq(t1, t0, r, base);
5473                 adddq(t1, t0, 0, taba[i]);
5474                 q = udiv1norm(&r, t1, t0, b, b_inv);
5475                 tabr[i] = q;
5476             }
5477         } else {
5478             limb_t b_inv;
5479             b <<= shift;
5480             b_inv = udiv1norm_init(b);
5481             for(i = na - 1; i >= 0; i--) {
5482                 muldq(t1, t0, r, base);
5483                 adddq(t1, t0, 0, taba[i]);
5484                 t1 = (t1 << shift) | (t0 >> (LIMB_BITS - shift));
5485                 t0 <<= shift;
5486                 q = udiv1norm(&r, t1, t0, b, b_inv);
5487                 r >>= shift;
5488                 tabr[i] = q;
5489             }
5490         }
5491     } else {
5492         for(i = na - 1; i >= 0; i--) {
5493             muldq(t1, t0, r, base);
5494             adddq(t1, t0, 0, taba[i]);
5495             divdq(q, r, t1, t0, b);
5496             tabr[i] = q;
5497         }
5498     }
5499     return r;
5500 }
5501 
mp_print_str_dec(const char * str,const limb_t * tab,slimb_t n)5502 static __maybe_unused void mp_print_str_dec(const char *str,
5503                                        const limb_t *tab, slimb_t n)
5504 {
5505     slimb_t i;
5506     printf("%s=", str);
5507     for(i = n - 1; i >= 0; i--) {
5508         if (i != n - 1)
5509             printf("_");
5510         printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5511     }
5512     printf("\n");
5513 }
5514 
mp_print_str_h_dec(const char * str,const limb_t * tab,slimb_t n,limb_t high)5515 static __maybe_unused void mp_print_str_h_dec(const char *str,
5516                                               const limb_t *tab, slimb_t n,
5517                                               limb_t high)
5518 {
5519     slimb_t i;
5520     printf("%s=", str);
5521     printf("%0*" PRIu_LIMB, LIMB_DIGITS, high);
5522     for(i = n - 1; i >= 0; i--) {
5523         printf("_");
5524         printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5525     }
5526     printf("\n");
5527 }
5528 
5529 //#define DEBUG_DIV_SLOW
5530 
5531 #define DIV_STATIC_ALLOC_LEN 16
5532 
5533 /* return q = a / b and r = a % b.
5534 
5535    taba[na] must be allocated if tabb1[nb - 1] < B / 2.  tabb1[nb - 1]
5536    must be != zero. na must be >= nb. 's' can be NULL if tabb1[nb - 1]
5537    >= B / 2.
5538 
5539    The remainder is is returned in taba and contains nb libms. tabq
5540    contains na - nb + 1 limbs. No overlap is permitted.
5541 
5542    Running time of the standard method: (na - nb + 1) * nb
5543    Return 0 if OK, -1 if memory alloc error
5544 */
5545 /* XXX: optimize */
mp_div_dec(bf_context_t * s,limb_t * tabq,limb_t * taba,mp_size_t na,const limb_t * tabb1,mp_size_t nb)5546 static int mp_div_dec(bf_context_t *s, limb_t *tabq,
5547                       limb_t *taba, mp_size_t na,
5548                       const limb_t *tabb1, mp_size_t nb)
5549 {
5550     limb_t base = BF_DEC_BASE;
5551     limb_t r, mult, t0, t1, a, c, q, v, *tabb;
5552     mp_size_t i, j;
5553     limb_t static_tabb[DIV_STATIC_ALLOC_LEN];
5554 
5555 #ifdef DEBUG_DIV_SLOW
5556     mp_print_str_dec("a", taba, na);
5557     mp_print_str_dec("b", tabb1, nb);
5558 #endif
5559 
5560     /* normalize tabb */
5561     r = tabb1[nb - 1];
5562     assert(r != 0);
5563     i = na - nb;
5564     if (r >= BF_DEC_BASE / 2) {
5565         mult = 1;
5566         tabb = (limb_t *)tabb1;
5567         q = 1;
5568         for(j = nb - 1; j >= 0; j--) {
5569             if (taba[i + j] != tabb[j]) {
5570                 if (taba[i + j] < tabb[j])
5571                     q = 0;
5572                 break;
5573             }
5574         }
5575         tabq[i] = q;
5576         if (q) {
5577             mp_sub_dec(taba + i, taba + i, tabb, nb, 0);
5578         }
5579         i--;
5580     } else {
5581         mult = base / (r + 1);
5582         if (likely(nb <= DIV_STATIC_ALLOC_LEN)) {
5583             tabb = static_tabb;
5584         } else {
5585             tabb = bf_malloc(s, sizeof(limb_t) * nb);
5586             if (!tabb)
5587                 return -1;
5588         }
5589         mp_mul1_dec(tabb, tabb1, nb, mult, 0);
5590         taba[na] = mp_mul1_dec(taba, taba, na, mult, 0);
5591     }
5592 
5593 #ifdef DEBUG_DIV_SLOW
5594     printf("mult=" FMT_LIMB "\n", mult);
5595     mp_print_str_dec("a_norm", taba, na + 1);
5596     mp_print_str_dec("b_norm", tabb, nb);
5597 #endif
5598 
5599     for(; i >= 0; i--) {
5600         if (unlikely(taba[i + nb] >= tabb[nb - 1])) {
5601             /* XXX: check if it is really possible */
5602             q = base - 1;
5603         } else {
5604             muldq(t1, t0, taba[i + nb], base);
5605             adddq(t1, t0, 0, taba[i + nb - 1]);
5606             divdq(q, r, t1, t0, tabb[nb - 1]);
5607         }
5608         //        printf("i=%d q1=%ld\n", i, q);
5609 
5610         r = mp_sub_mul1_dec(taba + i, tabb, nb, q);
5611         //        mp_dump("r1", taba + i, nb, bd);
5612         //        printf("r2=%ld\n", r);
5613 
5614         v = taba[i + nb];
5615         a = v - r;
5616         c = a > v;
5617         if (c)
5618             a += base;
5619         taba[i + nb] = a;
5620 
5621         if (c != 0) {
5622             /* negative result */
5623             for(;;) {
5624                 q--;
5625                 c = mp_add_dec(taba + i, taba + i, tabb, nb, 0);
5626                 /* propagate carry and test if positive result */
5627                 if (c != 0) {
5628                     if (++taba[i + nb] == base) {
5629                         break;
5630                     }
5631                 }
5632             }
5633         }
5634         tabq[i] = q;
5635     }
5636 
5637 #ifdef DEBUG_DIV_SLOW
5638     mp_print_str_dec("q", tabq, na - nb + 1);
5639     mp_print_str_dec("r", taba, nb);
5640 #endif
5641 
5642     /* remove the normalization */
5643     if (mult != 1) {
5644         mp_div1_dec(taba, taba, nb, mult, 0);
5645         if (unlikely(tabb != static_tabb))
5646             bf_free(s, tabb);
5647     }
5648     return 0;
5649 }
5650 
5651 /* divide by 10^shift */
mp_shr_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t high)5652 static limb_t mp_shr_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5653                          limb_t shift, limb_t high)
5654 {
5655     mp_size_t i;
5656     limb_t l, a, q, r;
5657 
5658     assert(shift >= 1 && shift < LIMB_DIGITS);
5659     l = high;
5660     for(i = n - 1; i >= 0; i--) {
5661         a = tab[i];
5662         fast_udivrem(q, r, a, &mp_pow_div[shift]);
5663         tab_r[i] = q + l * mp_pow_dec[LIMB_DIGITS - shift];
5664         l = r;
5665     }
5666     return l;
5667 }
5668 
5669 /* multiply by 10^shift */
mp_shl_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t low)5670 static limb_t mp_shl_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5671                          limb_t shift, limb_t low)
5672 {
5673     mp_size_t i;
5674     limb_t l, a, q, r;
5675 
5676     assert(shift >= 1 && shift < LIMB_DIGITS);
5677     l = low;
5678     for(i = 0; i < n; i++) {
5679         a = tab[i];
5680         fast_udivrem(q, r, a, &mp_pow_div[LIMB_DIGITS - shift]);
5681         tab_r[i] = r * mp_pow_dec[shift] + l;
5682         l = q;
5683     }
5684     return l;
5685 }
5686 
mp_sqrtrem2_dec(limb_t * tabs,limb_t * taba)5687 static limb_t mp_sqrtrem2_dec(limb_t *tabs, limb_t *taba)
5688 {
5689     int k;
5690     dlimb_t a, b, r;
5691     limb_t taba1[2], s, r0, r1;
5692 
5693     /* convert to binary and normalize */
5694     a = (dlimb_t)taba[1] * BF_DEC_BASE + taba[0];
5695     k = clz(a >> LIMB_BITS) & ~1;
5696     b = a << k;
5697     taba1[0] = b;
5698     taba1[1] = b >> LIMB_BITS;
5699     mp_sqrtrem2(&s, taba1);
5700     s >>= (k >> 1);
5701     /* convert the remainder back to decimal */
5702     r = a - (dlimb_t)s * (dlimb_t)s;
5703     divdq_base(r1, r0, r >> LIMB_BITS, r);
5704     taba[0] = r0;
5705     tabs[0] = s;
5706     return r1;
5707 }
5708 
5709 //#define DEBUG_SQRTREM_DEC
5710 
5711 /* tmp_buf must contain (n / 2 + 1 limbs) */
mp_sqrtrem_rec_dec(limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf)5712 static limb_t mp_sqrtrem_rec_dec(limb_t *tabs, limb_t *taba, limb_t n,
5713                                  limb_t *tmp_buf)
5714 {
5715     limb_t l, h, rh, ql, qh, c, i;
5716 
5717     if (n == 1)
5718         return mp_sqrtrem2_dec(tabs, taba);
5719 #ifdef DEBUG_SQRTREM_DEC
5720     mp_print_str_dec("a", taba, 2 * n);
5721 #endif
5722     l = n / 2;
5723     h = n - l;
5724     qh = mp_sqrtrem_rec_dec(tabs + l, taba + 2 * l, h, tmp_buf);
5725 #ifdef DEBUG_SQRTREM_DEC
5726     mp_print_str_dec("s1", tabs + l, h);
5727     mp_print_str_h_dec("r1", taba + 2 * l, h, qh);
5728     mp_print_str_h_dec("r2", taba + l, n, qh);
5729 #endif
5730 
5731     /* the remainder is in taba + 2 * l. Its high bit is in qh */
5732     if (qh) {
5733         mp_sub_dec(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
5734     }
5735     /* instead of dividing by 2*s, divide by s (which is normalized)
5736        and update q and r */
5737     mp_div_dec(NULL, tmp_buf, taba + l, n, tabs + l, h);
5738     qh += tmp_buf[l];
5739     for(i = 0; i < l; i++)
5740         tabs[i] = tmp_buf[i];
5741     ql = mp_div1_dec(tabs, tabs, l, 2, qh & 1);
5742     qh = qh >> 1; /* 0 or 1 */
5743     if (ql)
5744         rh = mp_add_dec(taba + l, taba + l, tabs + l, h, 0);
5745     else
5746         rh = 0;
5747 #ifdef DEBUG_SQRTREM_DEC
5748     mp_print_str_h_dec("q", tabs, l, qh);
5749     mp_print_str_h_dec("u", taba + l, h, rh);
5750 #endif
5751 
5752     mp_add_ui_dec(tabs + l, qh, h);
5753 #ifdef DEBUG_SQRTREM_DEC
5754     mp_print_str_dec("s2", tabs, n);
5755 #endif
5756 
5757     /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
5758     /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
5759     if (qh) {
5760         c = qh;
5761     } else {
5762         mp_mul_basecase_dec(taba + n, tabs, l, tabs, l);
5763         c = mp_sub_dec(taba, taba, taba + n, 2 * l, 0);
5764     }
5765     rh -= mp_sub_ui_dec(taba + 2 * l, c, n - 2 * l);
5766     if ((slimb_t)rh < 0) {
5767         mp_sub_ui_dec(tabs, 1, n);
5768         rh += mp_add_mul1_dec(taba, tabs, n, 2);
5769         rh += mp_add_ui_dec(taba, 1, n);
5770     }
5771     return rh;
5772 }
5773 
5774 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= B/4. Return (s,
5775    r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2 * s. tabs has n
5776    limbs. r is returned in the lower n limbs of taba. Its r[n] is the
5777    returned value of the function. */
mp_sqrtrem_dec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)5778 int mp_sqrtrem_dec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
5779 {
5780     limb_t tmp_buf1[8];
5781     limb_t *tmp_buf;
5782     mp_size_t n2;
5783     n2 = n / 2 + 1;
5784     if (n2 <= countof(tmp_buf1)) {
5785         tmp_buf = tmp_buf1;
5786     } else {
5787         tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
5788         if (!tmp_buf)
5789             return -1;
5790     }
5791     taba[n] = mp_sqrtrem_rec_dec(tabs, taba, n, tmp_buf);
5792     if (tmp_buf != tmp_buf1)
5793         bf_free(s, tmp_buf);
5794     return 0;
5795 }
5796 
5797 /* return the number of leading zero digits, from 0 to LIMB_DIGITS */
clz_dec(limb_t a)5798 static int clz_dec(limb_t a)
5799 {
5800     if (a == 0)
5801         return LIMB_DIGITS;
5802     switch(LIMB_BITS - 1 - clz(a)) {
5803     case 0: /* 1-1 */
5804         return LIMB_DIGITS - 1;
5805     case 1: /* 2-3 */
5806         return LIMB_DIGITS - 1;
5807     case 2: /* 4-7 */
5808         return LIMB_DIGITS - 1;
5809     case 3: /* 8-15 */
5810         if (a < 10)
5811             return LIMB_DIGITS - 1;
5812         else
5813             return LIMB_DIGITS - 2;
5814     case 4: /* 16-31 */
5815         return LIMB_DIGITS - 2;
5816     case 5: /* 32-63 */
5817         return LIMB_DIGITS - 2;
5818     case 6: /* 64-127 */
5819         if (a < 100)
5820             return LIMB_DIGITS - 2;
5821         else
5822             return LIMB_DIGITS - 3;
5823     case 7: /* 128-255 */
5824         return LIMB_DIGITS - 3;
5825     case 8: /* 256-511 */
5826         return LIMB_DIGITS - 3;
5827     case 9: /* 512-1023 */
5828         if (a < 1000)
5829             return LIMB_DIGITS - 3;
5830         else
5831             return LIMB_DIGITS - 4;
5832     case 10: /* 1024-2047 */
5833         return LIMB_DIGITS - 4;
5834     case 11: /* 2048-4095 */
5835         return LIMB_DIGITS - 4;
5836     case 12: /* 4096-8191 */
5837         return LIMB_DIGITS - 4;
5838     case 13: /* 8192-16383 */
5839         if (a < 10000)
5840             return LIMB_DIGITS - 4;
5841         else
5842             return LIMB_DIGITS - 5;
5843     case 14: /* 16384-32767 */
5844         return LIMB_DIGITS - 5;
5845     case 15: /* 32768-65535 */
5846         return LIMB_DIGITS - 5;
5847     case 16: /* 65536-131071 */
5848         if (a < 100000)
5849             return LIMB_DIGITS - 5;
5850         else
5851             return LIMB_DIGITS - 6;
5852     case 17: /* 131072-262143 */
5853         return LIMB_DIGITS - 6;
5854     case 18: /* 262144-524287 */
5855         return LIMB_DIGITS - 6;
5856     case 19: /* 524288-1048575 */
5857         if (a < 1000000)
5858             return LIMB_DIGITS - 6;
5859         else
5860             return LIMB_DIGITS - 7;
5861     case 20: /* 1048576-2097151 */
5862         return LIMB_DIGITS - 7;
5863     case 21: /* 2097152-4194303 */
5864         return LIMB_DIGITS - 7;
5865     case 22: /* 4194304-8388607 */
5866         return LIMB_DIGITS - 7;
5867     case 23: /* 8388608-16777215 */
5868         if (a < 10000000)
5869             return LIMB_DIGITS - 7;
5870         else
5871             return LIMB_DIGITS - 8;
5872     case 24: /* 16777216-33554431 */
5873         return LIMB_DIGITS - 8;
5874     case 25: /* 33554432-67108863 */
5875         return LIMB_DIGITS - 8;
5876     case 26: /* 67108864-134217727 */
5877         if (a < 100000000)
5878             return LIMB_DIGITS - 8;
5879         else
5880             return LIMB_DIGITS - 9;
5881 #if LIMB_BITS == 64
5882     case 27: /* 134217728-268435455 */
5883         return LIMB_DIGITS - 9;
5884     case 28: /* 268435456-536870911 */
5885         return LIMB_DIGITS - 9;
5886     case 29: /* 536870912-1073741823 */
5887         if (a < 1000000000)
5888             return LIMB_DIGITS - 9;
5889         else
5890             return LIMB_DIGITS - 10;
5891     case 30: /* 1073741824-2147483647 */
5892         return LIMB_DIGITS - 10;
5893     case 31: /* 2147483648-4294967295 */
5894         return LIMB_DIGITS - 10;
5895     case 32: /* 4294967296-8589934591 */
5896         return LIMB_DIGITS - 10;
5897     case 33: /* 8589934592-17179869183 */
5898         if (a < 10000000000)
5899             return LIMB_DIGITS - 10;
5900         else
5901             return LIMB_DIGITS - 11;
5902     case 34: /* 17179869184-34359738367 */
5903         return LIMB_DIGITS - 11;
5904     case 35: /* 34359738368-68719476735 */
5905         return LIMB_DIGITS - 11;
5906     case 36: /* 68719476736-137438953471 */
5907         if (a < 100000000000)
5908             return LIMB_DIGITS - 11;
5909         else
5910             return LIMB_DIGITS - 12;
5911     case 37: /* 137438953472-274877906943 */
5912         return LIMB_DIGITS - 12;
5913     case 38: /* 274877906944-549755813887 */
5914         return LIMB_DIGITS - 12;
5915     case 39: /* 549755813888-1099511627775 */
5916         if (a < 1000000000000)
5917             return LIMB_DIGITS - 12;
5918         else
5919             return LIMB_DIGITS - 13;
5920     case 40: /* 1099511627776-2199023255551 */
5921         return LIMB_DIGITS - 13;
5922     case 41: /* 2199023255552-4398046511103 */
5923         return LIMB_DIGITS - 13;
5924     case 42: /* 4398046511104-8796093022207 */
5925         return LIMB_DIGITS - 13;
5926     case 43: /* 8796093022208-17592186044415 */
5927         if (a < 10000000000000)
5928             return LIMB_DIGITS - 13;
5929         else
5930             return LIMB_DIGITS - 14;
5931     case 44: /* 17592186044416-35184372088831 */
5932         return LIMB_DIGITS - 14;
5933     case 45: /* 35184372088832-70368744177663 */
5934         return LIMB_DIGITS - 14;
5935     case 46: /* 70368744177664-140737488355327 */
5936         if (a < 100000000000000)
5937             return LIMB_DIGITS - 14;
5938         else
5939             return LIMB_DIGITS - 15;
5940     case 47: /* 140737488355328-281474976710655 */
5941         return LIMB_DIGITS - 15;
5942     case 48: /* 281474976710656-562949953421311 */
5943         return LIMB_DIGITS - 15;
5944     case 49: /* 562949953421312-1125899906842623 */
5945         if (a < 1000000000000000)
5946             return LIMB_DIGITS - 15;
5947         else
5948             return LIMB_DIGITS - 16;
5949     case 50: /* 1125899906842624-2251799813685247 */
5950         return LIMB_DIGITS - 16;
5951     case 51: /* 2251799813685248-4503599627370495 */
5952         return LIMB_DIGITS - 16;
5953     case 52: /* 4503599627370496-9007199254740991 */
5954         return LIMB_DIGITS - 16;
5955     case 53: /* 9007199254740992-18014398509481983 */
5956         if (a < 10000000000000000)
5957             return LIMB_DIGITS - 16;
5958         else
5959             return LIMB_DIGITS - 17;
5960     case 54: /* 18014398509481984-36028797018963967 */
5961         return LIMB_DIGITS - 17;
5962     case 55: /* 36028797018963968-72057594037927935 */
5963         return LIMB_DIGITS - 17;
5964     case 56: /* 72057594037927936-144115188075855871 */
5965         if (a < 100000000000000000)
5966             return LIMB_DIGITS - 17;
5967         else
5968             return LIMB_DIGITS - 18;
5969     case 57: /* 144115188075855872-288230376151711743 */
5970         return LIMB_DIGITS - 18;
5971     case 58: /* 288230376151711744-576460752303423487 */
5972         return LIMB_DIGITS - 18;
5973     case 59: /* 576460752303423488-1152921504606846975 */
5974         if (a < 1000000000000000000)
5975             return LIMB_DIGITS - 18;
5976         else
5977             return LIMB_DIGITS - 19;
5978 #endif
5979     default:
5980         return 0;
5981     }
5982 }
5983 
5984 /* for debugging */
bfdec_print_str(const char * str,const bfdec_t * a)5985 void bfdec_print_str(const char *str, const bfdec_t *a)
5986 {
5987     slimb_t i;
5988     printf("%s=", str);
5989 
5990     if (a->expn == BF_EXP_NAN) {
5991         printf("NaN");
5992     } else {
5993         if (a->sign)
5994             putchar('-');
5995         if (a->expn == BF_EXP_ZERO) {
5996             putchar('0');
5997         } else if (a->expn == BF_EXP_INF) {
5998             printf("Inf");
5999         } else {
6000             printf("0.");
6001             for(i = a->len - 1; i >= 0; i--)
6002                 printf("%0*" PRIu_LIMB, LIMB_DIGITS, a->tab[i]);
6003             printf("e%" PRId_LIMB, a->expn);
6004         }
6005     }
6006     printf("\n");
6007 }
6008 
6009 /* return != 0 if one digit between 0 and bit_pos inclusive is not zero. */
scan_digit_nz(const bfdec_t * r,slimb_t bit_pos)6010 static inline limb_t scan_digit_nz(const bfdec_t *r, slimb_t bit_pos)
6011 {
6012     slimb_t pos;
6013     limb_t v, q;
6014     int shift;
6015 
6016     if (bit_pos < 0)
6017         return 0;
6018     pos = (limb_t)bit_pos / LIMB_DIGITS;
6019     shift = (limb_t)bit_pos % LIMB_DIGITS;
6020     fast_udivrem(q, v, r->tab[pos], &mp_pow_div[shift + 1]);
6021     (void)q;
6022     if (v != 0)
6023         return 1;
6024     pos--;
6025     while (pos >= 0) {
6026         if (r->tab[pos] != 0)
6027             return 1;
6028         pos--;
6029     }
6030     return 0;
6031 }
6032 
get_digit(const limb_t * tab,limb_t len,slimb_t pos)6033 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos)
6034 {
6035     slimb_t i;
6036     int shift;
6037     i = floor_div(pos, LIMB_DIGITS);
6038     if (i < 0 || i >= len)
6039         return 0;
6040     shift = pos - i * LIMB_DIGITS;
6041     return fast_udiv(tab[i], &mp_pow_div[shift]) % 10;
6042 }
6043 
6044 #if 0
6045 static limb_t get_digits(const limb_t *tab, limb_t len, slimb_t pos)
6046 {
6047     limb_t a0, a1;
6048     int shift;
6049     slimb_t i;
6050 
6051     i = floor_div(pos, LIMB_DIGITS);
6052     shift = pos - i * LIMB_DIGITS;
6053     if (i >= 0 && i < len)
6054         a0 = tab[i];
6055     else
6056         a0 = 0;
6057     if (shift == 0) {
6058         return a0;
6059     } else {
6060         i++;
6061         if (i >= 0 && i < len)
6062             a1 = tab[i];
6063         else
6064             a1 = 0;
6065         return fast_udiv(a0, &mp_pow_div[shift]) +
6066             fast_urem(a1, &mp_pow_div[LIMB_DIGITS - shift]) *
6067             mp_pow_dec[shift];
6068     }
6069 }
6070 #endif
6071 
6072 /* return the addend for rounding. Note that prec can be <= 0 for bf_rint() */
bfdec_get_rnd_add(int * pret,const bfdec_t * r,limb_t l,slimb_t prec,int rnd_mode)6073 static int bfdec_get_rnd_add(int *pret, const bfdec_t *r, limb_t l,
6074                              slimb_t prec, int rnd_mode)
6075 {
6076     int add_one, inexact;
6077     limb_t digit1, digit0;
6078 
6079     //    bfdec_print_str("get_rnd_add", r);
6080     if (rnd_mode == BF_RNDF) {
6081         digit0 = 1; /* faithful rounding does not honor the INEXACT flag */
6082     } else {
6083         /* starting limb for bit 'prec + 1' */
6084         digit0 = scan_digit_nz(r, l * LIMB_DIGITS - 1 - bf_max(0, prec + 1));
6085     }
6086 
6087     /* get the digit at 'prec' */
6088     digit1 = get_digit(r->tab, l, l * LIMB_DIGITS - 1 - prec);
6089     inexact = (digit1 | digit0) != 0;
6090 
6091     add_one = 0;
6092     switch(rnd_mode) {
6093     case BF_RNDZ:
6094         break;
6095     case BF_RNDN:
6096         if (digit1 == 5) {
6097             if (digit0) {
6098                 add_one = 1;
6099             } else {
6100                 /* round to even */
6101                 add_one =
6102                     get_digit(r->tab, l, l * LIMB_DIGITS - 1 - (prec - 1)) & 1;
6103             }
6104         } else if (digit1 > 5) {
6105             add_one = 1;
6106         }
6107         break;
6108     case BF_RNDD:
6109     case BF_RNDU:
6110         if (r->sign == (rnd_mode == BF_RNDD))
6111             add_one = inexact;
6112         break;
6113     case BF_RNDNA:
6114     case BF_RNDF:
6115         add_one = (digit1 >= 5);
6116         break;
6117     case BF_RNDNU:
6118         if (digit1 >= 5) {
6119             if (r->sign)
6120                 add_one = (digit0 != 0);
6121             else
6122                 add_one = 1;
6123         }
6124         break;
6125     default:
6126         abort();
6127     }
6128 
6129     if (inexact)
6130         *pret |= BF_ST_INEXACT;
6131     return add_one;
6132 }
6133 
6134 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
6135    assumed to have length 'l' (1 <= l <= r->len). prec1 can be
6136    BF_PREC_INF. BF_FLAG_SUBNORMAL is not supported. Cannot fail with
6137    BF_ST_MEM_ERROR.
6138  */
__bfdec_round(bfdec_t * r,limb_t prec1,bf_flags_t flags,limb_t l)6139 static int __bfdec_round(bfdec_t *r, limb_t prec1, bf_flags_t flags, limb_t l)
6140 {
6141     int shift, add_one, rnd_mode, ret;
6142     slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
6143 
6144     /* e_min and e_max are computed to match the IEEE 754 conventions */
6145     /* XXX: does not matter for decimal numbers */
6146     e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
6147     e_min = -e_range + 3;
6148     e_max = e_range;
6149 
6150     if (flags & BF_FLAG_RADPNT_PREC) {
6151         /* 'prec' is the precision after the decimal point */
6152         if (prec1 != BF_PREC_INF)
6153             prec = r->expn + prec1;
6154         else
6155             prec = prec1;
6156     } else {
6157         prec = prec1;
6158     }
6159 
6160     /* round to prec bits */
6161     rnd_mode = flags & BF_RND_MASK;
6162     ret = 0;
6163     add_one = bfdec_get_rnd_add(&ret, r, l, prec, rnd_mode);
6164 
6165     if (prec <= 0) {
6166         if (add_one) {
6167             bfdec_resize(r, 1); /* cannot fail because r is non zero */
6168             r->tab[0] = BF_DEC_BASE / 10;
6169             r->expn += 1 - prec;
6170             ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6171             return ret;
6172         } else {
6173             goto underflow;
6174         }
6175     } else if (add_one) {
6176         limb_t carry;
6177 
6178         /* add one starting at digit 'prec - 1' */
6179         bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6180         pos = bit_pos / LIMB_DIGITS;
6181         carry = mp_pow_dec[bit_pos % LIMB_DIGITS];
6182         carry = mp_add_ui_dec(r->tab + pos, carry, l - pos);
6183         if (carry) {
6184             /* shift right by one digit */
6185             mp_shr_dec(r->tab + pos, r->tab + pos, l - pos, 1, 1);
6186             r->expn++;
6187         }
6188     }
6189 
6190     /* check underflow */
6191     if (unlikely(r->expn < e_min)) {
6192     underflow:
6193         bfdec_set_zero(r, r->sign);
6194         ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6195         return ret;
6196     }
6197 
6198     /* check overflow */
6199     if (unlikely(r->expn > e_max)) {
6200         bfdec_set_inf(r, r->sign);
6201         ret |= BF_ST_OVERFLOW | BF_ST_INEXACT;
6202         return ret;
6203     }
6204 
6205     /* keep the bits starting at 'prec - 1' */
6206     bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6207     i = floor_div(bit_pos, LIMB_DIGITS);
6208     if (i >= 0) {
6209         shift = smod(bit_pos, LIMB_DIGITS);
6210         if (shift != 0) {
6211             r->tab[i] = fast_udiv(r->tab[i], &mp_pow_div[shift]) *
6212                 mp_pow_dec[shift];
6213         }
6214     } else {
6215         i = 0;
6216     }
6217     /* remove trailing zeros */
6218     while (r->tab[i] == 0)
6219         i++;
6220     if (i > 0) {
6221         l -= i;
6222         memmove(r->tab, r->tab + i, l * sizeof(limb_t));
6223     }
6224     bfdec_resize(r, l); /* cannot fail */
6225     return ret;
6226 }
6227 
6228 /* Cannot fail with BF_ST_MEM_ERROR. */
bfdec_round(bfdec_t * r,limb_t prec,bf_flags_t flags)6229 int bfdec_round(bfdec_t *r, limb_t prec, bf_flags_t flags)
6230 {
6231     if (r->len == 0)
6232         return 0;
6233     return __bfdec_round(r, prec, flags, r->len);
6234 }
6235 
6236 /* 'r' must be a finite number. Cannot fail with BF_ST_MEM_ERROR.  */
bfdec_normalize_and_round(bfdec_t * r,limb_t prec1,bf_flags_t flags)6237 int bfdec_normalize_and_round(bfdec_t *r, limb_t prec1, bf_flags_t flags)
6238 {
6239     limb_t l, v;
6240     int shift, ret;
6241 
6242     //    bfdec_print_str("bf_renorm", r);
6243     l = r->len;
6244     while (l > 0 && r->tab[l - 1] == 0)
6245         l--;
6246     if (l == 0) {
6247         /* zero */
6248         r->expn = BF_EXP_ZERO;
6249         bfdec_resize(r, 0); /* cannot fail */
6250         ret = 0;
6251     } else {
6252         r->expn -= (r->len - l) * LIMB_DIGITS;
6253         /* shift to have the MSB set to '1' */
6254         v = r->tab[l - 1];
6255         shift = clz_dec(v);
6256         if (shift != 0) {
6257             mp_shl_dec(r->tab, r->tab, l, shift, 0);
6258             r->expn -= shift;
6259         }
6260         ret = __bfdec_round(r, prec1, flags, l);
6261     }
6262     //    bf_print_str("r_final", r);
6263     return ret;
6264 }
6265 
bfdec_set_ui(bfdec_t * r,uint64_t v)6266 int bfdec_set_ui(bfdec_t *r, uint64_t v)
6267 {
6268 #if LIMB_BITS == 32
6269     if (v >= BF_DEC_BASE * BF_DEC_BASE) {
6270         if (bfdec_resize(r, 3))
6271             goto fail;
6272         r->tab[0] = v % BF_DEC_BASE;
6273         v /= BF_DEC_BASE;
6274         r->tab[1] = v % BF_DEC_BASE;
6275         r->tab[2] = v / BF_DEC_BASE;
6276         r->expn = 3 * LIMB_DIGITS;
6277     } else
6278 #endif
6279     if (v >= BF_DEC_BASE) {
6280         if (bfdec_resize(r, 2))
6281             goto fail;
6282         r->tab[0] = v % BF_DEC_BASE;
6283         r->tab[1] = v / BF_DEC_BASE;
6284         r->expn = 2 * LIMB_DIGITS;
6285     } else {
6286         if (bfdec_resize(r, 1))
6287             goto fail;
6288         r->tab[0] = v;
6289         r->expn = LIMB_DIGITS;
6290     }
6291     r->sign = 0;
6292     return bfdec_normalize_and_round(r, BF_PREC_INF, 0);
6293  fail:
6294     bfdec_set_nan(r);
6295     return BF_ST_MEM_ERROR;
6296 }
6297 
bfdec_set_si(bfdec_t * r,int64_t v)6298 int bfdec_set_si(bfdec_t *r, int64_t v)
6299 {
6300     int ret;
6301     if (v < 0) {
6302         ret = bfdec_set_ui(r, -v);
6303         r->sign = 1;
6304     } else {
6305         ret = bfdec_set_ui(r, v);
6306     }
6307     return ret;
6308 }
6309 
bfdec_add_internal(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int b_neg)6310 static int bfdec_add_internal(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec, bf_flags_t flags, int b_neg)
6311 {
6312     bf_context_t *s = r->ctx;
6313     int is_sub, cmp_res, a_sign, b_sign, ret;
6314 
6315     a_sign = a->sign;
6316     b_sign = b->sign ^ b_neg;
6317     is_sub = a_sign ^ b_sign;
6318     cmp_res = bfdec_cmpu(a, b);
6319     if (cmp_res < 0) {
6320         const bfdec_t *tmp;
6321         tmp = a;
6322         a = b;
6323         b = tmp;
6324         a_sign = b_sign; /* b_sign is never used later */
6325     }
6326     /* abs(a) >= abs(b) */
6327     if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
6328         /* zero result */
6329         bfdec_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
6330         ret = 0;
6331     } else if (a->len == 0 || b->len == 0) {
6332         ret = 0;
6333         if (a->expn >= BF_EXP_INF) {
6334             if (a->expn == BF_EXP_NAN) {
6335                 /* at least one operand is NaN */
6336                 bfdec_set_nan(r);
6337                 ret = 0;
6338             } else if (b->expn == BF_EXP_INF && is_sub) {
6339                 /* infinities with different signs */
6340                 bfdec_set_nan(r);
6341                 ret = BF_ST_INVALID_OP;
6342             } else {
6343                 bfdec_set_inf(r, a_sign);
6344             }
6345         } else {
6346             /* at least one zero and not subtract */
6347             if (bfdec_set(r, a))
6348                 return BF_ST_MEM_ERROR;
6349             r->sign = a_sign;
6350             goto renorm;
6351         }
6352     } else {
6353         slimb_t d, a_offset, b_offset, i, r_len;
6354         limb_t carry;
6355         limb_t *b1_tab;
6356         int b_shift;
6357         mp_size_t b1_len;
6358 
6359         d = a->expn - b->expn;
6360 
6361         /* XXX: not efficient in time and memory if the precision is
6362            not infinite */
6363         r_len = bf_max(a->len, b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6364         if (bfdec_resize(r, r_len))
6365             goto fail;
6366         r->sign = a_sign;
6367         r->expn = a->expn;
6368 
6369         a_offset = r_len - a->len;
6370         for(i = 0; i < a_offset; i++)
6371             r->tab[i] = 0;
6372         for(i = 0; i < a->len; i++)
6373             r->tab[a_offset + i] = a->tab[i];
6374 
6375         b_shift = d % LIMB_DIGITS;
6376         if (b_shift == 0) {
6377             b1_len = b->len;
6378             b1_tab = (limb_t *)b->tab;
6379         } else {
6380             b1_len = b->len + 1;
6381             b1_tab = bf_malloc(s, sizeof(limb_t) * b1_len);
6382             if (!b1_tab)
6383                 goto fail;
6384             b1_tab[0] = mp_shr_dec(b1_tab + 1, b->tab, b->len, b_shift, 0) *
6385                 mp_pow_dec[LIMB_DIGITS - b_shift];
6386         }
6387         b_offset = r_len - (b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6388 
6389         if (is_sub) {
6390             carry = mp_sub_dec(r->tab + b_offset, r->tab + b_offset,
6391                                b1_tab, b1_len, 0);
6392             if (carry != 0) {
6393                 carry = mp_sub_ui_dec(r->tab + b_offset + b1_len, carry,
6394                                       r_len - (b_offset + b1_len));
6395                 assert(carry == 0);
6396             }
6397         } else {
6398             carry = mp_add_dec(r->tab + b_offset, r->tab + b_offset,
6399                                b1_tab, b1_len, 0);
6400             if (carry != 0) {
6401                 carry = mp_add_ui_dec(r->tab + b_offset + b1_len, carry,
6402                                       r_len - (b_offset + b1_len));
6403             }
6404             if (carry != 0) {
6405                 if (bfdec_resize(r, r_len + 1)) {
6406                     if (b_shift != 0)
6407                         bf_free(s, b1_tab);
6408                     goto fail;
6409                 }
6410                 r->tab[r_len] = 1;
6411                 r->expn += LIMB_DIGITS;
6412             }
6413         }
6414         if (b_shift != 0)
6415             bf_free(s, b1_tab);
6416     renorm:
6417         ret = bfdec_normalize_and_round(r, prec, flags);
6418     }
6419     return ret;
6420  fail:
6421     bfdec_set_nan(r);
6422     return BF_ST_MEM_ERROR;
6423 }
6424 
__bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6425 static int __bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6426                      bf_flags_t flags)
6427 {
6428     return bfdec_add_internal(r, a, b, prec, flags, 0);
6429 }
6430 
__bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6431 static int __bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6432                      bf_flags_t flags)
6433 {
6434     return bfdec_add_internal(r, a, b, prec, flags, 1);
6435 }
6436 
bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6437 int bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6438               bf_flags_t flags)
6439 {
6440     return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6441                   (bf_op2_func_t *)__bfdec_add);
6442 }
6443 
bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6444 int bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6445               bf_flags_t flags)
6446 {
6447     return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6448                   (bf_op2_func_t *)__bfdec_sub);
6449 }
6450 
bfdec_mul(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6451 int bfdec_mul(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6452               bf_flags_t flags)
6453 {
6454     int ret, r_sign;
6455 
6456     if (a->len < b->len) {
6457         const bfdec_t *tmp = a;
6458         a = b;
6459         b = tmp;
6460     }
6461     r_sign = a->sign ^ b->sign;
6462     /* here b->len <= a->len */
6463     if (b->len == 0) {
6464         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6465             bfdec_set_nan(r);
6466             ret = 0;
6467         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
6468             if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
6469                 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
6470                 bfdec_set_nan(r);
6471                 ret = BF_ST_INVALID_OP;
6472             } else {
6473                 bfdec_set_inf(r, r_sign);
6474                 ret = 0;
6475             }
6476         } else {
6477             bfdec_set_zero(r, r_sign);
6478             ret = 0;
6479         }
6480     } else {
6481         bfdec_t tmp, *r1 = NULL;
6482         limb_t a_len, b_len;
6483         limb_t *a_tab, *b_tab;
6484 
6485         a_len = a->len;
6486         b_len = b->len;
6487         a_tab = a->tab;
6488         b_tab = b->tab;
6489 
6490         if (r == a || r == b) {
6491             bfdec_init(r->ctx, &tmp);
6492             r1 = r;
6493             r = &tmp;
6494         }
6495         if (bfdec_resize(r, a_len + b_len)) {
6496             bfdec_set_nan(r);
6497             ret = BF_ST_MEM_ERROR;
6498             goto done;
6499         }
6500         mp_mul_basecase_dec(r->tab, a_tab, a_len, b_tab, b_len);
6501         r->sign = r_sign;
6502         r->expn = a->expn + b->expn;
6503         ret = bfdec_normalize_and_round(r, prec, flags);
6504     done:
6505         if (r == &tmp)
6506             bfdec_move(r1, &tmp);
6507     }
6508     return ret;
6509 }
6510 
bfdec_mul_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6511 int bfdec_mul_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6512                  bf_flags_t flags)
6513 {
6514     bfdec_t b;
6515     int ret;
6516     bfdec_init(r->ctx, &b);
6517     ret = bfdec_set_si(&b, b1);
6518     ret |= bfdec_mul(r, a, &b, prec, flags);
6519     bfdec_delete(&b);
6520     return ret;
6521 }
6522 
bfdec_add_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6523 int bfdec_add_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6524                  bf_flags_t flags)
6525 {
6526     bfdec_t b;
6527     int ret;
6528 
6529     bfdec_init(r->ctx, &b);
6530     ret = bfdec_set_si(&b, b1);
6531     ret |= bfdec_add(r, a, &b, prec, flags);
6532     bfdec_delete(&b);
6533     return ret;
6534 }
6535 
__bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6536 static int __bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6537                        limb_t prec, bf_flags_t flags)
6538 {
6539     int ret, r_sign;
6540     limb_t n, nb, precl;
6541 
6542     r_sign = a->sign ^ b->sign;
6543     if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
6544         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6545             bfdec_set_nan(r);
6546             return 0;
6547         } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
6548             bfdec_set_nan(r);
6549             return BF_ST_INVALID_OP;
6550         } else if (a->expn == BF_EXP_INF) {
6551             bfdec_set_inf(r, r_sign);
6552             return 0;
6553         } else {
6554             bfdec_set_zero(r, r_sign);
6555             return 0;
6556         }
6557     } else if (a->expn == BF_EXP_ZERO) {
6558         if (b->expn == BF_EXP_ZERO) {
6559             bfdec_set_nan(r);
6560             return BF_ST_INVALID_OP;
6561         } else {
6562             bfdec_set_zero(r, r_sign);
6563             return 0;
6564         }
6565     } else if (b->expn == BF_EXP_ZERO) {
6566         bfdec_set_inf(r, r_sign);
6567         return BF_ST_DIVIDE_ZERO;
6568     }
6569 
6570     nb = b->len;
6571     if (prec == BF_PREC_INF) {
6572         /* infinite precision: return BF_ST_INVALID_OP if not an exact
6573            result */
6574         /* XXX: check */
6575         precl = nb + 1;
6576     } else if (flags & BF_FLAG_RADPNT_PREC) {
6577         /* number of digits after the decimal point */
6578         /* XXX: check (2 extra digits for rounding + 2 digits) */
6579         precl = (bf_max(a->expn - b->expn, 0) + 2 +
6580                  2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6581     } else {
6582         /* number of limbs of the quotient (2 extra digits for rounding) */
6583          precl = (prec + 2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6584     }
6585     n = bf_max(a->len, precl);
6586 
6587     {
6588         limb_t *taba, na, i;
6589         slimb_t d;
6590 
6591         na = n + nb;
6592         taba = bf_malloc(r->ctx, (na + 1) * sizeof(limb_t));
6593         if (!taba)
6594             goto fail;
6595         d = na - a->len;
6596         memset(taba, 0, d * sizeof(limb_t));
6597         memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
6598         if (bfdec_resize(r, n + 1))
6599             goto fail1;
6600         if (mp_div_dec(r->ctx, r->tab, taba, na, b->tab, nb)) {
6601         fail1:
6602             bf_free(r->ctx, taba);
6603             goto fail;
6604         }
6605         /* see if non zero remainder */
6606         for(i = 0; i < nb; i++) {
6607             if (taba[i] != 0)
6608                 break;
6609         }
6610         bf_free(r->ctx, taba);
6611         if (i != nb) {
6612             if (prec == BF_PREC_INF) {
6613                 bfdec_set_nan(r);
6614                 return BF_ST_INVALID_OP;
6615             } else {
6616                 r->tab[0] |= 1;
6617             }
6618         }
6619         r->expn = a->expn - b->expn + LIMB_DIGITS;
6620         r->sign = r_sign;
6621         ret = bfdec_normalize_and_round(r, prec, flags);
6622     }
6623     return ret;
6624  fail:
6625     bfdec_set_nan(r);
6626     return BF_ST_MEM_ERROR;
6627 }
6628 
bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6629 int bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6630               bf_flags_t flags)
6631 {
6632     return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6633                   (bf_op2_func_t *)__bfdec_div);
6634 }
6635 
6636 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
6637    integer defined as floor(a/b) and r = a - q * b. */
bfdec_tdivremu(bf_context_t * s,bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b)6638 static void bfdec_tdivremu(bf_context_t *s, bfdec_t *q, bfdec_t *r,
6639                            const bfdec_t *a, const bfdec_t *b)
6640 {
6641     if (bfdec_cmpu(a, b) < 0) {
6642         bfdec_set_ui(q, 0);
6643         bfdec_set(r, a);
6644     } else {
6645         bfdec_div(q, a, b, 0, BF_RNDZ | BF_FLAG_RADPNT_PREC);
6646         bfdec_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
6647         bfdec_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
6648     }
6649 }
6650 
6651 /* division and remainder.
6652 
6653    rnd_mode is the rounding mode for the quotient. The additional
6654    rounding mode BF_RND_EUCLIDIAN is supported.
6655 
6656    'q' is an integer. 'r' is rounded with prec and flags (prec can be
6657    BF_PREC_INF).
6658 */
bfdec_divrem(bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)6659 int bfdec_divrem(bfdec_t *q, bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6660                  limb_t prec, bf_flags_t flags, int rnd_mode)
6661 {
6662     bf_context_t *s = q->ctx;
6663     bfdec_t a1_s, *a1 = &a1_s;
6664     bfdec_t b1_s, *b1 = &b1_s;
6665     bfdec_t r1_s, *r1 = &r1_s;
6666     int q_sign, res;
6667     BOOL is_ceil, is_rndn;
6668 
6669     assert(q != a && q != b);
6670     assert(r != a && r != b);
6671     assert(q != r);
6672 
6673     if (a->len == 0 || b->len == 0) {
6674         bfdec_set_zero(q, 0);
6675         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6676             bfdec_set_nan(r);
6677             return 0;
6678         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
6679             bfdec_set_nan(r);
6680             return BF_ST_INVALID_OP;
6681         } else {
6682             bfdec_set(r, a);
6683             return bfdec_round(r, prec, flags);
6684         }
6685     }
6686 
6687     q_sign = a->sign ^ b->sign;
6688     is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA ||
6689                rnd_mode == BF_RNDNU);
6690     switch(rnd_mode) {
6691     default:
6692     case BF_RNDZ:
6693     case BF_RNDN:
6694     case BF_RNDNA:
6695         is_ceil = FALSE;
6696         break;
6697     case BF_RNDD:
6698         is_ceil = q_sign;
6699         break;
6700     case BF_RNDU:
6701         is_ceil = q_sign ^ 1;
6702         break;
6703     case BF_DIVREM_EUCLIDIAN:
6704         is_ceil = a->sign;
6705         break;
6706     case BF_RNDNU:
6707         /* XXX: unsupported yet */
6708         abort();
6709     }
6710 
6711     a1->expn = a->expn;
6712     a1->tab = a->tab;
6713     a1->len = a->len;
6714     a1->sign = 0;
6715 
6716     b1->expn = b->expn;
6717     b1->tab = b->tab;
6718     b1->len = b->len;
6719     b1->sign = 0;
6720 
6721     //    bfdec_print_str("a1", a1);
6722     //    bfdec_print_str("b1", b1);
6723     /* XXX: could improve to avoid having a large 'q' */
6724     bfdec_tdivremu(s, q, r, a1, b1);
6725     if (bfdec_is_nan(q) || bfdec_is_nan(r))
6726         goto fail;
6727     //    bfdec_print_str("q", q);
6728     //    bfdec_print_str("r", r);
6729 
6730     if (r->len != 0) {
6731         if (is_rndn) {
6732             bfdec_init(s, r1);
6733             if (bfdec_set(r1, r))
6734                 goto fail;
6735             if (bfdec_mul_si(r1, r1, 2, BF_PREC_INF, BF_RNDZ)) {
6736                 bfdec_delete(r1);
6737                 goto fail;
6738             }
6739             res = bfdec_cmpu(r1, b);
6740             bfdec_delete(r1);
6741             if (res > 0 ||
6742                 (res == 0 &&
6743                  (rnd_mode == BF_RNDNA ||
6744                   (get_digit(q->tab, q->len, q->len * LIMB_DIGITS - q->expn) & 1) != 0))) {
6745                 goto do_sub_r;
6746             }
6747         } else if (is_ceil) {
6748         do_sub_r:
6749             res = bfdec_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
6750             res |= bfdec_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
6751             if (res & BF_ST_MEM_ERROR)
6752                 goto fail;
6753         }
6754     }
6755 
6756     r->sign ^= a->sign;
6757     q->sign = q_sign;
6758     return bfdec_round(r, prec, flags);
6759  fail:
6760     bfdec_set_nan(q);
6761     bfdec_set_nan(r);
6762     return BF_ST_MEM_ERROR;
6763 }
6764 
bfdec_fmod(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6765 int bfdec_fmod(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6766                bf_flags_t flags)
6767 {
6768     bfdec_t q_s, *q = &q_s;
6769     int ret;
6770 
6771     bfdec_init(r->ctx, q);
6772     ret = bfdec_divrem(q, r, a, b, prec, flags, BF_RNDZ);
6773     bfdec_delete(q);
6774     return ret;
6775 }
6776 
6777 /* convert to integer (infinite precision) */
bfdec_rint(bfdec_t * r,int rnd_mode)6778 int bfdec_rint(bfdec_t *r, int rnd_mode)
6779 {
6780     return bfdec_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
6781 }
6782 
bfdec_sqrt(bfdec_t * r,const bfdec_t * a,limb_t prec,bf_flags_t flags)6783 int bfdec_sqrt(bfdec_t *r, const bfdec_t *a, limb_t prec, bf_flags_t flags)
6784 {
6785     bf_context_t *s = a->ctx;
6786     int ret, k;
6787     limb_t *a1, v;
6788     slimb_t n, n1;
6789     limb_t res;
6790 
6791     assert(r != a);
6792 
6793     if (a->len == 0) {
6794         if (a->expn == BF_EXP_NAN) {
6795             bfdec_set_nan(r);
6796         } else if (a->expn == BF_EXP_INF && a->sign) {
6797             goto invalid_op;
6798         } else {
6799             bfdec_set(r, a);
6800         }
6801         ret = 0;
6802     } else if (a->sign || prec == BF_PREC_INF) {
6803  invalid_op:
6804         bfdec_set_nan(r);
6805         ret = BF_ST_INVALID_OP;
6806     } else {
6807         /* convert the mantissa to an integer with at least 2 *
6808            prec + 4 digits */
6809         n = (2 * (prec + 2) + 2 * LIMB_DIGITS - 1) / (2 * LIMB_DIGITS);
6810         if (bfdec_resize(r, n))
6811             goto fail;
6812         a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
6813         if (!a1)
6814             goto fail;
6815         n1 = bf_min(2 * n, a->len);
6816         memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
6817         memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
6818         if (a->expn & 1) {
6819             res = mp_shr_dec(a1, a1, 2 * n, 1, 0);
6820         } else {
6821             res = 0;
6822         }
6823         /* normalize so that a1 >= B^(2*n)/4. Not need for n = 1
6824            because mp_sqrtrem2_dec already does it */
6825         k = 0;
6826         if (n > 1) {
6827             v = a1[2 * n - 1];
6828             while (v < BF_DEC_BASE / 4) {
6829                 k++;
6830                 v *= 4;
6831             }
6832             if (k != 0)
6833                 mp_mul1_dec(a1, a1, 2 * n, 1 << (2 * k), 0);
6834         }
6835         if (mp_sqrtrem_dec(s, r->tab, a1, n)) {
6836             bf_free(s, a1);
6837             goto fail;
6838         }
6839         if (k != 0)
6840             mp_div1_dec(r->tab, r->tab, n, 1 << k, 0);
6841         if (!res) {
6842             res = mp_scan_nz(a1, n + 1);
6843         }
6844         bf_free(s, a1);
6845         if (!res) {
6846             res = mp_scan_nz(a->tab, a->len - n1);
6847         }
6848         if (res != 0)
6849             r->tab[0] |= 1;
6850         r->sign = 0;
6851         r->expn = (a->expn + 1) >> 1;
6852         ret = bfdec_round(r, prec, flags);
6853     }
6854     return ret;
6855  fail:
6856     bfdec_set_nan(r);
6857     return BF_ST_MEM_ERROR;
6858 }
6859 
6860 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
6861    is an overflow and 0 otherwise. No memory error is possible. */
bfdec_get_int32(int * pres,const bfdec_t * a)6862 int bfdec_get_int32(int *pres, const bfdec_t *a)
6863 {
6864     uint32_t v;
6865     int ret;
6866     if (a->expn >= BF_EXP_INF) {
6867         ret = 0;
6868         if (a->expn == BF_EXP_INF) {
6869             v = (uint32_t)INT32_MAX + a->sign;
6870              /* XXX: return overflow ? */
6871         } else {
6872             v = INT32_MAX;
6873         }
6874     } else if (a->expn <= 0) {
6875         v = 0;
6876         ret = 0;
6877     } else if (a->expn <= 9) {
6878         v = fast_udiv(a->tab[a->len - 1], &mp_pow_div[LIMB_DIGITS - a->expn]);
6879         if (a->sign)
6880             v = -v;
6881         ret = 0;
6882     } else if (a->expn == 10) {
6883         uint64_t v1;
6884         uint32_t v_max;
6885 #if LIMB_BITS == 64
6886         v1 = fast_udiv(a->tab[a->len - 1], &mp_pow_div[LIMB_DIGITS - a->expn]);
6887 #else
6888         v1 = (uint64_t)a->tab[a->len - 1] * 10 +
6889             get_digit(a->tab, a->len, (a->len - 1) * LIMB_DIGITS - 1);
6890 #endif
6891         v_max = (uint32_t)INT32_MAX + a->sign;
6892         if (v1 > v_max) {
6893             v = v_max;
6894             ret = BF_ST_OVERFLOW;
6895         } else {
6896             v = v1;
6897             if (a->sign)
6898                 v = -v;
6899             ret = 0;
6900         }
6901     } else {
6902         v = (uint32_t)INT32_MAX + a->sign;
6903         ret = BF_ST_OVERFLOW;
6904     }
6905     *pres = v;
6906     return ret;
6907 }
6908 
6909 /* power to an integer with infinite precision */
bfdec_pow_ui(bfdec_t * r,const bfdec_t * a,limb_t b)6910 int bfdec_pow_ui(bfdec_t *r, const bfdec_t *a, limb_t b)
6911 {
6912     int ret, n_bits, i;
6913 
6914     assert(r != a);
6915     if (b == 0)
6916         return bfdec_set_ui(r, 1);
6917     ret = bfdec_set(r, a);
6918     n_bits = LIMB_BITS - clz(b);
6919     for(i = n_bits - 2; i >= 0; i--) {
6920         ret |= bfdec_mul(r, r, r, BF_PREC_INF, BF_RNDZ);
6921         if ((b >> i) & 1)
6922             ret |= bfdec_mul(r, r, a, BF_PREC_INF, BF_RNDZ);
6923     }
6924     return ret;
6925 }
6926 
bfdec_ftoa(size_t * plen,const bfdec_t * a,limb_t prec,bf_flags_t flags)6927 char *bfdec_ftoa(size_t *plen, const bfdec_t *a, limb_t prec, bf_flags_t flags)
6928 {
6929     return bf_ftoa_internal(plen, (const bf_t *)a, 10, prec, flags, TRUE);
6930 }
6931 
bfdec_atof(bfdec_t * r,const char * str,const char ** pnext,limb_t prec,bf_flags_t flags)6932 int bfdec_atof(bfdec_t *r, const char *str, const char **pnext,
6933                limb_t prec, bf_flags_t flags)
6934 {
6935     slimb_t dummy_exp;
6936     return bf_atof_internal((bf_t *)r, &dummy_exp, str, pnext, 10, prec,
6937                             flags, TRUE);
6938 }
6939 
6940 #endif /* USE_BF_DEC */
6941 
6942 #ifdef USE_FFT_MUL
6943 /***************************************************************/
6944 /* Integer multiplication with FFT */
6945 
6946 /* or LIMB_BITS at bit position 'pos' in tab */
put_bits(limb_t * tab,limb_t len,slimb_t pos,limb_t val)6947 static inline void put_bits(limb_t *tab, limb_t len, slimb_t pos, limb_t val)
6948 {
6949     limb_t i;
6950     int p;
6951 
6952     i = pos >> LIMB_LOG2_BITS;
6953     p = pos & (LIMB_BITS - 1);
6954     if (i < len)
6955         tab[i] |= val << p;
6956     if (p != 0) {
6957         i++;
6958         if (i < len) {
6959             tab[i] |= val >> (LIMB_BITS - p);
6960         }
6961     }
6962 }
6963 
6964 #if defined(__AVX2__)
6965 
6966 typedef double NTTLimb;
6967 
6968 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
6969 #define NTT_MOD_LOG2_MIN 50
6970 #define NTT_MOD_LOG2_MAX 51
6971 #define NB_MODS 5
6972 #define NTT_PROOT_2EXP 39
6973 static const int ntt_int_bits[NB_MODS] = { 254, 203, 152, 101, 50, };
6974 
6975 static const limb_t ntt_mods[NB_MODS] = { 0x00073a8000000001, 0x0007858000000001, 0x0007a38000000001, 0x0007a68000000001, 0x0007fd8000000001,
6976 };
6977 
6978 static const limb_t ntt_proot[2][NB_MODS] = {
6979     { 0x00056198d44332c8, 0x0002eb5d640aad39, 0x00047e31eaa35fd0, 0x0005271ac118a150, 0x00075e0ce8442bd5, },
6980     { 0x000461169761bcc5, 0x0002dac3cb2da688, 0x0004abc97751e3bf, 0x000656778fc8c485, 0x0000dc6469c269fa, },
6981 };
6982 
6983 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
6984  0x00020e4da740da8e, 0x0004c3dc09c09c1d, 0x000063bd097b4271, 0x000799d8f18f18fd,
6985  0x0005384222222264, 0x000572b07c1f07fe, 0x00035cd08888889a,
6986  0x00066015555557e3, 0x000725960b60b623,
6987  0x0002fc1fa1d6ce12,
6988 };
6989 
6990 #else
6991 
6992 typedef limb_t NTTLimb;
6993 
6994 #if LIMB_BITS == 64
6995 
6996 #define NTT_MOD_LOG2_MIN 61
6997 #define NTT_MOD_LOG2_MAX 62
6998 #define NB_MODS 5
6999 #define NTT_PROOT_2EXP 51
7000 static const int ntt_int_bits[NB_MODS] = { 307, 246, 185, 123, 61, };
7001 
7002 static const limb_t ntt_mods[NB_MODS] = { 0x28d8000000000001, 0x2a88000000000001, 0x2ed8000000000001, 0x3508000000000001, 0x3aa8000000000001,
7003 };
7004 
7005 static const limb_t ntt_proot[2][NB_MODS] = {
7006     { 0x1b8ea61034a2bea7, 0x21a9762de58206fb, 0x02ca782f0756a8ea, 0x278384537a3e50a1, 0x106e13fee74ce0ab, },
7007     { 0x233513af133e13b8, 0x1d13140d1c6f75f1, 0x12cde57f97e3eeda, 0x0d6149e23cbe654f, 0x36cd204f522a1379, },
7008 };
7009 
7010 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7011  0x08a9ed097b425eea, 0x18a44aaaaaaaaab3, 0x2493f57f57f57f5d, 0x126b8d0649a7f8d4,
7012  0x09d80ed7303b5ccc, 0x25b8bcf3cf3cf3d5, 0x2ce6ce63398ce638,
7013  0x0e31fad40a57eb59, 0x02a3529fd4a7f52f,
7014  0x3a5493e93e93e94a,
7015 };
7016 
7017 #elif LIMB_BITS == 32
7018 
7019 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
7020 #define NTT_MOD_LOG2_MIN 29
7021 #define NTT_MOD_LOG2_MAX 30
7022 #define NB_MODS 5
7023 #define NTT_PROOT_2EXP 20
7024 static const int ntt_int_bits[NB_MODS] = { 148, 119, 89, 59, 29, };
7025 
7026 static const limb_t ntt_mods[NB_MODS] = { 0x0000000032b00001, 0x0000000033700001, 0x0000000036d00001, 0x0000000037300001, 0x000000003e500001,
7027 };
7028 
7029 static const limb_t ntt_proot[2][NB_MODS] = {
7030     { 0x0000000032525f31, 0x0000000005eb3b37, 0x00000000246eda9f, 0x0000000035f25901, 0x00000000022f5768, },
7031     { 0x00000000051eba1a, 0x00000000107be10e, 0x000000001cd574e0, 0x00000000053806e6, 0x000000002cd6bf98, },
7032 };
7033 
7034 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7035  0x000000000449559a, 0x000000001eba6ca9, 0x000000002ec18e46, 0x000000000860160b,
7036  0x000000000d321307, 0x000000000bf51120, 0x000000000f662938,
7037  0x000000000932ab3e, 0x000000002f40eef8,
7038  0x000000002e760905,
7039 };
7040 
7041 #endif /* LIMB_BITS */
7042 
7043 #endif /* !AVX2 */
7044 
7045 #if defined(__AVX2__)
7046 #define NTT_TRIG_K_MAX 18
7047 #else
7048 #define NTT_TRIG_K_MAX 19
7049 #endif
7050 
7051 typedef struct BFNTTState {
7052     bf_context_t *ctx;
7053 
7054     /* used for mul_mod_fast() */
7055     limb_t ntt_mods_div[NB_MODS];
7056 
7057     limb_t ntt_proot_pow[NB_MODS][2][NTT_PROOT_2EXP + 1];
7058     limb_t ntt_proot_pow_inv[NB_MODS][2][NTT_PROOT_2EXP + 1];
7059     NTTLimb *ntt_trig[NB_MODS][2][NTT_TRIG_K_MAX + 1];
7060     /* 1/2^n mod m */
7061     limb_t ntt_len_inv[NB_MODS][NTT_PROOT_2EXP + 1][2];
7062 #if defined(__AVX2__)
7063     __m256d ntt_mods_cr_vec[NB_MODS * (NB_MODS - 1) / 2];
7064     __m256d ntt_mods_vec[NB_MODS];
7065     __m256d ntt_mods_inv_vec[NB_MODS];
7066 #else
7067     limb_t ntt_mods_cr_inv[NB_MODS * (NB_MODS - 1) / 2];
7068 #endif
7069 } BFNTTState;
7070 
7071 static NTTLimb *get_trig(BFNTTState *s, int k, int inverse, int m_idx);
7072 
7073 /* add modulo with up to (LIMB_BITS-1) bit modulo */
add_mod(limb_t a,limb_t b,limb_t m)7074 static inline limb_t add_mod(limb_t a, limb_t b, limb_t m)
7075 {
7076     limb_t r;
7077     r = a + b;
7078     if (r >= m)
7079         r -= m;
7080     return r;
7081 }
7082 
7083 /* sub modulo with up to LIMB_BITS bit modulo */
sub_mod(limb_t a,limb_t b,limb_t m)7084 static inline limb_t sub_mod(limb_t a, limb_t b, limb_t m)
7085 {
7086     limb_t r;
7087     r = a - b;
7088     if (r > a)
7089         r += m;
7090     return r;
7091 }
7092 
7093 /* return (r0+r1*B) mod m
7094    precondition: 0 <= r0+r1*B < 2^(64+NTT_MOD_LOG2_MIN)
7095 */
mod_fast(dlimb_t r,limb_t m,limb_t m_inv)7096 static inline limb_t mod_fast(dlimb_t r,
7097                                 limb_t m, limb_t m_inv)
7098 {
7099     limb_t a1, q, t0, r1, r0;
7100 
7101     a1 = r >> NTT_MOD_LOG2_MIN;
7102 
7103     q = ((dlimb_t)a1 * m_inv) >> LIMB_BITS;
7104     r = r - (dlimb_t)q * m - m * 2;
7105     r1 = r >> LIMB_BITS;
7106     t0 = (slimb_t)r1 >> 1;
7107     r += m & t0;
7108     r0 = r;
7109     r1 = r >> LIMB_BITS;
7110     r0 += m & r1;
7111     return r0;
7112 }
7113 
7114 /* faster version using precomputed modulo inverse.
7115    precondition: 0 <= a * b < 2^(64+NTT_MOD_LOG2_MIN) */
mul_mod_fast(limb_t a,limb_t b,limb_t m,limb_t m_inv)7116 static inline limb_t mul_mod_fast(limb_t a, limb_t b,
7117                                     limb_t m, limb_t m_inv)
7118 {
7119     dlimb_t r;
7120     r = (dlimb_t)a * (dlimb_t)b;
7121     return mod_fast(r, m, m_inv);
7122 }
7123 
init_mul_mod_fast(limb_t m)7124 static inline limb_t init_mul_mod_fast(limb_t m)
7125 {
7126     dlimb_t t;
7127     assert(m < (limb_t)1 << NTT_MOD_LOG2_MAX);
7128     assert(m >= (limb_t)1 << NTT_MOD_LOG2_MIN);
7129     t = (dlimb_t)1 << (LIMB_BITS + NTT_MOD_LOG2_MIN);
7130     return t / m;
7131 }
7132 
7133 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7134    0 <= b < m. */
mul_mod_fast2(limb_t a,limb_t b,limb_t m,limb_t b_inv)7135 static inline limb_t mul_mod_fast2(limb_t a, limb_t b,
7136                                      limb_t m, limb_t b_inv)
7137 {
7138     limb_t r, q;
7139 
7140     q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7141     r = a * b - q * m;
7142     if (r >= m)
7143         r -= m;
7144     return r;
7145 }
7146 
7147 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7148    0 <= b < m. Let r = a * b mod m. The return value is 'r' or 'r +
7149    m'. */
mul_mod_fast3(limb_t a,limb_t b,limb_t m,limb_t b_inv)7150 static inline limb_t mul_mod_fast3(limb_t a, limb_t b,
7151                                      limb_t m, limb_t b_inv)
7152 {
7153     limb_t r, q;
7154 
7155     q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7156     r = a * b - q * m;
7157     return r;
7158 }
7159 
init_mul_mod_fast2(limb_t b,limb_t m)7160 static inline limb_t init_mul_mod_fast2(limb_t b, limb_t m)
7161 {
7162     return ((dlimb_t)b << LIMB_BITS) / m;
7163 }
7164 
7165 #ifdef __AVX2__
7166 
ntt_limb_to_int(NTTLimb a,limb_t m)7167 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7168 {
7169     slimb_t v;
7170     v = a;
7171     if (v < 0)
7172         v += m;
7173     if (v >= m)
7174         v -= m;
7175     return v;
7176 }
7177 
int_to_ntt_limb(limb_t a,limb_t m)7178 static inline NTTLimb int_to_ntt_limb(limb_t a, limb_t m)
7179 {
7180     return (slimb_t)a;
7181 }
7182 
int_to_ntt_limb2(limb_t a,limb_t m)7183 static inline NTTLimb int_to_ntt_limb2(limb_t a, limb_t m)
7184 {
7185     if (a >= (m / 2))
7186         a -= m;
7187     return (slimb_t)a;
7188 }
7189 
7190 /* return r + m if r < 0 otherwise r. */
ntt_mod1(__m256d r,__m256d m)7191 static inline __m256d ntt_mod1(__m256d r, __m256d m)
7192 {
7193     return _mm256_blendv_pd(r, r + m, r);
7194 }
7195 
7196 /* input: abs(r) < 2 * m. Output: abs(r) < m */
ntt_mod(__m256d r,__m256d mf,__m256d m2f)7197 static inline __m256d ntt_mod(__m256d r, __m256d mf, __m256d m2f)
7198 {
7199     return _mm256_blendv_pd(r, r + m2f, r) - mf;
7200 }
7201 
7202 /* input: abs(a*b) < 2 * m^2, output: abs(r) < m */
ntt_mul_mod(__m256d a,__m256d b,__m256d mf,__m256d m_inv)7203 static inline __m256d ntt_mul_mod(__m256d a, __m256d b, __m256d mf,
7204                                   __m256d m_inv)
7205 {
7206     __m256d r, q, ab1, ab0, qm0, qm1;
7207     ab1 = a * b;
7208     q = _mm256_round_pd(ab1 * m_inv, 0); /* round to nearest */
7209     qm1 = q * mf;
7210     qm0 = _mm256_fmsub_pd(q, mf, qm1); /* low part */
7211     ab0 = _mm256_fmsub_pd(a, b, ab1); /* low part */
7212     r = (ab1 - qm1) + (ab0 - qm0);
7213     return r;
7214 }
7215 
bf_aligned_malloc(bf_context_t * s,size_t size,size_t align)7216 static void *bf_aligned_malloc(bf_context_t *s, size_t size, size_t align)
7217 {
7218     void *ptr;
7219     void **ptr1;
7220     ptr = bf_malloc(s, size + sizeof(void *) + align - 1);
7221     if (!ptr)
7222         return NULL;
7223     ptr1 = (void **)(((uintptr_t)ptr + sizeof(void *) + align - 1) &
7224                      ~(align - 1));
7225     ptr1[-1] = ptr;
7226     return ptr1;
7227 }
7228 
bf_aligned_free(bf_context_t * s,void * ptr)7229 static void bf_aligned_free(bf_context_t *s, void *ptr)
7230 {
7231     if (!ptr)
7232         return;
7233     bf_free(s, ((void **)ptr)[-1]);
7234 }
7235 
ntt_malloc(BFNTTState * s,size_t size)7236 static void *ntt_malloc(BFNTTState *s, size_t size)
7237 {
7238     return bf_aligned_malloc(s->ctx, size, 64);
7239 }
7240 
ntt_free(BFNTTState * s,void * ptr)7241 static void ntt_free(BFNTTState *s, void *ptr)
7242 {
7243     bf_aligned_free(s->ctx, ptr);
7244 }
7245 
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7246 static no_inline int ntt_fft(BFNTTState *s,
7247                              NTTLimb *out_buf, NTTLimb *in_buf,
7248                              NTTLimb *tmp_buf, int fft_len_log2,
7249                              int inverse, int m_idx)
7250 {
7251     limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j;
7252     NTTLimb *tab_in, *tab_out, *tmp, *trig;
7253     __m256d m_inv, mf, m2f, c, a0, a1, b0, b1;
7254     limb_t m;
7255     int l;
7256 
7257     m = ntt_mods[m_idx];
7258 
7259     m_inv = _mm256_set1_pd(1.0 / (double)m);
7260     mf = _mm256_set1_pd(m);
7261     m2f = _mm256_set1_pd(m * 2);
7262 
7263     n = (limb_t)1 << fft_len_log2;
7264     assert(n >= 8);
7265     stride_in = n / 2;
7266 
7267     tab_in = in_buf;
7268     tab_out = tmp_buf;
7269     trig = get_trig(s, fft_len_log2, inverse, m_idx);
7270     if (!trig)
7271         return -1;
7272     p = 0;
7273     for(k = 0; k < stride_in; k += 4) {
7274         a0 = _mm256_load_pd(&tab_in[k]);
7275         a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7276         c = _mm256_load_pd(trig);
7277         trig += 4;
7278         b0 = ntt_mod(a0 + a1, mf, m2f);
7279         b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7280         a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7281         a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7282         a0 = _mm256_permute4x64_pd(a0, 0xd8);
7283         a1 = _mm256_permute4x64_pd(a1, 0xd8);
7284         _mm256_store_pd(&tab_out[p], a0);
7285         _mm256_store_pd(&tab_out[p + 4], a1);
7286         p += 2 * 4;
7287     }
7288     tmp = tab_in;
7289     tab_in = tab_out;
7290     tab_out = tmp;
7291 
7292     trig = get_trig(s, fft_len_log2 - 1, inverse, m_idx);
7293     if (!trig)
7294         return -1;
7295     p = 0;
7296     for(k = 0; k < stride_in; k += 4) {
7297         a0 = _mm256_load_pd(&tab_in[k]);
7298         a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7299         c = _mm256_setr_pd(trig[0], trig[0], trig[1], trig[1]);
7300         trig += 2;
7301         b0 = ntt_mod(a0 + a1, mf, m2f);
7302         b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7303         a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7304         a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7305         _mm256_store_pd(&tab_out[p], a0);
7306         _mm256_store_pd(&tab_out[p + 4], a1);
7307         p += 2 * 4;
7308     }
7309     tmp = tab_in;
7310     tab_in = tab_out;
7311     tab_out = tmp;
7312 
7313     nb_blocks = n / 4;
7314     fft_per_block = 4;
7315 
7316     l = fft_len_log2 - 2;
7317     while (nb_blocks != 2) {
7318         nb_blocks >>= 1;
7319         p = 0;
7320         k = 0;
7321         trig = get_trig(s, l, inverse, m_idx);
7322         if (!trig)
7323             return -1;
7324         for(i = 0; i < nb_blocks; i++) {
7325             c = _mm256_set1_pd(trig[0]);
7326             trig++;
7327             for(j = 0; j < fft_per_block; j += 4) {
7328                 a0 = _mm256_load_pd(&tab_in[k + j]);
7329                 a1 = _mm256_load_pd(&tab_in[k + j + stride_in]);
7330                 b0 = ntt_mod(a0 + a1, mf, m2f);
7331                 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7332                 _mm256_store_pd(&tab_out[p + j], b0);
7333                 _mm256_store_pd(&tab_out[p + j + fft_per_block], b1);
7334             }
7335             k += fft_per_block;
7336             p += 2 * fft_per_block;
7337         }
7338         fft_per_block <<= 1;
7339         l--;
7340         tmp = tab_in;
7341         tab_in = tab_out;
7342         tab_out = tmp;
7343     }
7344 
7345     tab_out = out_buf;
7346     for(k = 0; k < stride_in; k += 4) {
7347         a0 = _mm256_load_pd(&tab_in[k]);
7348         a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7349         b0 = ntt_mod(a0 + a1, mf, m2f);
7350         b1 = ntt_mod(a0 - a1, mf, m2f);
7351         _mm256_store_pd(&tab_out[k], b0);
7352         _mm256_store_pd(&tab_out[k + stride_in], b1);
7353     }
7354     return 0;
7355 }
7356 
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,limb_t fft_len_log2,int k_tot,int m_idx)7357 static void ntt_vec_mul(BFNTTState *s,
7358                         NTTLimb *tab1, NTTLimb *tab2, limb_t fft_len_log2,
7359                         int k_tot, int m_idx)
7360 {
7361     limb_t i, c_inv, n, m;
7362     __m256d m_inv, mf, a, b, c;
7363 
7364     m = ntt_mods[m_idx];
7365     c_inv = s->ntt_len_inv[m_idx][k_tot][0];
7366     m_inv = _mm256_set1_pd(1.0 / (double)m);
7367     mf = _mm256_set1_pd(m);
7368     c = _mm256_set1_pd(int_to_ntt_limb(c_inv, m));
7369     n = (limb_t)1 << fft_len_log2;
7370     for(i = 0; i < n; i += 4) {
7371         a = _mm256_load_pd(&tab1[i]);
7372         b = _mm256_load_pd(&tab2[i]);
7373         a = ntt_mul_mod(a, b, mf, m_inv);
7374         a = ntt_mul_mod(a, c, mf, m_inv);
7375         _mm256_store_pd(&tab1[i], a);
7376     }
7377 }
7378 
mul_trig(NTTLimb * buf,limb_t n,limb_t c1,limb_t m,limb_t m_inv1)7379 static no_inline void mul_trig(NTTLimb *buf,
7380                                limb_t n, limb_t c1, limb_t m, limb_t m_inv1)
7381 {
7382     limb_t i, c2, c3, c4;
7383     __m256d c, c_mul, a0, mf, m_inv;
7384     assert(n >= 2);
7385 
7386     mf = _mm256_set1_pd(m);
7387     m_inv = _mm256_set1_pd(1.0 / (double)m);
7388 
7389     c2 = mul_mod_fast(c1, c1, m, m_inv1);
7390     c3 = mul_mod_fast(c2, c1, m, m_inv1);
7391     c4 = mul_mod_fast(c2, c2, m, m_inv1);
7392     c = _mm256_setr_pd(1, int_to_ntt_limb(c1, m),
7393                        int_to_ntt_limb(c2, m), int_to_ntt_limb(c3, m));
7394     c_mul = _mm256_set1_pd(int_to_ntt_limb(c4, m));
7395     for(i = 0; i < n; i += 4) {
7396         a0 = _mm256_load_pd(&buf[i]);
7397         a0 = ntt_mul_mod(a0, c, mf, m_inv);
7398         _mm256_store_pd(&buf[i], a0);
7399         c = ntt_mul_mod(c, c_mul, mf, m_inv);
7400     }
7401 }
7402 
7403 #else
7404 
ntt_malloc(BFNTTState * s,size_t size)7405 static void *ntt_malloc(BFNTTState *s, size_t size)
7406 {
7407     return bf_malloc(s->ctx, size);
7408 }
7409 
ntt_free(BFNTTState * s,void * ptr)7410 static void ntt_free(BFNTTState *s, void *ptr)
7411 {
7412     bf_free(s->ctx, ptr);
7413 }
7414 
ntt_limb_to_int(NTTLimb a,limb_t m)7415 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7416 {
7417     if (a >= m)
7418         a -= m;
7419     return a;
7420 }
7421 
int_to_ntt_limb(slimb_t a,limb_t m)7422 static inline NTTLimb int_to_ntt_limb(slimb_t a, limb_t m)
7423 {
7424     return a;
7425 }
7426 
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7427 static no_inline int ntt_fft(BFNTTState *s, NTTLimb *out_buf, NTTLimb *in_buf,
7428                              NTTLimb *tmp_buf, int fft_len_log2,
7429                              int inverse, int m_idx)
7430 {
7431     limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j, m, m2;
7432     NTTLimb *tab_in, *tab_out, *tmp, a0, a1, b0, b1, c, *trig, c_inv;
7433     int l;
7434 
7435     m = ntt_mods[m_idx];
7436     m2 = 2 * m;
7437     n = (limb_t)1 << fft_len_log2;
7438     nb_blocks = n;
7439     fft_per_block = 1;
7440     stride_in = n / 2;
7441     tab_in = in_buf;
7442     tab_out = tmp_buf;
7443     l = fft_len_log2;
7444     while (nb_blocks != 2) {
7445         nb_blocks >>= 1;
7446         p = 0;
7447         k = 0;
7448         trig = get_trig(s, l, inverse, m_idx);
7449         if (!trig)
7450             return -1;
7451         for(i = 0; i < nb_blocks; i++) {
7452             c = trig[0];
7453             c_inv = trig[1];
7454             trig += 2;
7455             for(j = 0; j < fft_per_block; j++) {
7456                 a0 = tab_in[k + j];
7457                 a1 = tab_in[k + j + stride_in];
7458                 b0 = add_mod(a0, a1, m2);
7459                 b1 = a0 - a1 + m2;
7460                 b1 = mul_mod_fast3(b1, c, m, c_inv);
7461                 tab_out[p + j] = b0;
7462                 tab_out[p + j + fft_per_block] = b1;
7463             }
7464             k += fft_per_block;
7465             p += 2 * fft_per_block;
7466         }
7467         fft_per_block <<= 1;
7468         l--;
7469         tmp = tab_in;
7470         tab_in = tab_out;
7471         tab_out = tmp;
7472     }
7473     /* no twiddle in last step */
7474     tab_out = out_buf;
7475     for(k = 0; k < stride_in; k++) {
7476         a0 = tab_in[k];
7477         a1 = tab_in[k + stride_in];
7478         b0 = add_mod(a0, a1, m2);
7479         b1 = sub_mod(a0, a1, m2);
7480         tab_out[k] = b0;
7481         tab_out[k + stride_in] = b1;
7482     }
7483     return 0;
7484 }
7485 
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,int fft_len_log2,int k_tot,int m_idx)7486 static void ntt_vec_mul(BFNTTState *s,
7487                         NTTLimb *tab1, NTTLimb *tab2, int fft_len_log2,
7488                         int k_tot, int m_idx)
7489 {
7490     limb_t i, norm, norm_inv, a, n, m, m_inv;
7491 
7492     m = ntt_mods[m_idx];
7493     m_inv = s->ntt_mods_div[m_idx];
7494     norm = s->ntt_len_inv[m_idx][k_tot][0];
7495     norm_inv = s->ntt_len_inv[m_idx][k_tot][1];
7496     n = (limb_t)1 << fft_len_log2;
7497     for(i = 0; i < n; i++) {
7498         a = tab1[i];
7499         /* need to reduce the range so that the product is <
7500            2^(LIMB_BITS+NTT_MOD_LOG2_MIN) */
7501         if (a >= m)
7502             a -= m;
7503         a = mul_mod_fast(a, tab2[i], m, m_inv);
7504         a = mul_mod_fast3(a, norm, m, norm_inv);
7505         tab1[i] = a;
7506     }
7507 }
7508 
mul_trig(NTTLimb * buf,limb_t n,limb_t c_mul,limb_t m,limb_t m_inv)7509 static no_inline void mul_trig(NTTLimb *buf,
7510                                limb_t n, limb_t c_mul, limb_t m, limb_t m_inv)
7511 {
7512     limb_t i, c0, c_mul_inv;
7513 
7514     c0 = 1;
7515     c_mul_inv = init_mul_mod_fast2(c_mul, m);
7516     for(i = 0; i < n; i++) {
7517         buf[i] = mul_mod_fast(buf[i], c0, m, m_inv);
7518         c0 = mul_mod_fast2(c0, c_mul, m, c_mul_inv);
7519     }
7520 }
7521 
7522 #endif /* !AVX2 */
7523 
get_trig(BFNTTState * s,int k,int inverse,int m_idx)7524 static no_inline NTTLimb *get_trig(BFNTTState *s,
7525                                    int k, int inverse, int m_idx)
7526 {
7527     NTTLimb *tab;
7528     limb_t i, n2, c, c_mul, m, c_mul_inv;
7529 
7530     if (k > NTT_TRIG_K_MAX)
7531         return NULL;
7532 
7533     tab = s->ntt_trig[m_idx][inverse][k];
7534     if (tab)
7535         return tab;
7536     n2 = (limb_t)1 << (k - 1);
7537     m = ntt_mods[m_idx];
7538 #ifdef __AVX2__
7539     tab = ntt_malloc(s, sizeof(NTTLimb) * n2);
7540 #else
7541     tab = ntt_malloc(s, sizeof(NTTLimb) * n2 * 2);
7542 #endif
7543     if (!tab)
7544         return NULL;
7545     c = 1;
7546     c_mul = s->ntt_proot_pow[m_idx][inverse][k];
7547     c_mul_inv = s->ntt_proot_pow_inv[m_idx][inverse][k];
7548     for(i = 0; i < n2; i++) {
7549 #ifdef __AVX2__
7550         tab[i] = int_to_ntt_limb2(c, m);
7551 #else
7552         tab[2 * i] = int_to_ntt_limb(c, m);
7553         tab[2 * i + 1] = init_mul_mod_fast2(c, m);
7554 #endif
7555         c = mul_mod_fast2(c, c_mul, m, c_mul_inv);
7556     }
7557     s->ntt_trig[m_idx][inverse][k] = tab;
7558     return tab;
7559 }
7560 
fft_clear_cache(bf_context_t * s1)7561 void fft_clear_cache(bf_context_t *s1)
7562 {
7563     int m_idx, inverse, k;
7564     BFNTTState *s = s1->ntt_state;
7565     if (s) {
7566         for(m_idx = 0; m_idx < NB_MODS; m_idx++) {
7567             for(inverse = 0; inverse < 2; inverse++) {
7568                 for(k = 0; k < NTT_TRIG_K_MAX + 1; k++) {
7569                     if (s->ntt_trig[m_idx][inverse][k]) {
7570                         ntt_free(s, s->ntt_trig[m_idx][inverse][k]);
7571                         s->ntt_trig[m_idx][inverse][k] = NULL;
7572                     }
7573                 }
7574             }
7575         }
7576 #if defined(__AVX2__)
7577         bf_aligned_free(s1, s);
7578 #else
7579         bf_free(s1, s);
7580 #endif
7581         s1->ntt_state = NULL;
7582     }
7583 }
7584 
7585 #define STRIP_LEN 16
7586 
7587 /* dst = buf1, src = buf2 */
ntt_fft_partial(BFNTTState * s,NTTLimb * buf1,int k1,int k2,limb_t n1,limb_t n2,int inverse,limb_t m_idx)7588 static int ntt_fft_partial(BFNTTState *s, NTTLimb *buf1,
7589                            int k1, int k2, limb_t n1, limb_t n2, int inverse,
7590                            limb_t m_idx)
7591 {
7592     limb_t i, j, c_mul, c0, m, m_inv, strip_len, l;
7593     NTTLimb *buf2, *buf3;
7594 
7595     buf2 = NULL;
7596     buf3 = ntt_malloc(s, sizeof(NTTLimb) * n1);
7597     if (!buf3)
7598         goto fail;
7599     if (k2 == 0) {
7600         if (ntt_fft(s, buf1, buf1, buf3, k1, inverse, m_idx))
7601             goto fail;
7602     } else {
7603         strip_len = STRIP_LEN;
7604         buf2 = ntt_malloc(s, sizeof(NTTLimb) * n1 * strip_len);
7605         if (!buf2)
7606             goto fail;
7607         m = ntt_mods[m_idx];
7608         m_inv = s->ntt_mods_div[m_idx];
7609         c0 = s->ntt_proot_pow[m_idx][inverse][k1 + k2];
7610         c_mul = 1;
7611         assert((n2 % strip_len) == 0);
7612         for(j = 0; j < n2; j += strip_len) {
7613             for(i = 0; i < n1; i++) {
7614                 for(l = 0; l < strip_len; l++) {
7615                     buf2[i + l * n1] = buf1[i * n2 + (j + l)];
7616                 }
7617             }
7618             for(l = 0; l < strip_len; l++) {
7619                 if (inverse)
7620                     mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7621                 if (ntt_fft(s, buf2 + l * n1, buf2 + l * n1, buf3, k1, inverse, m_idx))
7622                     goto fail;
7623                 if (!inverse)
7624                     mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7625                 c_mul = mul_mod_fast(c_mul, c0, m, m_inv);
7626             }
7627 
7628             for(i = 0; i < n1; i++) {
7629                 for(l = 0; l < strip_len; l++) {
7630                     buf1[i * n2 + (j + l)] = buf2[i + l *n1];
7631                 }
7632             }
7633         }
7634         ntt_free(s, buf2);
7635     }
7636     ntt_free(s, buf3);
7637     return 0;
7638  fail:
7639     ntt_free(s, buf2);
7640     ntt_free(s, buf3);
7641     return -1;
7642 }
7643 
7644 
7645 /* dst = buf1, src = buf2, tmp = buf3 */
ntt_conv(BFNTTState * s,NTTLimb * buf1,NTTLimb * buf2,int k,int k_tot,limb_t m_idx)7646 static int ntt_conv(BFNTTState *s, NTTLimb *buf1, NTTLimb *buf2,
7647                     int k, int k_tot, limb_t m_idx)
7648 {
7649     limb_t n1, n2, i;
7650     int k1, k2;
7651 
7652     if (k <= NTT_TRIG_K_MAX) {
7653         k1 = k;
7654     } else {
7655         /* recursive split of the FFT */
7656         k1 = bf_min(k / 2, NTT_TRIG_K_MAX);
7657     }
7658     k2 = k - k1;
7659     n1 = (limb_t)1 << k1;
7660     n2 = (limb_t)1 << k2;
7661 
7662     if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 0, m_idx))
7663         return -1;
7664     if (ntt_fft_partial(s, buf2, k1, k2, n1, n2, 0, m_idx))
7665         return -1;
7666     if (k2 == 0) {
7667         ntt_vec_mul(s, buf1, buf2, k, k_tot, m_idx);
7668     } else {
7669         for(i = 0; i < n1; i++) {
7670             ntt_conv(s, buf1 + i * n2, buf2 + i * n2, k2, k_tot, m_idx);
7671         }
7672     }
7673     if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 1, m_idx))
7674         return -1;
7675     return 0;
7676 }
7677 
7678 
limb_to_ntt(BFNTTState * s,NTTLimb * tabr,limb_t fft_len,const limb_t * taba,limb_t a_len,int dpl,int first_m_idx,int nb_mods)7679 static no_inline void limb_to_ntt(BFNTTState *s,
7680                                   NTTLimb *tabr, limb_t fft_len,
7681                                   const limb_t *taba, limb_t a_len, int dpl,
7682                                   int first_m_idx, int nb_mods)
7683 {
7684     slimb_t i, n;
7685     dlimb_t a, b;
7686     int j, shift;
7687     limb_t base_mask1, a0, a1, a2, r, m, m_inv;
7688 
7689 #if 0
7690     for(i = 0; i < a_len; i++) {
7691         printf("%" PRId64 ": " FMT_LIMB "\n",
7692                (int64_t)i, taba[i]);
7693     }
7694 #endif
7695     memset(tabr, 0, sizeof(NTTLimb) * fft_len * nb_mods);
7696     shift = dpl & (LIMB_BITS - 1);
7697     if (shift == 0)
7698         base_mask1 = -1;
7699     else
7700         base_mask1 = ((limb_t)1 << shift) - 1;
7701     n = bf_min(fft_len, (a_len * LIMB_BITS + dpl - 1) / dpl);
7702     for(i = 0; i < n; i++) {
7703         a0 = get_bits(taba, a_len, i * dpl);
7704         if (dpl <= LIMB_BITS) {
7705             a0 &= base_mask1;
7706             a = a0;
7707         } else {
7708             a1 = get_bits(taba, a_len, i * dpl + LIMB_BITS);
7709             if (dpl <= (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
7710                 a = a0 | ((dlimb_t)(a1 & base_mask1) << LIMB_BITS);
7711             } else {
7712                 if (dpl > 2 * LIMB_BITS) {
7713                     a2 = get_bits(taba, a_len, i * dpl + LIMB_BITS * 2) &
7714                         base_mask1;
7715                 } else {
7716                     a1 &= base_mask1;
7717                     a2 = 0;
7718                 }
7719                 //            printf("a=0x%016lx%016lx%016lx\n", a2, a1, a0);
7720                 a = (a0 >> (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) |
7721                     ((dlimb_t)a1 << (NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN)) |
7722                     ((dlimb_t)a2 << (LIMB_BITS + NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN));
7723                 a0 &= ((limb_t)1 << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) - 1;
7724             }
7725         }
7726         for(j = 0; j < nb_mods; j++) {
7727             m = ntt_mods[first_m_idx + j];
7728             m_inv = s->ntt_mods_div[first_m_idx + j];
7729             r = mod_fast(a, m, m_inv);
7730             if (dpl > (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
7731                 b = ((dlimb_t)r << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) | a0;
7732                 r = mod_fast(b, m, m_inv);
7733             }
7734             tabr[i + j * fft_len] = int_to_ntt_limb(r, m);
7735         }
7736     }
7737 }
7738 
7739 #if defined(__AVX2__)
7740 
7741 #define VEC_LEN 4
7742 
7743 typedef union {
7744     __m256d v;
7745     double d[4];
7746 } VecUnion;
7747 
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)7748 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
7749                                   const NTTLimb *buf, int fft_len_log2, int dpl,
7750                                   int nb_mods)
7751 {
7752     const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
7753     const __m256d *mods_cr_vec, *mf, *m_inv;
7754     VecUnion y[NB_MODS];
7755     limb_t u[NB_MODS], carry[NB_MODS], fft_len, base_mask1, r;
7756     slimb_t i, len, pos;
7757     int j, k, l, shift, n_limb1, p;
7758     dlimb_t t;
7759 
7760     j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
7761     mods_cr_vec = s->ntt_mods_cr_vec + j;
7762     mf = s->ntt_mods_vec + NB_MODS - nb_mods;
7763     m_inv = s->ntt_mods_inv_vec + NB_MODS - nb_mods;
7764 
7765     shift = dpl & (LIMB_BITS - 1);
7766     if (shift == 0)
7767         base_mask1 = -1;
7768     else
7769         base_mask1 = ((limb_t)1 << shift) - 1;
7770     n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
7771     for(j = 0; j < NB_MODS; j++)
7772         carry[j] = 0;
7773     for(j = 0; j < NB_MODS; j++)
7774         u[j] = 0; /* avoid warnings */
7775     memset(tabr, 0, sizeof(limb_t) * r_len);
7776     fft_len = (limb_t)1 << fft_len_log2;
7777     len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
7778     len = (len + VEC_LEN - 1) & ~(VEC_LEN - 1);
7779     i = 0;
7780     while (i < len) {
7781         for(j = 0; j < nb_mods; j++)
7782             y[j].v = *(__m256d *)&buf[i + fft_len * j];
7783 
7784         /* Chinese remainder to get mixed radix representation */
7785         l = 0;
7786         for(j = 0; j < nb_mods - 1; j++) {
7787             y[j].v = ntt_mod1(y[j].v, mf[j]);
7788             for(k = j + 1; k < nb_mods; k++) {
7789                 y[k].v = ntt_mul_mod(y[k].v - y[j].v,
7790                                      mods_cr_vec[l], mf[k], m_inv[k]);
7791                 l++;
7792             }
7793         }
7794         y[j].v = ntt_mod1(y[j].v, mf[j]);
7795 
7796         for(p = 0; p < VEC_LEN; p++) {
7797             /* back to normal representation */
7798             u[0] = (int64_t)y[nb_mods - 1].d[p];
7799             l = 1;
7800             for(j = nb_mods - 2; j >= 1; j--) {
7801                 r = (int64_t)y[j].d[p];
7802                 for(k = 0; k < l; k++) {
7803                     t = (dlimb_t)u[k] * mods[j] + r;
7804                     r = t >> LIMB_BITS;
7805                     u[k] = t;
7806                 }
7807                 u[l] = r;
7808                 l++;
7809             }
7810             /* XXX: for nb_mods = 5, l should be 4 */
7811 
7812             /* last step adds the carry */
7813             r = (int64_t)y[0].d[p];
7814             for(k = 0; k < l; k++) {
7815                 t = (dlimb_t)u[k] * mods[j] + r + carry[k];
7816                 r = t >> LIMB_BITS;
7817                 u[k] = t;
7818             }
7819             u[l] = r + carry[l];
7820 
7821 #if 0
7822             printf("%" PRId64 ": ", i);
7823             for(j = nb_mods - 1; j >= 0; j--) {
7824                 printf(" %019" PRIu64, u[j]);
7825             }
7826             printf("\n");
7827 #endif
7828 
7829             /* write the digits */
7830             pos = i * dpl;
7831             for(j = 0; j < n_limb1; j++) {
7832                 put_bits(tabr, r_len, pos, u[j]);
7833                 pos += LIMB_BITS;
7834             }
7835             put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
7836             /* shift by dpl digits and set the carry */
7837             if (shift == 0) {
7838                 for(j = n_limb1 + 1; j < nb_mods; j++)
7839                     carry[j - (n_limb1 + 1)] = u[j];
7840             } else {
7841                 for(j = n_limb1; j < nb_mods - 1; j++) {
7842                     carry[j - n_limb1] = (u[j] >> shift) |
7843                         (u[j + 1] << (LIMB_BITS - shift));
7844                 }
7845                 carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
7846             }
7847             i++;
7848         }
7849     }
7850 }
7851 #else
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)7852 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
7853                                   const NTTLimb *buf, int fft_len_log2, int dpl,
7854                                   int nb_mods)
7855 {
7856     const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
7857     const limb_t *mods_cr, *mods_cr_inv;
7858     limb_t y[NB_MODS], u[NB_MODS+2], carry[NB_MODS], fft_len, base_mask1, r;
7859     slimb_t i, len, pos;
7860     int j, k, l, shift, n_limb1;
7861     dlimb_t t;
7862 
7863     j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
7864     mods_cr = ntt_mods_cr + j;
7865     mods_cr_inv = s->ntt_mods_cr_inv + j;
7866 
7867     shift = dpl & (LIMB_BITS - 1);
7868     if (shift == 0)
7869         base_mask1 = -1;
7870     else
7871         base_mask1 = ((limb_t)1 << shift) - 1;
7872     n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
7873     for(j = 0; j < NB_MODS; j++)
7874         carry[j] = 0;
7875     for(j = 0; j < NB_MODS; j++)
7876         u[j] = 0; /* avoid warnings */
7877     memset(tabr, 0, sizeof(limb_t) * r_len);
7878     fft_len = (limb_t)1 << fft_len_log2;
7879     len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
7880     for(i = 0; i < len; i++) {
7881         for(j = 0; j < nb_mods; j++)  {
7882             y[j] = ntt_limb_to_int(buf[i + fft_len * j], mods[j]);
7883         }
7884 
7885         /* Chinese remainder to get mixed radix representation */
7886         l = 0;
7887         for(j = 0; j < nb_mods - 1; j++) {
7888             for(k = j + 1; k < nb_mods; k++) {
7889                 limb_t m;
7890                 m = mods[k];
7891                 /* Note: there is no overflow in the sub_mod() because
7892                    the modulos are sorted by increasing order */
7893                 y[k] = mul_mod_fast2(y[k] - y[j] + m,
7894                                      mods_cr[l], m, mods_cr_inv[l]);
7895                 l++;
7896             }
7897         }
7898 
7899         /* back to normal representation */
7900         u[0] = y[nb_mods - 1];
7901         l = 1;
7902         for(j = nb_mods - 2; j >= 1; j--) {
7903             r = y[j];
7904             for(k = 0; k < l; k++) {
7905                 t = (dlimb_t)u[k] * mods[j] + r;
7906                 r = t >> LIMB_BITS;
7907                 u[k] = t;
7908             }
7909             u[l] = r;
7910             l++;
7911         }
7912 
7913         /* last step adds the carry */
7914         r = y[0];
7915         for(k = 0; k < l; k++) {
7916             t = (dlimb_t)u[k] * mods[j] + r + carry[k];
7917             r = t >> LIMB_BITS;
7918             u[k] = t;
7919         }
7920         u[l] = r + carry[l];
7921 
7922 #if 0
7923         printf("%" PRId64 ": ", (int64_t)i);
7924         for(j = nb_mods - 1; j >= 0; j--) {
7925             printf(" " FMT_LIMB, u[j]);
7926         }
7927         printf("\n");
7928 #endif
7929 
7930         /* write the digits */
7931         pos = i * dpl;
7932         for(j = 0; j < n_limb1; j++) {
7933             put_bits(tabr, r_len, pos, u[j]);
7934             pos += LIMB_BITS;
7935         }
7936         put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
7937         /* shift by dpl digits and set the carry */
7938         if (shift == 0) {
7939             for(j = n_limb1 + 1; j < nb_mods; j++)
7940                 carry[j - (n_limb1 + 1)] = u[j];
7941         } else {
7942             for(j = n_limb1; j < nb_mods - 1; j++) {
7943                 carry[j - n_limb1] = (u[j] >> shift) |
7944                     (u[j + 1] << (LIMB_BITS - shift));
7945             }
7946             carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
7947         }
7948     }
7949 }
7950 #endif
7951 
ntt_static_init(bf_context_t * s1)7952 static int ntt_static_init(bf_context_t *s1)
7953 {
7954     BFNTTState *s;
7955     int inverse, i, j, k, l;
7956     limb_t c, c_inv, c_inv2, m, m_inv;
7957 
7958     if (s1->ntt_state)
7959         return 0;
7960 #if defined(__AVX2__)
7961     s = bf_aligned_malloc(s1, sizeof(*s), 64);
7962 #else
7963     s = bf_malloc(s1, sizeof(*s));
7964 #endif
7965     if (!s)
7966         return -1;
7967     memset(s, 0, sizeof(*s));
7968     s1->ntt_state = s;
7969     s->ctx = s1;
7970 
7971     for(j = 0; j < NB_MODS; j++) {
7972         m = ntt_mods[j];
7973         m_inv = init_mul_mod_fast(m);
7974         s->ntt_mods_div[j] = m_inv;
7975 #if defined(__AVX2__)
7976         s->ntt_mods_vec[j] = _mm256_set1_pd(m);
7977         s->ntt_mods_inv_vec[j] = _mm256_set1_pd(1.0 / (double)m);
7978 #endif
7979         c_inv2 = (m + 1) / 2; /* 1/2 */
7980         c_inv = 1;
7981         for(i = 0; i <= NTT_PROOT_2EXP; i++) {
7982             s->ntt_len_inv[j][i][0] = c_inv;
7983             s->ntt_len_inv[j][i][1] = init_mul_mod_fast2(c_inv, m);
7984             c_inv = mul_mod_fast(c_inv, c_inv2, m, m_inv);
7985         }
7986 
7987         for(inverse = 0; inverse < 2; inverse++) {
7988             c = ntt_proot[inverse][j];
7989             for(i = 0; i < NTT_PROOT_2EXP; i++) {
7990                 s->ntt_proot_pow[j][inverse][NTT_PROOT_2EXP - i] = c;
7991                 s->ntt_proot_pow_inv[j][inverse][NTT_PROOT_2EXP - i] =
7992                     init_mul_mod_fast2(c, m);
7993                 c = mul_mod_fast(c, c, m, m_inv);
7994             }
7995         }
7996     }
7997 
7998     l = 0;
7999     for(j = 0; j < NB_MODS - 1; j++) {
8000         for(k = j + 1; k < NB_MODS; k++) {
8001 #if defined(__AVX2__)
8002             s->ntt_mods_cr_vec[l] = _mm256_set1_pd(int_to_ntt_limb2(ntt_mods_cr[l],
8003                                                                     ntt_mods[k]));
8004 #else
8005             s->ntt_mods_cr_inv[l] = init_mul_mod_fast2(ntt_mods_cr[l],
8006                                                        ntt_mods[k]);
8007 #endif
8008             l++;
8009         }
8010     }
8011     return 0;
8012 }
8013 
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8014 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8015 {
8016     int dpl, fft_len_log2, n_bits, nb_mods, dpl_found, fft_len_log2_found;
8017     int int_bits, nb_mods_found;
8018     limb_t cost, min_cost;
8019 
8020     min_cost = -1;
8021     dpl_found = 0;
8022     nb_mods_found = 4;
8023     fft_len_log2_found = 0;
8024     for(nb_mods = 3; nb_mods <= NB_MODS; nb_mods++) {
8025         int_bits = ntt_int_bits[NB_MODS - nb_mods];
8026         dpl = bf_min((int_bits - 4) / 2,
8027                      2 * LIMB_BITS + 2 * NTT_MOD_LOG2_MIN - NTT_MOD_LOG2_MAX);
8028         for(;;) {
8029             fft_len_log2 = ceil_log2((len * LIMB_BITS + dpl - 1) / dpl);
8030             if (fft_len_log2 > NTT_PROOT_2EXP)
8031                 goto next;
8032             n_bits = fft_len_log2 + 2 * dpl;
8033             if (n_bits <= int_bits) {
8034                 cost = ((limb_t)(fft_len_log2 + 1) << fft_len_log2) * nb_mods;
8035                 //                printf("n=%d dpl=%d: cost=%" PRId64 "\n", nb_mods, dpl, (int64_t)cost);
8036                 if (cost < min_cost) {
8037                     min_cost = cost;
8038                     dpl_found = dpl;
8039                     nb_mods_found = nb_mods;
8040                     fft_len_log2_found = fft_len_log2;
8041                 }
8042                 break;
8043             }
8044             dpl--;
8045             if (dpl == 0)
8046                 break;
8047         }
8048     next: ;
8049     }
8050     if (!dpl_found)
8051         abort();
8052     /* limit dpl if possible to reduce fixed cost of limb/NTT conversion */
8053     if (dpl_found > (LIMB_BITS + NTT_MOD_LOG2_MIN) &&
8054         ((limb_t)(LIMB_BITS + NTT_MOD_LOG2_MIN) << fft_len_log2_found) >=
8055         len * LIMB_BITS) {
8056         dpl_found = LIMB_BITS + NTT_MOD_LOG2_MIN;
8057     }
8058     *pnb_mods = nb_mods_found;
8059     *pdpl = dpl_found;
8060     return fft_len_log2_found;
8061 }
8062 
8063 /* return 0 if OK, -1 if memory error */
fft_mul(bf_context_t * s1,bf_t * res,limb_t * a_tab,limb_t a_len,limb_t * b_tab,limb_t b_len,int mul_flags)8064 static no_inline int fft_mul(bf_context_t *s1,
8065                              bf_t *res, limb_t *a_tab, limb_t a_len,
8066                              limb_t *b_tab, limb_t b_len, int mul_flags)
8067 {
8068     BFNTTState *s;
8069     int dpl, fft_len_log2, j, nb_mods, reduced_mem;
8070     slimb_t len, fft_len;
8071     NTTLimb *buf1, *buf2, *ptr;
8072 #if defined(USE_MUL_CHECK)
8073     limb_t ha, hb, hr, h_ref;
8074 #endif
8075 
8076     if (ntt_static_init(s1))
8077         return -1;
8078     s = s1->ntt_state;
8079 
8080     /* find the optimal number of digits per limb (dpl) */
8081     len = a_len + b_len;
8082     fft_len_log2 = bf_get_fft_size(&dpl, &nb_mods, len);
8083     fft_len = (uint64_t)1 << fft_len_log2;
8084     //    printf("len=%" PRId64 " fft_len_log2=%d dpl=%d\n", len, fft_len_log2, dpl);
8085 #if defined(USE_MUL_CHECK)
8086     ha = mp_mod1(a_tab, a_len, BF_CHKSUM_MOD, 0);
8087     hb = mp_mod1(b_tab, b_len, BF_CHKSUM_MOD, 0);
8088 #endif
8089     if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) == 0) {
8090         if (!(mul_flags & FFT_MUL_R_NORESIZE))
8091             bf_resize(res, 0);
8092     } else if (mul_flags & FFT_MUL_R_OVERLAP_B) {
8093         limb_t *tmp_tab, tmp_len;
8094         /* it is better to free 'b' first */
8095         tmp_tab = a_tab;
8096         a_tab = b_tab;
8097         b_tab = tmp_tab;
8098         tmp_len = a_len;
8099         a_len = b_len;
8100         b_len = tmp_len;
8101     }
8102     buf1 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8103     if (!buf1)
8104         return -1;
8105     limb_to_ntt(s, buf1, fft_len, a_tab, a_len, dpl,
8106                 NB_MODS - nb_mods, nb_mods);
8107     if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) ==
8108         FFT_MUL_R_OVERLAP_A) {
8109         if (!(mul_flags & FFT_MUL_R_NORESIZE))
8110             bf_resize(res, 0);
8111     }
8112     reduced_mem = (fft_len_log2 >= 14);
8113     if (!reduced_mem) {
8114         buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8115         if (!buf2)
8116             goto fail;
8117         limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8118                     NB_MODS - nb_mods, nb_mods);
8119         if (!(mul_flags & FFT_MUL_R_NORESIZE))
8120             bf_resize(res, 0); /* in case res == b */
8121     } else {
8122         buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len);
8123         if (!buf2)
8124             goto fail;
8125     }
8126     for(j = 0; j < nb_mods; j++) {
8127         if (reduced_mem) {
8128             limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8129                         NB_MODS - nb_mods + j, 1);
8130             ptr = buf2;
8131         } else {
8132             ptr = buf2 + fft_len * j;
8133         }
8134         if (ntt_conv(s, buf1 + fft_len * j, ptr,
8135                      fft_len_log2, fft_len_log2, j + NB_MODS - nb_mods))
8136             goto fail;
8137     }
8138     if (!(mul_flags & FFT_MUL_R_NORESIZE))
8139         bf_resize(res, 0); /* in case res == b and reduced mem */
8140     ntt_free(s, buf2);
8141     buf2 = NULL;
8142     if (!(mul_flags & FFT_MUL_R_NORESIZE)) {
8143         if (bf_resize(res, len))
8144             goto fail;
8145     }
8146     ntt_to_limb(s, res->tab, len, buf1, fft_len_log2, dpl, nb_mods);
8147     ntt_free(s, buf1);
8148 #if defined(USE_MUL_CHECK)
8149     hr = mp_mod1(res->tab, len, BF_CHKSUM_MOD, 0);
8150     h_ref = mul_mod(ha, hb, BF_CHKSUM_MOD);
8151     if (hr != h_ref) {
8152         printf("ntt_mul_error: len=%" PRId_LIMB " fft_len_log2=%d dpl=%d nb_mods=%d\n",
8153                len, fft_len_log2, dpl, nb_mods);
8154         //        printf("ha=0x" FMT_LIMB" hb=0x" FMT_LIMB " hr=0x" FMT_LIMB " expected=0x" FMT_LIMB "\n", ha, hb, hr, h_ref);
8155         exit(1);
8156     }
8157 #endif
8158     return 0;
8159  fail:
8160     ntt_free(s, buf1);
8161     ntt_free(s, buf2);
8162     return -1;
8163 }
8164 
8165 #else /* USE_FFT_MUL */
8166 
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8167 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8168 {
8169     return 0;
8170 }
8171 
8172 #endif /* !USE_FFT_MUL */
8173