1"""
2Low-level functions for arbitrary-precision floating-point arithmetic.
3"""
4
5__docformat__ = 'plaintext'
6
7import math
8
9from bisect import bisect
10
11import sys
12
13# Importing random is slow
14#from random import getrandbits
15getrandbits = None
16
17from .backend import (MPZ, MPZ_TYPE, MPZ_ZERO, MPZ_ONE, MPZ_TWO, MPZ_FIVE,
18    BACKEND, STRICT, HASH_MODULUS, HASH_BITS, gmpy, sage, sage_utils)
19
20from .libintmath import (giant_steps,
21    trailtable, bctable, lshift, rshift, bitcount, trailing,
22    sqrt_fixed, numeral, isqrt, isqrt_fast, sqrtrem,
23    bin_to_radix)
24
25# We don't pickle tuples directly for the following reasons:
26#   1: pickle uses str() for ints, which is inefficient when they are large
27#   2: pickle doesn't work for gmpy mpzs
28# Both problems are solved by using hex()
29
30if BACKEND == 'sage':
31    def to_pickable(x):
32        sign, man, exp, bc = x
33        return sign, hex(man), exp, bc
34else:
35    def to_pickable(x):
36        sign, man, exp, bc = x
37        return sign, hex(man)[2:], exp, bc
38
39def from_pickable(x):
40    sign, man, exp, bc = x
41    return (sign, MPZ(man, 16), exp, bc)
42
43class ComplexResult(ValueError):
44    pass
45
46try:
47    intern
48except NameError:
49    intern = lambda x: x
50
51# All supported rounding modes
52round_nearest = intern('n')
53round_floor = intern('f')
54round_ceiling = intern('c')
55round_up = intern('u')
56round_down = intern('d')
57round_fast = round_down
58
59def prec_to_dps(n):
60    """Return number of accurate decimals that can be represented
61    with a precision of n bits."""
62    return max(1, int(round(int(n)/3.3219280948873626)-1))
63
64def dps_to_prec(n):
65    """Return the number of bits required to represent n decimals
66    accurately."""
67    return max(1, int(round((int(n)+1)*3.3219280948873626)))
68
69def repr_dps(n):
70    """Return the number of decimal digits required to represent
71    a number with n-bit precision so that it can be uniquely
72    reconstructed from the representation."""
73    dps = prec_to_dps(n)
74    if dps == 15:
75        return 17
76    return dps + 3
77
78#----------------------------------------------------------------------------#
79#                    Some commonly needed float values                       #
80#----------------------------------------------------------------------------#
81
82# Regular number format:
83# (-1)**sign * mantissa * 2**exponent, plus bitcount of mantissa
84fzero = (0, MPZ_ZERO, 0, 0)
85fnzero = (1, MPZ_ZERO, 0, 0)
86fone = (0, MPZ_ONE, 0, 1)
87fnone = (1, MPZ_ONE, 0, 1)
88ftwo = (0, MPZ_ONE, 1, 1)
89ften = (0, MPZ_FIVE, 1, 3)
90fhalf = (0, MPZ_ONE, -1, 1)
91
92# Arbitrary encoding for special numbers: zero mantissa, nonzero exponent
93fnan = (0, MPZ_ZERO, -123, -1)
94finf = (0, MPZ_ZERO, -456, -2)
95fninf = (1, MPZ_ZERO, -789, -3)
96
97# Was 1e1000; this is broken in Python 2.4
98math_float_inf = 1e300 * 1e300
99
100
101#----------------------------------------------------------------------------#
102#                                  Rounding                                  #
103#----------------------------------------------------------------------------#
104
105# This function can be used to round a mantissa generally. However,
106# we will try to do most rounding inline for efficiency.
107def round_int(x, n, rnd):
108    if rnd == round_nearest:
109        if x >= 0:
110            t = x >> (n-1)
111            if t & 1 and ((t & 2) or (x & h_mask[n<300][n])):
112                return (t>>1)+1
113            else:
114                return t>>1
115        else:
116            return -round_int(-x, n, rnd)
117    if rnd == round_floor:
118        return x >> n
119    if rnd == round_ceiling:
120        return -((-x) >> n)
121    if rnd == round_down:
122        if x >= 0:
123            return x >> n
124        return -((-x) >> n)
125    if rnd == round_up:
126        if x >= 0:
127            return -((-x) >> n)
128        return x >> n
129
130# These masks are used to pick out segments of numbers to determine
131# which direction to round when rounding to nearest.
132class h_mask_big:
133    def __getitem__(self, n):
134        return (MPZ_ONE<<(n-1))-1
135
136h_mask_small = [0]+[((MPZ_ONE<<(_-1))-1) for _ in range(1, 300)]
137h_mask = [h_mask_big(), h_mask_small]
138
139# The >> operator rounds to floor. shifts_down[rnd][sign]
140# tells whether this is the right direction to use, or if the
141# number should be negated before shifting
142shifts_down = {round_floor:(1,0), round_ceiling:(0,1),
143    round_down:(1,1), round_up:(0,0)}
144
145
146#----------------------------------------------------------------------------#
147#                          Normalization of raw mpfs                         #
148#----------------------------------------------------------------------------#
149
150# This function is called almost every time an mpf is created.
151# It has been optimized accordingly.
152
153def _normalize(sign, man, exp, bc, prec, rnd):
154    """
155    Create a raw mpf tuple with value (-1)**sign * man * 2**exp and
156    normalized mantissa. The mantissa is rounded in the specified
157    direction if its size exceeds the precision. Trailing zero bits
158    are also stripped from the mantissa to ensure that the
159    representation is canonical.
160
161    Conditions on the input:
162    * The input must represent a regular (finite) number
163    * The sign bit must be 0 or 1
164    * The mantissa must be positive
165    * The exponent must be an integer
166    * The bitcount must be exact
167
168    If these conditions are not met, use from_man_exp, mpf_pos, or any
169    of the conversion functions to create normalized raw mpf tuples.
170    """
171    if not man:
172        return fzero
173    # Cut mantissa down to size if larger than target precision
174    n = bc - prec
175    if n > 0:
176        if rnd == round_nearest:
177            t = man >> (n-1)
178            if t & 1 and ((t & 2) or (man & h_mask[n<300][n])):
179                man = (t>>1)+1
180            else:
181                man = t>>1
182        elif shifts_down[rnd][sign]:
183            man >>= n
184        else:
185            man = -((-man)>>n)
186        exp += n
187        bc = prec
188    # Strip trailing bits
189    if not man & 1:
190        t = trailtable[int(man & 255)]
191        if not t:
192            while not man & 255:
193                man >>= 8
194                exp += 8
195                bc -= 8
196            t = trailtable[int(man & 255)]
197        man >>= t
198        exp += t
199        bc -= t
200    # Bit count can be wrong if the input mantissa was 1 less than
201    # a power of 2 and got rounded up, thereby adding an extra bit.
202    # With trailing bits removed, all powers of two have mantissa 1,
203    # so this is easy to check for.
204    if man == 1:
205        bc = 1
206    return sign, man, exp, bc
207
208def _normalize1(sign, man, exp, bc, prec, rnd):
209    """same as normalize, but with the added condition that
210       man is odd or zero
211    """
212    if not man:
213        return fzero
214    if bc <= prec:
215        return sign, man, exp, bc
216    n = bc - prec
217    if rnd == round_nearest:
218        t = man >> (n-1)
219        if t & 1 and ((t & 2) or (man & h_mask[n<300][n])):
220            man = (t>>1)+1
221        else:
222            man = t>>1
223    elif shifts_down[rnd][sign]:
224        man >>= n
225    else:
226        man = -((-man)>>n)
227    exp += n
228    bc = prec
229    # Strip trailing bits
230    if not man & 1:
231        t = trailtable[int(man & 255)]
232        if not t:
233            while not man & 255:
234                man >>= 8
235                exp += 8
236                bc -= 8
237            t = trailtable[int(man & 255)]
238        man >>= t
239        exp += t
240        bc -= t
241    # Bit count can be wrong if the input mantissa was 1 less than
242    # a power of 2 and got rounded up, thereby adding an extra bit.
243    # With trailing bits removed, all powers of two have mantissa 1,
244    # so this is easy to check for.
245    if man == 1:
246        bc = 1
247    return sign, man, exp, bc
248
249try:
250    _exp_types = (int, long)
251except NameError:
252    _exp_types = (int,)
253
254def strict_normalize(sign, man, exp, bc, prec, rnd):
255    """Additional checks on the components of an mpf. Enable tests by setting
256       the environment variable MPMATH_STRICT to Y."""
257    assert type(man) == MPZ_TYPE
258    assert type(bc) in _exp_types
259    assert type(exp) in _exp_types
260    assert bc == bitcount(man)
261    return _normalize(sign, man, exp, bc, prec, rnd)
262
263def strict_normalize1(sign, man, exp, bc, prec, rnd):
264    """Additional checks on the components of an mpf. Enable tests by setting
265       the environment variable MPMATH_STRICT to Y."""
266    assert type(man) == MPZ_TYPE
267    assert type(bc) in _exp_types
268    assert type(exp) in _exp_types
269    assert bc == bitcount(man)
270    assert (not man) or (man & 1)
271    return _normalize1(sign, man, exp, bc, prec, rnd)
272
273if BACKEND == 'gmpy' and '_mpmath_normalize' in dir(gmpy):
274    _normalize = gmpy._mpmath_normalize
275    _normalize1 = gmpy._mpmath_normalize
276
277if BACKEND == 'sage':
278    _normalize = _normalize1 = sage_utils.normalize
279
280if STRICT:
281    normalize = strict_normalize
282    normalize1 = strict_normalize1
283else:
284    normalize = _normalize
285    normalize1 = _normalize1
286
287#----------------------------------------------------------------------------#
288#                            Conversion functions                            #
289#----------------------------------------------------------------------------#
290
291def from_man_exp(man, exp, prec=None, rnd=round_fast):
292    """Create raw mpf from (man, exp) pair. The mantissa may be signed.
293    If no precision is specified, the mantissa is stored exactly."""
294    man = MPZ(man)
295    sign = 0
296    if man < 0:
297        sign = 1
298        man = -man
299    if man < 1024:
300        bc = bctable[int(man)]
301    else:
302        bc = bitcount(man)
303    if not prec:
304        if not man:
305            return fzero
306        if not man & 1:
307            if man & 2:
308                return (sign, man >> 1, exp + 1, bc - 1)
309            t = trailtable[int(man & 255)]
310            if not t:
311                while not man & 255:
312                    man >>= 8
313                    exp += 8
314                    bc -= 8
315                t = trailtable[int(man & 255)]
316            man >>= t
317            exp += t
318            bc -= t
319        return (sign, man, exp, bc)
320    return normalize(sign, man, exp, bc, prec, rnd)
321
322int_cache = dict((n, from_man_exp(n, 0)) for n in range(-10, 257))
323
324if BACKEND == 'gmpy' and '_mpmath_create' in dir(gmpy):
325    from_man_exp = gmpy._mpmath_create
326
327if BACKEND == 'sage':
328    from_man_exp = sage_utils.from_man_exp
329
330def from_int(n, prec=0, rnd=round_fast):
331    """Create a raw mpf from an integer. If no precision is specified,
332    the mantissa is stored exactly."""
333    if not prec:
334        if n in int_cache:
335            return int_cache[n]
336    return from_man_exp(n, 0, prec, rnd)
337
338def to_man_exp(s):
339    """Return (man, exp) of a raw mpf. Raise an error if inf/nan."""
340    sign, man, exp, bc = s
341    if (not man) and exp:
342        raise ValueError("mantissa and exponent are undefined for %s" % man)
343    return man, exp
344
345def to_int(s, rnd=None):
346    """Convert a raw mpf to the nearest int. Rounding is done down by
347    default (same as int(float) in Python), but can be changed. If the
348    input is inf/nan, an exception is raised."""
349    sign, man, exp, bc = s
350    if (not man) and exp:
351        raise ValueError("cannot convert inf or nan to int")
352    if exp >= 0:
353        if sign:
354            return (-man) << exp
355        return man << exp
356    # Make default rounding fast
357    if not rnd:
358        if sign:
359            return -(man >> (-exp))
360        else:
361            return man >> (-exp)
362    if sign:
363        return round_int(-man, -exp, rnd)
364    else:
365        return round_int(man, -exp, rnd)
366
367def mpf_round_int(s, rnd):
368    sign, man, exp, bc = s
369    if (not man) and exp:
370        return s
371    if exp >= 0:
372        return s
373    mag = exp+bc
374    if mag < 1:
375        if rnd == round_ceiling:
376            if sign: return fzero
377            else:    return fone
378        elif rnd == round_floor:
379            if sign: return fnone
380            else:    return fzero
381        elif rnd == round_nearest:
382            if mag < 0 or man == MPZ_ONE: return fzero
383            elif sign: return fnone
384            else:      return fone
385        else:
386            raise NotImplementedError
387    return mpf_pos(s, min(bc, mag), rnd)
388
389def mpf_floor(s, prec=0, rnd=round_fast):
390    v = mpf_round_int(s, round_floor)
391    if prec:
392        v = mpf_pos(v, prec, rnd)
393    return v
394
395def mpf_ceil(s, prec=0, rnd=round_fast):
396    v = mpf_round_int(s, round_ceiling)
397    if prec:
398        v = mpf_pos(v, prec, rnd)
399    return v
400
401def mpf_nint(s, prec=0, rnd=round_fast):
402    v = mpf_round_int(s, round_nearest)
403    if prec:
404        v = mpf_pos(v, prec, rnd)
405    return v
406
407def mpf_frac(s, prec=0, rnd=round_fast):
408    return mpf_sub(s, mpf_floor(s), prec, rnd)
409
410def from_float(x, prec=53, rnd=round_fast):
411    """Create a raw mpf from a Python float, rounding if necessary.
412    If prec >= 53, the result is guaranteed to represent exactly the
413    same number as the input. If prec is not specified, use prec=53."""
414    # frexp only raises an exception for nan on some platforms
415    if x != x:
416        return fnan
417    # in Python2.5 math.frexp gives an exception for float infinity
418    # in Python2.6 it returns (float infinity, 0)
419    try:
420        m, e = math.frexp(x)
421    except:
422        if x == math_float_inf: return finf
423        if x == -math_float_inf: return fninf
424        return fnan
425    if x == math_float_inf: return finf
426    if x == -math_float_inf: return fninf
427    return from_man_exp(int(m*(1<<53)), e-53, prec, rnd)
428
429def from_npfloat(x, prec=113, rnd=round_fast):
430    """Create a raw mpf from a numpy float, rounding if necessary.
431    If prec >= 113, the result is guaranteed to represent exactly the
432    same number as the input. If prec is not specified, use prec=113."""
433    y = float(x)
434    if x == y: # ldexp overflows for float16
435        return from_float(y, prec, rnd)
436    import numpy as np
437    if np.isfinite(x):
438        m, e = np.frexp(x)
439        return from_man_exp(int(np.ldexp(m, 113)), int(e-113), prec, rnd)
440    if np.isposinf(x): return finf
441    if np.isneginf(x): return fninf
442    return fnan
443
444def from_Decimal(x, prec=None, rnd=round_fast):
445    """Create a raw mpf from a Decimal, rounding if necessary.
446    If prec is not specified, use the equivalent bit precision
447    of the number of significant digits in x."""
448    if x.is_nan(): return fnan
449    if x.is_infinite(): return fninf if x.is_signed() else finf
450    if prec is None:
451        prec = int(len(x.as_tuple()[1])*3.3219280948873626)
452    return from_str(str(x), prec, rnd)
453
454def to_float(s, strict=False, rnd=round_fast):
455    """
456    Convert a raw mpf to a Python float. The result is exact if the
457    bitcount of s is <= 53 and no underflow/overflow occurs.
458
459    If the number is too large or too small to represent as a regular
460    float, it will be converted to inf or 0.0. Setting strict=True
461    forces an OverflowError to be raised instead.
462
463    Warning: with a directed rounding mode, the correct nearest representable
464    floating-point number in the specified direction might not be computed
465    in case of overflow or (gradual) underflow.
466    """
467    sign, man, exp, bc = s
468    if not man:
469        if s == fzero: return 0.0
470        if s == finf: return math_float_inf
471        if s == fninf: return -math_float_inf
472        return math_float_inf/math_float_inf
473    if bc > 53:
474        sign, man, exp, bc = normalize1(sign, man, exp, bc, 53, rnd)
475    if sign:
476        man = -man
477    try:
478        return math.ldexp(man, exp)
479    except OverflowError:
480        if strict:
481            raise
482        # Overflow to infinity
483        if exp + bc > 0:
484            if sign:
485                return -math_float_inf
486            else:
487                return math_float_inf
488        # Underflow to zero
489        return 0.0
490
491def from_rational(p, q, prec, rnd=round_fast):
492    """Create a raw mpf from a rational number p/q, round if
493    necessary."""
494    return mpf_div(from_int(p), from_int(q), prec, rnd)
495
496def to_rational(s):
497    """Convert a raw mpf to a rational number. Return integers (p, q)
498    such that s = p/q exactly."""
499    sign, man, exp, bc = s
500    if sign:
501        man = -man
502    if bc == -1:
503        raise ValueError("cannot convert %s to a rational number" % man)
504    if exp >= 0:
505        return man * (1<<exp), 1
506    else:
507        return man, 1<<(-exp)
508
509def to_fixed(s, prec):
510    """Convert a raw mpf to a fixed-point big integer"""
511    sign, man, exp, bc = s
512    offset = exp + prec
513    if sign:
514        if offset >= 0: return (-man) << offset
515        else:           return (-man) >> (-offset)
516    else:
517        if offset >= 0: return man << offset
518        else:           return man >> (-offset)
519
520
521##############################################################################
522##############################################################################
523
524#----------------------------------------------------------------------------#
525#                       Arithmetic operations, etc.                          #
526#----------------------------------------------------------------------------#
527
528def mpf_rand(prec):
529    """Return a raw mpf chosen randomly from [0, 1), with prec bits
530    in the mantissa."""
531    global getrandbits
532    if not getrandbits:
533        import random
534        getrandbits = random.getrandbits
535    return from_man_exp(getrandbits(prec), -prec, prec, round_floor)
536
537def mpf_eq(s, t):
538    """Test equality of two raw mpfs. This is simply tuple comparison
539    unless either number is nan, in which case the result is False."""
540    if not s[1] or not t[1]:
541        if s == fnan or t == fnan:
542            return False
543    return s == t
544
545def mpf_hash(s):
546    # Duplicate the new hash algorithm introduces in Python 3.2.
547    if sys.version_info >= (3, 2):
548        ssign, sman, sexp, sbc = s
549
550        # Handle special numbers
551        if not sman:
552            if s == fnan: return sys.hash_info.nan
553            if s == finf: return sys.hash_info.inf
554            if s == fninf: return -sys.hash_info.inf
555        h = sman % HASH_MODULUS
556        if sexp >= 0:
557            sexp = sexp % HASH_BITS
558        else:
559            sexp = HASH_BITS - 1 - ((-1 - sexp) % HASH_BITS)
560        h = (h << sexp) % HASH_MODULUS
561        if ssign: h = -h
562        if h == -1: h == -2
563        return int(h)
564    else:
565        try:
566            # Try to be compatible with hash values for floats and ints
567            return hash(to_float(s, strict=1))
568        except OverflowError:
569            # We must unfortunately sacrifice compatibility with ints here.
570            # We could do hash(man << exp) when the exponent is positive, but
571            # this would cause unreasonable inefficiency for large numbers.
572            return hash(s)
573
574def mpf_cmp(s, t):
575    """Compare the raw mpfs s and t. Return -1 if s < t, 0 if s == t,
576    and 1 if s > t. (Same convention as Python's cmp() function.)"""
577
578    # In principle, a comparison amounts to determining the sign of s-t.
579    # A full subtraction is relatively slow, however, so we first try to
580    # look at the components.
581    ssign, sman, sexp, sbc = s
582    tsign, tman, texp, tbc = t
583
584    # Handle zeros and special numbers
585    if not sman or not tman:
586        if s == fzero: return -mpf_sign(t)
587        if t == fzero: return mpf_sign(s)
588        if s == t: return 0
589        # Follow same convention as Python's cmp for float nan
590        if t == fnan: return 1
591        if s == finf: return 1
592        if t == fninf: return 1
593        return -1
594    # Different sides of zero
595    if ssign != tsign:
596        if not ssign: return 1
597        return -1
598    # This reduces to direct integer comparison
599    if sexp == texp:
600        if sman == tman:
601            return 0
602        if sman > tman:
603            if ssign: return -1
604            else:     return 1
605        else:
606            if ssign: return 1
607            else:     return -1
608    # Check position of the highest set bit in each number. If
609    # different, there is certainly an inequality.
610    a = sbc + sexp
611    b = tbc + texp
612    if ssign:
613        if a < b: return 1
614        if a > b: return -1
615    else:
616        if a < b: return -1
617        if a > b: return 1
618
619    # Both numbers have the same highest bit. Subtract to find
620    # how the lower bits compare.
621    delta = mpf_sub(s, t, 5, round_floor)
622    if delta[0]:
623        return -1
624    return 1
625
626def mpf_lt(s, t):
627    if s == fnan or t == fnan:
628        return False
629    return mpf_cmp(s, t) < 0
630
631def mpf_le(s, t):
632    if s == fnan or t == fnan:
633        return False
634    return mpf_cmp(s, t) <= 0
635
636def mpf_gt(s, t):
637    if s == fnan or t == fnan:
638        return False
639    return mpf_cmp(s, t) > 0
640
641def mpf_ge(s, t):
642    if s == fnan or t == fnan:
643        return False
644    return mpf_cmp(s, t) >= 0
645
646def mpf_min_max(seq):
647    min = max = seq[0]
648    for x in seq[1:]:
649        if mpf_lt(x, min): min = x
650        if mpf_gt(x, max): max = x
651    return min, max
652
653def mpf_pos(s, prec=0, rnd=round_fast):
654    """Calculate 0+s for a raw mpf (i.e., just round s to the specified
655    precision)."""
656    if prec:
657        sign, man, exp, bc = s
658        if (not man) and exp:
659            return s
660        return normalize1(sign, man, exp, bc, prec, rnd)
661    return s
662
663def mpf_neg(s, prec=None, rnd=round_fast):
664    """Negate a raw mpf (return -s), rounding the result to the
665    specified precision. The prec argument can be omitted to do the
666    operation exactly."""
667    sign, man, exp, bc = s
668    if not man:
669        if exp:
670            if s == finf: return fninf
671            if s == fninf: return finf
672        return s
673    if not prec:
674        return (1-sign, man, exp, bc)
675    return normalize1(1-sign, man, exp, bc, prec, rnd)
676
677def mpf_abs(s, prec=None, rnd=round_fast):
678    """Return abs(s) of the raw mpf s, rounded to the specified
679    precision. The prec argument can be omitted to generate an
680    exact result."""
681    sign, man, exp, bc = s
682    if (not man) and exp:
683        if s == fninf:
684            return finf
685        return s
686    if not prec:
687        if sign:
688            return (0, man, exp, bc)
689        return s
690    return normalize1(0, man, exp, bc, prec, rnd)
691
692def mpf_sign(s):
693    """Return -1, 0, or 1 (as a Python int, not a raw mpf) depending on
694    whether s is negative, zero, or positive. (Nan is taken to give 0.)"""
695    sign, man, exp, bc = s
696    if not man:
697        if s == finf: return 1
698        if s == fninf: return -1
699        return 0
700    return (-1) ** sign
701
702def mpf_add(s, t, prec=0, rnd=round_fast, _sub=0):
703    """
704    Add the two raw mpf values s and t.
705
706    With prec=0, no rounding is performed. Note that this can
707    produce a very large mantissa (potentially too large to fit
708    in memory) if exponents are far apart.
709    """
710    ssign, sman, sexp, sbc = s
711    tsign, tman, texp, tbc = t
712    tsign ^= _sub
713    # Standard case: two nonzero, regular numbers
714    if sman and tman:
715        offset = sexp - texp
716        if offset:
717            if offset > 0:
718                # Outside precision range; only need to perturb
719                if offset > 100 and prec:
720                    delta = sbc + sexp - tbc - texp
721                    if delta > prec + 4:
722                        offset = prec + 4
723                        sman <<= offset
724                        if tsign == ssign: sman += 1
725                        else:              sman -= 1
726                        return normalize1(ssign, sman, sexp-offset,
727                            bitcount(sman), prec, rnd)
728                # Add
729                if ssign == tsign:
730                    man = tman + (sman << offset)
731                # Subtract
732                else:
733                    if ssign: man = tman - (sman << offset)
734                    else:     man = (sman << offset) - tman
735                    if man >= 0:
736                        ssign = 0
737                    else:
738                        man = -man
739                        ssign = 1
740                bc = bitcount(man)
741                return normalize1(ssign, man, texp, bc, prec or bc, rnd)
742            elif offset < 0:
743                # Outside precision range; only need to perturb
744                if offset < -100 and prec:
745                    delta = tbc + texp - sbc - sexp
746                    if delta > prec + 4:
747                        offset = prec + 4
748                        tman <<= offset
749                        if ssign == tsign: tman += 1
750                        else:              tman -= 1
751                        return normalize1(tsign, tman, texp-offset,
752                            bitcount(tman), prec, rnd)
753                # Add
754                if ssign == tsign:
755                    man = sman + (tman << -offset)
756                # Subtract
757                else:
758                    if tsign: man = sman - (tman << -offset)
759                    else:     man = (tman << -offset) - sman
760                    if man >= 0:
761                        ssign = 0
762                    else:
763                        man = -man
764                        ssign = 1
765                bc = bitcount(man)
766                return normalize1(ssign, man, sexp, bc, prec or bc, rnd)
767        # Equal exponents; no shifting necessary
768        if ssign == tsign:
769            man = tman + sman
770        else:
771            if ssign: man = tman - sman
772            else:     man = sman - tman
773            if man >= 0:
774                ssign = 0
775            else:
776                man = -man
777                ssign = 1
778        bc = bitcount(man)
779        return normalize(ssign, man, texp, bc, prec or bc, rnd)
780    # Handle zeros and special numbers
781    if _sub:
782        t = mpf_neg(t)
783    if not sman:
784        if sexp:
785            if s == t or tman or not texp:
786                return s
787            return fnan
788        if tman:
789            return normalize1(tsign, tman, texp, tbc, prec or tbc, rnd)
790        return t
791    if texp:
792        return t
793    if sman:
794        return normalize1(ssign, sman, sexp, sbc, prec or sbc, rnd)
795    return s
796
797def mpf_sub(s, t, prec=0, rnd=round_fast):
798    """Return the difference of two raw mpfs, s-t. This function is
799    simply a wrapper of mpf_add that changes the sign of t."""
800    return mpf_add(s, t, prec, rnd, 1)
801
802def mpf_sum(xs, prec=0, rnd=round_fast, absolute=False):
803    """
804    Sum a list of mpf values efficiently and accurately
805    (typically no temporary roundoff occurs). If prec=0,
806    the final result will not be rounded either.
807
808    There may be roundoff error or cancellation if extremely
809    large exponent differences occur.
810
811    With absolute=True, sums the absolute values.
812    """
813    man = 0
814    exp = 0
815    max_extra_prec = prec*2 or 1000000  # XXX
816    special = None
817    for x in xs:
818        xsign, xman, xexp, xbc = x
819        if xman:
820            if xsign and not absolute:
821                xman = -xman
822            delta = xexp - exp
823            if xexp >= exp:
824                # x much larger than existing sum?
825                # first: quick test
826                if (delta > max_extra_prec) and \
827                    ((not man) or delta-bitcount(abs(man)) > max_extra_prec):
828                    man = xman
829                    exp = xexp
830                else:
831                    man += (xman << delta)
832            else:
833                delta = -delta
834                # x much smaller than existing sum?
835                if delta-xbc > max_extra_prec:
836                    if not man:
837                        man, exp = xman, xexp
838                else:
839                    man = (man << delta) + xman
840                    exp = xexp
841        elif xexp:
842            if absolute:
843                x = mpf_abs(x)
844            special = mpf_add(special or fzero, x, 1)
845    # Will be inf or nan
846    if special:
847        return special
848    return from_man_exp(man, exp, prec, rnd)
849
850def gmpy_mpf_mul(s, t, prec=0, rnd=round_fast):
851    """Multiply two raw mpfs"""
852    ssign, sman, sexp, sbc = s
853    tsign, tman, texp, tbc = t
854    sign = ssign ^ tsign
855    man = sman*tman
856    if man:
857        bc = bitcount(man)
858        if prec:
859            return normalize1(sign, man, sexp+texp, bc, prec, rnd)
860        else:
861            return (sign, man, sexp+texp, bc)
862    s_special = (not sman) and sexp
863    t_special = (not tman) and texp
864    if not s_special and not t_special:
865        return fzero
866    if fnan in (s, t): return fnan
867    if (not tman) and texp: s, t = t, s
868    if t == fzero: return fnan
869    return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)]
870
871def gmpy_mpf_mul_int(s, n, prec, rnd=round_fast):
872    """Multiply by a Python integer."""
873    sign, man, exp, bc = s
874    if not man:
875        return mpf_mul(s, from_int(n), prec, rnd)
876    if not n:
877        return fzero
878    if n < 0:
879        sign ^= 1
880        n = -n
881    man *= n
882    return normalize(sign, man, exp, bitcount(man), prec, rnd)
883
884def python_mpf_mul(s, t, prec=0, rnd=round_fast):
885    """Multiply two raw mpfs"""
886    ssign, sman, sexp, sbc = s
887    tsign, tman, texp, tbc = t
888    sign = ssign ^ tsign
889    man = sman*tman
890    if man:
891        bc = sbc + tbc - 1
892        bc += int(man>>bc)
893        if prec:
894            return normalize1(sign, man, sexp+texp, bc, prec, rnd)
895        else:
896            return (sign, man, sexp+texp, bc)
897    s_special = (not sman) and sexp
898    t_special = (not tman) and texp
899    if not s_special and not t_special:
900        return fzero
901    if fnan in (s, t): return fnan
902    if (not tman) and texp: s, t = t, s
903    if t == fzero: return fnan
904    return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)]
905
906def python_mpf_mul_int(s, n, prec, rnd=round_fast):
907    """Multiply by a Python integer."""
908    sign, man, exp, bc = s
909    if not man:
910        return mpf_mul(s, from_int(n), prec, rnd)
911    if not n:
912        return fzero
913    if n < 0:
914        sign ^= 1
915        n = -n
916    man *= n
917    # Generally n will be small
918    if n < 1024:
919        bc += bctable[int(n)] - 1
920    else:
921        bc += bitcount(n) - 1
922    bc += int(man>>bc)
923    return normalize(sign, man, exp, bc, prec, rnd)
924
925
926if BACKEND == 'gmpy':
927    mpf_mul = gmpy_mpf_mul
928    mpf_mul_int = gmpy_mpf_mul_int
929else:
930    mpf_mul = python_mpf_mul
931    mpf_mul_int = python_mpf_mul_int
932
933def mpf_shift(s, n):
934    """Quickly multiply the raw mpf s by 2**n without rounding."""
935    sign, man, exp, bc = s
936    if not man:
937        return s
938    return sign, man, exp+n, bc
939
940def mpf_frexp(x):
941    """Convert x = y*2**n to (y, n) with abs(y) in [0.5, 1) if nonzero"""
942    sign, man, exp, bc = x
943    if not man:
944        if x == fzero:
945            return (fzero, 0)
946        else:
947            raise ValueError
948    return mpf_shift(x, -bc-exp), bc+exp
949
950def mpf_div(s, t, prec, rnd=round_fast):
951    """Floating-point division"""
952    ssign, sman, sexp, sbc = s
953    tsign, tman, texp, tbc = t
954    if not sman or not tman:
955        if s == fzero:
956            if t == fzero: raise ZeroDivisionError
957            if t == fnan: return fnan
958            return fzero
959        if t == fzero:
960            raise ZeroDivisionError
961        s_special = (not sman) and sexp
962        t_special = (not tman) and texp
963        if s_special and t_special:
964            return fnan
965        if s == fnan or t == fnan:
966            return fnan
967        if not t_special:
968            if t == fzero:
969                return fnan
970            return {1:finf, -1:fninf}[mpf_sign(s) * mpf_sign(t)]
971        return fzero
972    sign = ssign ^ tsign
973    if tman == 1:
974        return normalize1(sign, sman, sexp-texp, sbc, prec, rnd)
975    # Same strategy as for addition: if there is a remainder, perturb
976    # the result a few bits outside the precision range before rounding
977    extra = prec - sbc + tbc + 5
978    if extra < 5:
979        extra = 5
980    quot, rem = divmod(sman<<extra, tman)
981    if rem:
982        quot = (quot<<1) + 1
983        extra += 1
984        return normalize1(sign, quot, sexp-texp-extra, bitcount(quot), prec, rnd)
985    return normalize(sign, quot, sexp-texp-extra, bitcount(quot), prec, rnd)
986
987def mpf_rdiv_int(n, t, prec, rnd=round_fast):
988    """Floating-point division n/t with a Python integer as numerator"""
989    sign, man, exp, bc = t
990    if not n or not man:
991        return mpf_div(from_int(n), t, prec, rnd)
992    if n < 0:
993        sign ^= 1
994        n = -n
995    extra = prec + bc + 5
996    quot, rem = divmod(n<<extra, man)
997    if rem:
998        quot = (quot<<1) + 1
999        extra += 1
1000        return normalize1(sign, quot, -exp-extra, bitcount(quot), prec, rnd)
1001    return normalize(sign, quot, -exp-extra, bitcount(quot), prec, rnd)
1002
1003def mpf_mod(s, t, prec, rnd=round_fast):
1004    ssign, sman, sexp, sbc = s
1005    tsign, tman, texp, tbc = t
1006    if ((not sman) and sexp) or ((not tman) and texp):
1007        return fnan
1008    # Important special case: do nothing if t is larger
1009    if ssign == tsign and texp > sexp+sbc:
1010        return s
1011    # Another important special case: this allows us to do e.g. x % 1.0
1012    # to find the fractional part of x, and it will work when x is huge.
1013    if tman == 1 and sexp > texp+tbc:
1014        return fzero
1015    base = min(sexp, texp)
1016    sman = (-1)**ssign * sman
1017    tman = (-1)**tsign * tman
1018    man = (sman << (sexp-base)) % (tman << (texp-base))
1019    if man >= 0:
1020        sign = 0
1021    else:
1022        man = -man
1023        sign = 1
1024    return normalize(sign, man, base, bitcount(man), prec, rnd)
1025
1026reciprocal_rnd = {
1027  round_down : round_up,
1028  round_up : round_down,
1029  round_floor : round_ceiling,
1030  round_ceiling : round_floor,
1031  round_nearest : round_nearest
1032}
1033
1034negative_rnd = {
1035  round_down : round_down,
1036  round_up : round_up,
1037  round_floor : round_ceiling,
1038  round_ceiling : round_floor,
1039  round_nearest : round_nearest
1040}
1041
1042def mpf_pow_int(s, n, prec, rnd=round_fast):
1043    """Compute s**n, where s is a raw mpf and n is a Python integer."""
1044    sign, man, exp, bc = s
1045
1046    if (not man) and exp:
1047        if s == finf:
1048            if n > 0: return s
1049            if n == 0: return fnan
1050            return fzero
1051        if s == fninf:
1052            if n > 0: return [finf, fninf][n & 1]
1053            if n == 0: return fnan
1054            return fzero
1055        return fnan
1056
1057    n = int(n)
1058    if n == 0: return fone
1059    if n == 1: return mpf_pos(s, prec, rnd)
1060    if n == 2:
1061        _, man, exp, bc = s
1062        if not man:
1063            return fzero
1064        man = man*man
1065        if man == 1:
1066            return (0, MPZ_ONE, exp+exp, 1)
1067        bc = bc + bc - 2
1068        bc += bctable[int(man>>bc)]
1069        return normalize1(0, man, exp+exp, bc, prec, rnd)
1070    if n == -1: return mpf_div(fone, s, prec, rnd)
1071    if n < 0:
1072        inverse = mpf_pow_int(s, -n, prec+5, reciprocal_rnd[rnd])
1073        return mpf_div(fone, inverse, prec, rnd)
1074
1075    result_sign = sign & n
1076
1077    # Use exact integer power when the exact mantissa is small
1078    if man == 1:
1079        return (result_sign, MPZ_ONE, exp*n, 1)
1080    if bc*n < 1000:
1081        man **= n
1082        return normalize1(result_sign, man, exp*n, bitcount(man), prec, rnd)
1083
1084    # Use directed rounding all the way through to maintain rigorous
1085    # bounds for interval arithmetic
1086    rounds_down = (rnd == round_nearest) or \
1087        shifts_down[rnd][result_sign]
1088
1089    # Now we perform binary exponentiation. Need to estimate precision
1090    # to avoid rounding errors from temporary operations. Roughly log_2(n)
1091    # operations are performed.
1092    workprec = prec + 4*bitcount(n) + 4
1093    _, pm, pe, pbc = fone
1094    while 1:
1095        if n & 1:
1096            pm = pm*man
1097            pe = pe+exp
1098            pbc += bc - 2
1099            pbc = pbc + bctable[int(pm >> pbc)]
1100            if pbc > workprec:
1101                if rounds_down:
1102                    pm = pm >> (pbc-workprec)
1103                else:
1104                    pm = -((-pm) >> (pbc-workprec))
1105                pe += pbc - workprec
1106                pbc = workprec
1107            n -= 1
1108            if not n:
1109                break
1110        man = man*man
1111        exp = exp+exp
1112        bc = bc + bc - 2
1113        bc = bc + bctable[int(man >> bc)]
1114        if bc > workprec:
1115            if rounds_down:
1116                man = man >> (bc-workprec)
1117            else:
1118                man = -((-man) >> (bc-workprec))
1119            exp += bc - workprec
1120            bc = workprec
1121        n = n // 2
1122
1123    return normalize(result_sign, pm, pe, pbc, prec, rnd)
1124
1125
1126def mpf_perturb(x, eps_sign, prec, rnd):
1127    """
1128    For nonzero x, calculate x + eps with directed rounding, where
1129    eps < prec relatively and eps has the given sign (0 for
1130    positive, 1 for negative).
1131
1132    With rounding to nearest, this is taken to simply normalize
1133    x to the given precision.
1134    """
1135    if rnd == round_nearest:
1136        return mpf_pos(x, prec, rnd)
1137    sign, man, exp, bc = x
1138    eps = (eps_sign, MPZ_ONE, exp+bc-prec-1, 1)
1139    if sign:
1140        away = (rnd in (round_down, round_ceiling)) ^ eps_sign
1141    else:
1142        away = (rnd in (round_up, round_ceiling)) ^ eps_sign
1143    if away:
1144        return mpf_add(x, eps, prec, rnd)
1145    else:
1146        return mpf_pos(x, prec, rnd)
1147
1148
1149#----------------------------------------------------------------------------#
1150#                              Radix conversion                              #
1151#----------------------------------------------------------------------------#
1152
1153def to_digits_exp(s, dps):
1154    """Helper function for representing the floating-point number s as
1155    a decimal with dps digits. Returns (sign, string, exponent) where
1156    sign is '' or '-', string is the digit string, and exponent is
1157    the decimal exponent as an int.
1158
1159    If inexact, the decimal representation is rounded toward zero."""
1160
1161    # Extract sign first so it doesn't mess up the string digit count
1162    if s[0]:
1163        sign = '-'
1164        s = mpf_neg(s)
1165    else:
1166        sign = ''
1167    _sign, man, exp, bc = s
1168
1169    if not man:
1170        return '', '0', 0
1171
1172    bitprec = int(dps * math.log(10,2)) + 10
1173
1174    # Cut down to size
1175    # TODO: account for precision when doing this
1176    exp_from_1 = exp + bc
1177    if abs(exp_from_1) > 3500:
1178        from .libelefun import mpf_ln2, mpf_ln10
1179        # Set b = int(exp * log(2)/log(10))
1180        # If exp is huge, we must use high-precision arithmetic to
1181        # find the nearest power of ten
1182        expprec = bitcount(abs(exp)) + 5
1183        tmp = from_int(exp)
1184        tmp = mpf_mul(tmp, mpf_ln2(expprec))
1185        tmp = mpf_div(tmp, mpf_ln10(expprec), expprec)
1186        b = to_int(tmp)
1187        s = mpf_div(s, mpf_pow_int(ften, b, bitprec), bitprec)
1188        _sign, man, exp, bc = s
1189        exponent = b
1190    else:
1191        exponent = 0
1192
1193    # First, calculate mantissa digits by converting to a binary
1194    # fixed-point number and then converting that number to
1195    # a decimal fixed-point number.
1196    fixprec = max(bitprec - exp - bc, 0)
1197    fixdps = int(fixprec / math.log(10,2) + 0.5)
1198    sf = to_fixed(s, fixprec)
1199    sd = bin_to_radix(sf, fixprec, 10, fixdps)
1200    digits = numeral(sd, base=10, size=dps)
1201
1202    exponent += len(digits) - fixdps - 1
1203    return sign, digits, exponent
1204
1205def to_str(s, dps, strip_zeros=True, min_fixed=None, max_fixed=None,
1206    show_zero_exponent=False):
1207    """
1208    Convert a raw mpf to a decimal floating-point literal with at
1209    most `dps` decimal digits in the mantissa (not counting extra zeros
1210    that may be inserted for visual purposes).
1211
1212    The number will be printed in fixed-point format if the position
1213    of the leading digit is strictly between min_fixed
1214    (default = min(-dps/3,-5)) and max_fixed (default = dps).
1215
1216    To force fixed-point format always, set min_fixed = -inf,
1217    max_fixed = +inf. To force floating-point format, set
1218    min_fixed >= max_fixed.
1219
1220    The literal is formatted so that it can be parsed back to a number
1221    by to_str, float() or Decimal().
1222    """
1223
1224    # Special numbers
1225    if not s[1]:
1226        if s == fzero:
1227            if dps: t = '0.0'
1228            else:   t = '.0'
1229            if show_zero_exponent:
1230                t += 'e+0'
1231            return t
1232        if s == finf: return '+inf'
1233        if s == fninf: return '-inf'
1234        if s == fnan: return 'nan'
1235        raise ValueError
1236
1237    if min_fixed is None: min_fixed = min(-(dps//3), -5)
1238    if max_fixed is None: max_fixed = dps
1239
1240    # to_digits_exp rounds to floor.
1241    # This sometimes kills some instances of "...00001"
1242    sign, digits, exponent = to_digits_exp(s, dps+3)
1243
1244    # No digits: show only .0; round exponent to nearest
1245    if not dps:
1246        if digits[0] in '56789':
1247            exponent += 1
1248        digits = ".0"
1249
1250    else:
1251        # Rounding up kills some instances of "...99999"
1252        if len(digits) > dps and digits[dps] in '56789':
1253            digits = digits[:dps]
1254            i = dps - 1
1255            while i >= 0 and digits[i] == '9':
1256                i -= 1
1257            if i >= 0:
1258                digits = digits[:i] + str(int(digits[i]) + 1) + '0' * (dps - i - 1)
1259            else:
1260                digits = '1' + '0' * (dps - 1)
1261                exponent += 1
1262        else:
1263            digits = digits[:dps]
1264
1265        # Prettify numbers close to unit magnitude
1266        if min_fixed < exponent < max_fixed:
1267            if exponent < 0:
1268                digits = ("0"*int(-exponent)) + digits
1269                split = 1
1270            else:
1271                split = exponent + 1
1272                if split > dps:
1273                    digits += "0"*(split-dps)
1274            exponent = 0
1275        else:
1276            split = 1
1277
1278        digits = (digits[:split] + "." + digits[split:])
1279
1280        if strip_zeros:
1281            # Clean up trailing zeros
1282            digits = digits.rstrip('0')
1283            if digits[-1] == ".":
1284                digits += "0"
1285
1286    if exponent == 0 and dps and not show_zero_exponent: return sign + digits
1287    if exponent >= 0: return sign + digits + "e+" + str(exponent)
1288    if exponent < 0: return sign + digits + "e" + str(exponent)
1289
1290def str_to_man_exp(x, base=10):
1291    """Helper function for from_str."""
1292    x = x.lower().rstrip('l')
1293    # Verify that the input is a valid float literal
1294    float(x)
1295    # Split into mantissa, exponent
1296    parts = x.split('e')
1297    if len(parts) == 1:
1298        exp = 0
1299    else: # == 2
1300        x = parts[0]
1301        exp = int(parts[1])
1302    # Look for radix point in mantissa
1303    parts = x.split('.')
1304    if len(parts) == 2:
1305        a, b = parts[0], parts[1].rstrip('0')
1306        exp -= len(b)
1307        x = a + b
1308    x = MPZ(int(x, base))
1309    return x, exp
1310
1311special_str = {'inf':finf, '+inf':finf, '-inf':fninf, 'nan':fnan}
1312
1313def from_str(x, prec, rnd=round_fast):
1314    """Create a raw mpf from a decimal literal, rounding in the
1315    specified direction if the input number cannot be represented
1316    exactly as a binary floating-point number with the given number of
1317    bits. The literal syntax accepted is the same as for Python
1318    floats.
1319
1320    TODO: the rounding does not work properly for large exponents.
1321    """
1322    x = x.lower().strip()
1323    if x in special_str:
1324        return special_str[x]
1325
1326    if '/' in x:
1327        p, q = x.split('/')
1328        p, q = p.rstrip('l'), q.rstrip('l')
1329        return from_rational(int(p), int(q), prec, rnd)
1330
1331    man, exp = str_to_man_exp(x, base=10)
1332
1333    # XXX: appropriate cutoffs & track direction
1334    # note no factors of 5
1335    if abs(exp) > 400:
1336        s = from_int(man, prec+10)
1337        s = mpf_mul(s, mpf_pow_int(ften, exp, prec+10), prec, rnd)
1338    else:
1339        if exp >= 0:
1340            s = from_int(man * 10**exp, prec, rnd)
1341        else:
1342            s = from_rational(man, 10**-exp, prec, rnd)
1343    return s
1344
1345# Binary string conversion. These are currently mainly used for debugging
1346# and could use some improvement in the future
1347
1348def from_bstr(x):
1349    man, exp = str_to_man_exp(x, base=2)
1350    man = MPZ(man)
1351    sign = 0
1352    if man < 0:
1353        man = -man
1354        sign = 1
1355    bc = bitcount(man)
1356    return normalize(sign, man, exp, bc, bc, round_floor)
1357
1358def to_bstr(x):
1359    sign, man, exp, bc = x
1360    return ['','-'][sign] + numeral(man, size=bitcount(man), base=2) + ("e%i" % exp)
1361
1362
1363#----------------------------------------------------------------------------#
1364#                                Square roots                                #
1365#----------------------------------------------------------------------------#
1366
1367
1368def mpf_sqrt(s, prec, rnd=round_fast):
1369    """
1370    Compute the square root of a nonnegative mpf value. The
1371    result is correctly rounded.
1372    """
1373    sign, man, exp, bc = s
1374    if sign:
1375        raise ComplexResult("square root of a negative number")
1376    if not man:
1377        return s
1378    if exp & 1:
1379        exp -= 1
1380        man <<= 1
1381        bc += 1
1382    elif man == 1:
1383        return normalize1(sign, man, exp//2, bc, prec, rnd)
1384    shift = max(4, 2*prec-bc+4)
1385    shift += shift & 1
1386    if rnd in 'fd':
1387        man = isqrt(man<<shift)
1388    else:
1389        man, rem = sqrtrem(man<<shift)
1390        # Perturb up
1391        if rem:
1392            man = (man<<1)+1
1393            shift += 2
1394    return from_man_exp(man, (exp-shift)//2, prec, rnd)
1395
1396def mpf_hypot(x, y, prec, rnd=round_fast):
1397    """Compute the Euclidean norm sqrt(x**2 + y**2) of two raw mpfs
1398    x and y."""
1399    if y == fzero: return mpf_abs(x, prec, rnd)
1400    if x == fzero: return mpf_abs(y, prec, rnd)
1401    hypot2 = mpf_add(mpf_mul(x,x), mpf_mul(y,y), prec+4)
1402    return mpf_sqrt(hypot2, prec, rnd)
1403
1404
1405if BACKEND == 'sage':
1406    try:
1407        import sage.libs.mpmath.ext_libmp as ext_lib
1408        mpf_add = ext_lib.mpf_add
1409        mpf_sub = ext_lib.mpf_sub
1410        mpf_mul = ext_lib.mpf_mul
1411        mpf_div = ext_lib.mpf_div
1412        mpf_sqrt = ext_lib.mpf_sqrt
1413    except ImportError:
1414        pass
1415