1"""
2Adaptive numerical evaluation of Diofant expressions, using mpmath
3for mathematical functions.
4"""
5
6import math
7import numbers
8
9from mpmath import inf as mpmath_inf
10from mpmath import (libmp, make_mpc, make_mpf, mp, mpc, mpf, nsum, quadosc,
11                    quadts, workprec)
12from mpmath.libmp import bitcount as mpmath_bitcount
13from mpmath.libmp import (fhalf, fnan, fnone, fone, from_int, from_man_exp,
14                          from_rational, fzero, mpf_abs, mpf_add, mpf_atan,
15                          mpf_atan2, mpf_cmp, mpf_cos, mpf_e, mpf_exp, mpf_log,
16                          mpf_lt, mpf_mul, mpf_neg, mpf_pi, mpf_pow,
17                          mpf_pow_int, mpf_shift, mpf_sin, mpf_sqrt, normalize,
18                          round_nearest)
19from mpmath.libmp.backend import MPZ
20from mpmath.libmp.gammazeta import mpf_bernoulli
21from mpmath.libmp.libmpc import _infs_nan
22from mpmath.libmp.libmpf import dps_to_prec, prec_to_dps
23
24from .compatibility import is_sequence
25from .singleton import S
26from .sympify import sympify
27
28
29LG10 = math.log(10, 2)
30rnd = round_nearest
31
32
33def bitcount(n):
34    return mpmath_bitcount(int(n))
35
36
37# Used in a few places as placeholder values to denote exponents and
38# precision levels, e.g. of exact numbers. Must be careful to avoid
39# passing these to mpmath functions or returning them in final results.
40INF = float(mpmath_inf)
41MINUS_INF = float(-mpmath_inf)
42
43# ~= 100 digits. Real men set this to INF.
44DEFAULT_MAXPREC = int(110*LG10)  # keep in sync with maxn kwarg of evalf
45
46
47class PrecisionExhausted(ArithmeticError):
48    """Raised when precision is exhausted."""
49
50
51############################################################################
52#                                                                          #
53#              Helper functions for arithmetic and complex parts           #
54#                                                                          #
55############################################################################
56
57
58"""
59An mpf value tuple is a tuple of integers (sign, man, exp, bc)
60representing a floating-point number: [1, -1][sign]*man*2**exp where
61sign is 0 or 1 and bc should correspond to the number of bits used to
62represent the mantissa (man) in binary notation, e.g.
63
64>>> sign, man, exp, bc = 0, 5, 1, 3
65>>> n = [1, -1][sign]*man*2**exp
66>>> n, bitcount(man)
67(10, 3)
68
69A temporary result is a tuple (re, im, re_acc, im_acc) where
70re and im are nonzero mpf value tuples representing approximate
71numbers, or None to denote exact zeros.
72
73re_acc, im_acc are integers denoting log2(e) where e is the estimated
74relative accuracy of the respective complex part, but may be anything
75if the corresponding complex part is None.
76
77"""
78
79
80def fastlog(x):
81    """Fast approximation of log2(x) for an mpf value tuple x.
82
83    Notes: Calculated as exponent + width of mantissa. This is an
84    approximation for two reasons: 1) it gives the ceil(log2(abs(x)))
85    value and 2) it is too high by 1 in the case that x is an exact
86    power of 2. Although this is easy to remedy by testing to see if
87    the odd mpf mantissa is 1 (indicating that one was dealing with
88    an exact power of 2) that would decrease the speed and is not
89    necessary as this is only being used as an approximation for the
90    number of bits in x. The correct return value could be written as
91    "x[2] + (x[3] if x[1] != 1 else 0)".
92        Since mpf tuples always have an odd mantissa, no check is done
93    to see if the mantissa is a multiple of 2 (in which case the
94    result would be too large by 1).
95
96    Examples
97    ========
98
99    >>> s, m, e = 0, 5, 1
100    >>> bc = bitcount(m)
101    >>> n = [1, -1][s]*m*2**e
102    >>> n, (log(n)/log(2)).evalf(2), fastlog((s, m, e, bc))
103    (10, 3.3, 4)
104
105    """
106    if not x or x == fzero:
107        return MINUS_INF
108    return x[2] + x[3]
109
110
111def pure_complex(v):
112    """Return a and b if v matches a + I*b where b is not zero and
113    a and b are Numbers, else None.
114
115    >>> a, b = Tuple(2, 3)
116    >>> pure_complex(a)
117    >>> pure_complex(a + b*I)
118    (2, 3)
119    >>> pure_complex(I)
120    (0, 1)
121
122    """
123    from .numbers import I
124    h, t = v.as_coeff_Add()
125    c, i = t.as_coeff_Mul()
126    if i is I:
127        return h, c
128
129
130def scaled_zero(mag, sign=1):
131    """Return an mpf representing a power of two with magnitude ``mag``
132    and -1 for precision. Or, if ``mag`` is a scaled_zero tuple, then just
133    remove the sign from within the list that it was initially wrapped
134    in.
135
136    Examples
137    ========
138
139    >>> z, p = scaled_zero(100)
140    >>> z, p
141    (([0], 1, 100, 1), -1)
142    >>> ok = scaled_zero(z)
143    >>> ok
144    (0, 1, 100, 1)
145    >>> Float(ok)
146    1.26765060022823e+30
147    >>> Float(ok, p)
148    0.e+30
149    >>> ok, p = scaled_zero(100, -1)
150    >>> Float(scaled_zero(ok), p)
151    -0.e+30
152
153    """
154    if type(mag) is tuple and len(mag) == 4 and iszero(mag, scaled=True):
155        return (mag[0][0],) + mag[1:]
156    elif isinstance(mag, numbers.Integral):
157        if sign not in [-1, 1]:
158            raise ValueError('sign must be +/-1')
159        rv, p = mpf_shift(fone, mag), -1
160        s = 0 if sign == 1 else 1
161        rv = ([s],) + rv[1:]
162        return rv, p
163    else:
164        raise ValueError('scaled zero expects int or scaled_zero tuple.')
165
166
167def iszero(mpf, scaled=False):
168    if not scaled:
169        return not mpf or not mpf[1] and not mpf[-1]
170    return mpf and type(mpf[0]) is list and mpf[1] == mpf[-1] == 1
171
172
173def complex_accuracy(result):
174    """
175    Returns relative accuracy of a complex number with given accuracies
176    for the real and imaginary parts. The relative accuracy is defined
177    in the complex norm sense as ||z|+|error|| / |z| where error
178    is equal to (real absolute error) + (imag absolute error)*i.
179
180    The full expression for the (logarithmic) error can be approximated
181    easily by using the max norm to approximate the complex norm.
182
183    In the worst case (re and im equal), this is wrong by a factor
184    sqrt(2), or by log2(sqrt(2)) = 0.5 bit.
185
186    """
187    re, im, re_acc, im_acc = result
188    if not im:
189        if not re:
190            return INF
191        return re_acc
192    if not re:
193        return im_acc
194    re_size = fastlog(re)
195    im_size = fastlog(im)
196    absolute_error = max(re_size - re_acc, im_size - im_acc)
197    relative_error = absolute_error - max(re_size, im_size)
198    return -relative_error
199
200
201def get_abs(expr, prec, options):
202    re, im, re_acc, im_acc = evalf(expr, prec + 2, options)
203    if not re:
204        re, re_acc, im, im_acc = im, im_acc, re, re_acc
205    if im:
206        return libmp.mpc_abs((re, im), prec), None, re_acc, None
207    elif re:
208        return mpf_abs(re), None, re_acc, None
209    else:
210        return None, None, None, None
211
212
213def get_complex_part(expr, no, prec, options):
214    """Selector no = 0 for real part, no = 1 for imaginary part."""
215    workprec = prec
216    i = 0
217    while 1:
218        res = evalf(expr, workprec, options)
219        value, accuracy = res[no::2]
220        if (not value) or accuracy >= prec or expr.is_Float:
221            return value, None, accuracy, None
222        workprec += max(30, 2**i)
223        i += 1
224
225
226def evalf_abs(expr, prec, options):
227    return get_abs(expr.args[0], prec, options)
228
229
230def evalf_re(expr, prec, options):
231    return get_complex_part(expr.args[0], 0, prec, options)
232
233
234def evalf_im(expr, prec, options):
235    return get_complex_part(expr.args[0], 1, prec, options)
236
237
238def finalize_complex(re, im, prec):
239    assert re != fzero or im != fzero
240
241    if re == fzero:
242        return None, im, None, prec
243    elif im == fzero:
244        return re, None, prec, None
245
246    size_re = fastlog(re)
247    size_im = fastlog(im)
248    if size_re > size_im:
249        re_acc = prec
250        im_acc = prec + min(-(size_re - size_im), 0)
251    else:
252        im_acc = prec
253        re_acc = prec + min(-(size_im - size_re), 0)
254    return re, im, re_acc, im_acc
255
256
257def chop_parts(value, prec):
258    """Chop off tiny real or complex parts."""
259    re, im, re_acc, im_acc = value
260    # chop based on absolute value
261    if re and re not in _infs_nan and (fastlog(re) < -prec + 4):
262        re, re_acc = None, None
263    if im and im not in _infs_nan and (fastlog(im) < -prec + 4):
264        im, im_acc = None, None
265    return re, im, re_acc, im_acc
266
267
268def check_target(expr, result, prec):
269    a = complex_accuracy(result)
270    if a < prec:
271        raise PrecisionExhausted('Failed to distinguish the expression: \n\n%s\n\n'
272                                 'from zero. Try simplifying the input, using chop=True, or providing '
273                                 'a higher maxn for evalf' % expr)
274
275
276############################################################################
277#                                                                          #
278#                            Arithmetic operations                         #
279#                                                                          #
280############################################################################
281
282
283def add_terms(terms, prec, target_prec):
284    """
285    Helper for evalf_add. Adds a list of (mpfval, accuracy) terms.
286
287    Returns
288    =======
289
290    - None, None if there are no non-zero terms;
291    - terms[0] if there is only 1 term;
292    - scaled_zero if the sum of the terms produces a zero by cancellation
293      e.g. mpfs representing 1 and -1 would produce a scaled zero which need
294      special handling since they are not actually zero and they are purposely
295      malformed to ensure that they can't be used in anything but accuracy
296      calculations;
297    - a tuple that is scaled to target_prec that corresponds to the
298      sum of the terms.
299
300    The returned mpf tuple will be normalized to target_prec; the input
301    prec is used to define the working precision.
302
303    XXX explain why this is needed and why one can't just loop using mpf_add
304
305    """
306    terms = [t for t in terms if not iszero(t)]
307    if not terms:
308        return None, None
309    elif len(terms) == 1:
310        return terms[0]
311
312    # see if any argument is NaN or oo and thus warrants a special return
313    special = []
314    from .numbers import Float, nan
315    for t in terms:
316        arg = Float._new(t[0], 1)
317        if arg is nan or arg.is_infinite:
318            special.append(arg)
319    if special:
320        from .add import Add
321        rv = evalf(Add(*special), prec + 4, {})
322        return rv[0], rv[2]
323
324    working_prec = 2*prec
325    sum_man, sum_exp, absolute_error = 0, 0, MINUS_INF
326
327    for x, accuracy in terms:
328        sign, man, exp, bc = x
329        if sign:
330            man = -man
331        absolute_error = max(absolute_error, bc + exp - accuracy)
332        delta = exp - sum_exp
333        if exp >= sum_exp:
334            # x much larger than existing sum?
335            # first: quick test
336            if ((delta > working_prec) and
337                ((not sum_man) or
338                 delta - bitcount(abs(sum_man)) > working_prec)):
339                sum_man = man
340                sum_exp = exp
341            else:
342                sum_man += (man << delta)
343        else:
344            delta = -delta
345            # x much smaller than existing sum?
346            if delta - bc > working_prec:
347                if not sum_man:
348                    sum_man, sum_exp = man, exp
349            else:
350                sum_man = (sum_man << delta) + man
351                sum_exp = exp
352    if not sum_man:
353        return scaled_zero(absolute_error)
354    if sum_man < 0:
355        sum_sign = 1
356        sum_man = -sum_man
357    else:
358        sum_sign = 0
359    sum_bc = bitcount(sum_man)
360    sum_accuracy = sum_exp + sum_bc - absolute_error
361    r = normalize(sum_sign, sum_man, sum_exp, sum_bc, target_prec,
362                  rnd), sum_accuracy
363    return r
364
365
366def evalf_add(v, prec, options):
367    res = pure_complex(v)
368    if res:
369        h, c = res
370        re, _, re_acc, _ = evalf(h, prec, options)
371        im, _, im_acc, _ = evalf(c, prec, options)
372        return re, im, re_acc, im_acc
373
374    oldmaxprec = options['maxprec']
375
376    i = 0
377    target_prec = prec
378    while 1:
379        options['maxprec'] = min(oldmaxprec, 2*prec)
380
381        terms = [evalf(arg, prec + 10, options) for arg in v.args]
382        re, re_acc = add_terms(
383            [a[0::2] for a in terms if a[0]], prec, target_prec)
384        im, im_acc = add_terms(
385            [a[1::2] for a in terms if a[1]], prec, target_prec)
386        acc = complex_accuracy((re, im, re_acc, im_acc))
387        if acc >= target_prec:
388            break
389        else:
390            if (prec - target_prec) > options['maxprec']:
391                break
392
393            prec = prec + max(10 + 2**i, target_prec - acc)
394            i += 1
395
396    options['maxprec'] = oldmaxprec
397    if iszero(re, scaled=True):
398        re = scaled_zero(re)
399    if iszero(im, scaled=True):
400        im = scaled_zero(im)
401    return re, im, re_acc, im_acc
402
403
404def evalf_mul(v, prec, options):
405    res = pure_complex(v)
406    if res:
407        # the only pure complex that is a mul is h*I
408        _, h = res
409        im, _, im_acc, _ = evalf(h, prec, options)
410        return None, im, None, im_acc
411    args = list(v.args)
412
413    # see if any argument is NaN or oo and thus warrants a special return
414    special, other = [], []
415    from .numbers import Float, nan
416    for arg in args:
417        arg = evalf(arg, prec, options)
418        if arg[0] is None:
419            continue
420        arg = Float._new(arg[0], 1)
421        if arg is nan or arg.is_infinite:
422            special.append(arg)
423        else:
424            other.append(arg)
425    if special:
426        from .mul import Mul
427        other = Mul(*other)
428        special = Mul(*special)
429        return evalf(special*other, prec + 4, {})
430
431    # With guard digits, multiplication in the real case does not destroy
432    # accuracy. This is also true in the complex case when considering the
433    # total accuracy; however accuracy for the real or imaginary parts
434    # separately may be lower.
435    acc = prec
436
437    # XXX: big overestimate
438    working_prec = prec + len(args) + 5
439
440    # Empty product is 1
441    start = man, exp, bc = MPZ(1), 0, 1
442
443    # First, we multiply all pure real or pure imaginary numbers.
444    # direction tells us that the result should be multiplied by
445    # I**direction; all other numbers get put into complex_factors
446    # to be multiplied out after the first phase.
447    last = len(args)
448    direction = 0
449    args.append(S.One)
450    complex_factors = []
451
452    for i, arg in enumerate(args):
453        if i != last and pure_complex(arg):
454            args[-1] = (args[-1]*arg).expand()
455            continue
456        elif i == last and arg is S.One:
457            continue
458        re, im, re_acc, im_acc = evalf(arg, working_prec, options)
459        if re and im:
460            complex_factors.append((re, im, re_acc, im_acc))
461            continue
462        elif re:
463            (s, m, e, b), w_acc = re, re_acc
464        elif im:
465            (s, m, e, b), w_acc = im, im_acc
466            direction += 1
467        else:
468            return None, None, None, None
469        direction += 2*s
470        man *= m
471        exp += e
472        bc += b
473        if bc > 3*working_prec:
474            man >>= working_prec
475            exp += working_prec
476        acc = min(acc, w_acc)
477    sign = (direction & 2) >> 1
478    if not complex_factors:
479        v = normalize(sign, man, exp, bitcount(man), prec, rnd)
480        # multiply by i
481        if direction & 1:
482            return None, v, None, acc
483        else:
484            return v, None, acc, None
485    else:
486        # initialize with the first term
487        if (man, exp, bc) != start:
488            # there was a real part; give it an imaginary part
489            re, im = (sign, man, exp, bitcount(man)), (0, MPZ(0), 0, 0)
490            i0 = 0
491        else:
492            # there is no real part to start (other than the starting 1)
493            wre, wim, wre_acc, wim_acc = complex_factors[0]
494            acc = min(acc,
495                      complex_accuracy((wre, wim, wre_acc, wim_acc)))
496            re = wre
497            im = wim
498            i0 = 1
499
500        for wre, wim, wre_acc, wim_acc in complex_factors[i0:]:
501            # acc is the overall accuracy of the product; we aren't
502            # computing exact accuracies of the product.
503            acc = min(acc,
504                      complex_accuracy((wre, wim, wre_acc, wim_acc)))
505
506            use_prec = working_prec
507            A = mpf_mul(re, wre, use_prec)
508            B = mpf_mul(mpf_neg(im), wim, use_prec)
509            C = mpf_mul(re, wim, use_prec)
510            D = mpf_mul(im, wre, use_prec)
511            re = mpf_add(A, B, use_prec)
512            im = mpf_add(C, D, use_prec)
513        # multiply by I
514        if direction & 1:
515            re, im = mpf_neg(im), re
516        return re, im, acc, acc
517
518
519def evalf_pow(v, prec, options):
520    from .numbers import E
521
522    target_prec = prec
523    base, exp = v.args
524
525    # We handle x**n separately. This has two purposes: 1) it is much
526    # faster, because we avoid calling evalf on the exponent, and 2) it
527    # allows better handling of real/imaginary parts that are exactly zero
528    if exp.is_Integer:
529        p = exp.numerator
530        # Exact
531        if not p:
532            return fone, None, prec, None
533        # Exponentiation by p magnifies relative error by |p|, so the
534        # base must be evaluated with increased precision if p is large
535        prec += int(math.log(abs(p), 2))
536        re, im, *_ = evalf(base, prec + 5, options)
537        # Real to integer power
538        if re and not im:
539            return mpf_pow_int(re, p, target_prec), None, target_prec, None
540        # (x*I)**n = I**n * x**n
541        if im and not re:
542            z = mpf_pow_int(im, p, target_prec)
543            case = p % 4
544            if case == 0:
545                return z, None, target_prec, None
546            elif case == 1:
547                return None, z, None, target_prec
548            elif case == 2:
549                return mpf_neg(z), None, target_prec, None
550            else:
551                return None, mpf_neg(z), None, target_prec
552        # Zero raised to an integer power
553        if not re:
554            return None, None, None, None
555        # General complex number to arbitrary integer power
556        re, im = libmp.mpc_pow_int((re, im), p, prec)
557        # Assumes full accuracy in input
558        return finalize_complex(re, im, target_prec)
559
560    # Pure square root
561    if exp is S.Half:
562        xre, xim, _, _ = evalf(base, prec + 5, options)
563        # General complex square root
564        if xim:
565            re, im = libmp.mpc_sqrt((xre or fzero, xim), prec)
566            return finalize_complex(re, im, prec)
567        if not xre:
568            return None, None, None, None
569        # Square root of a negative real number
570        if mpf_lt(xre, fzero):
571            return None, mpf_sqrt(mpf_neg(xre), prec), None, prec
572        # Positive square root
573        return mpf_sqrt(xre, prec), None, prec, None
574
575    # We first evaluate the exponent to find its magnitude
576    # This determines the working precision that must be used
577    prec += 10
578    yre, yim, _, _ = evalf(exp, prec, options)
579    # Special cases: x**0
580    if not (yre or yim):
581        return fone, None, prec, None
582
583    ysize = fastlog(yre)
584    # Restart if too big
585    # XXX: prec + ysize might exceed maxprec
586    if ysize > 5:
587        prec += ysize
588        yre, yim, _, _ = evalf(exp, prec, options)
589
590    # Pure exponential function; no need to evalf the base
591    if base is E:
592        if yim:
593            re, im = libmp.mpc_exp((yre or fzero, yim), prec)
594            return finalize_complex(re, im, target_prec)
595        return mpf_exp(yre, target_prec), None, target_prec, None
596
597    xre, xim, _, _ = evalf(base, prec + 5, options)
598    # 0**y
599    if not (xre or xim):
600        return None, None, None, None
601
602    # (real ** complex) or (complex ** complex)
603    if yim:
604        re, im = libmp.mpc_pow(
605            (xre or fzero, xim or fzero), (yre or fzero, yim),
606            target_prec)
607        return finalize_complex(re, im, target_prec)
608    # complex ** real
609    if xim:
610        re, im = libmp.mpc_pow_mpf((xre or fzero, xim), yre, target_prec)
611        return finalize_complex(re, im, target_prec)
612    # negative ** real
613    elif mpf_lt(xre, fzero):
614        re, im = libmp.mpc_pow_mpf((xre, fzero), yre, target_prec)
615        return finalize_complex(re, im, target_prec)
616    # positive ** real
617    else:
618        return mpf_pow(xre, yre, target_prec), None, target_prec, None
619
620
621############################################################################
622#                                                                          #
623#                            Special functions                             #
624#                                                                          #
625############################################################################
626def evalf_trig(v, prec, options):
627    """
628    This function handles sin and cos of complex arguments.
629
630    TODO: should also handle tan of complex arguments.
631
632    """
633    from ..functions import cos, sin
634    if isinstance(v, cos):
635        func = mpf_cos
636    elif isinstance(v, sin):
637        func = mpf_sin
638    else:
639        raise NotImplementedError
640    arg = v.args[0]
641    # 20 extra bits is possibly overkill. It does make the need
642    # to restart very unlikely
643    xprec = prec + 20
644    re, im, *_ = evalf(arg, xprec, options)
645    if im:
646        if 'subs' in options:
647            v = v.subs(options['subs'])
648        return evalf(v._eval_evalf(prec), prec, options)
649    if not re:
650        if isinstance(v, cos):
651            return fone, None, prec, None
652        elif isinstance(v, sin):
653            return None, None, None, None
654        else:
655            raise NotImplementedError
656    # For trigonometric functions, we are interested in the
657    # fixed-point (absolute) accuracy of the argument.
658    xsize = fastlog(re)
659    # Magnitude <= 1.0. OK to compute directly, because there is no
660    # danger of hitting the first root of cos (with sin, magnitude
661    # <= 2.0 would actually be ok)
662    if xsize < 1:
663        return func(re, prec, rnd), None, prec, None
664    # Very large
665    if xsize >= 10:
666        xprec = prec + xsize
667        re, im, *_ = evalf(arg, xprec, options)
668    # Need to repeat in case the argument is very close to a
669    # multiple of pi (or pi/2), hitting close to a root
670    while 1:
671        y = func(re, prec, rnd)
672        ysize = fastlog(y)
673        gap = -ysize
674        accuracy = (xprec - xsize) - gap
675        if accuracy < prec:
676            if xprec > options['maxprec']:
677                return y, None, accuracy, None
678            xprec += gap
679            re, im, *_ = evalf(arg, xprec, options)
680            continue
681        else:
682            return y, None, prec, None
683
684
685def evalf_log(expr, prec, options):
686    from ..functions import Abs, log
687    from .add import Add
688
689    if len(expr.args) > 1:
690        expr = expr.doit()
691        return evalf(expr, prec, options)
692
693    arg = expr.args[0]
694    workprec = prec + 10
695    xre, xim, *_ = evalf(arg, workprec, options)
696
697    if xim:
698        # XXX: use get_abs etc instead
699        re = evalf_log(
700            log(Abs(arg, evaluate=False), evaluate=False), prec, options)
701        im = mpf_atan2(xim, xre or fzero, prec)
702        return re[0], im, re[2], prec
703
704    imaginary_term = (mpf_cmp(xre, fzero) < 0)
705
706    re = mpf_log(mpf_abs(xre), prec, rnd)
707    size = fastlog(re)
708    if prec - size > workprec:
709        # We actually need to compute 1+x accurately, not x
710        arg = Add(S.NegativeOne, arg, evaluate=False)
711        xre, xim, _, _ = evalf_add(arg, prec, options)
712        prec2 = workprec - fastlog(xre)
713        # xre is now x - 1 so we add 1 back here to calculate x
714        re = mpf_log(mpf_abs(mpf_add(xre, fone, prec2)), prec, rnd)
715
716    re_acc = prec
717
718    if imaginary_term:
719        return re, mpf_pi(prec), re_acc, prec
720    else:
721        return re, None, re_acc, None
722
723
724def evalf_atan(v, prec, options):
725    arg = v.args[0]
726    xre, xim, *_ = evalf(arg, prec + 5, options)
727    if xre is xim is None:
728        return (None,)*4
729    if xim:
730        raise NotImplementedError
731    return mpf_atan(xre, prec, rnd), None, prec, None
732
733
734def evalf_subs(prec, subs):
735    """Change all Float entries in `subs` to have precision prec."""
736    newsubs = {}
737    for a, b in subs.items():
738        b = sympify(b)
739        if b.is_Float:
740            b = b._eval_evalf(prec)
741        newsubs[a] = b
742    return newsubs
743
744
745def evalf_piecewise(expr, prec, options):
746    if 'subs' in options:
747        expr = expr.subs(evalf_subs(prec, options['subs']))
748        newopts = options.copy()
749        del newopts['subs']
750        return evalf(expr, prec, newopts)
751
752    # We still have undefined symbols
753    raise NotImplementedError
754
755
756def evalf_bernoulli(expr, prec, options):
757    arg = expr.args[0]
758    if not arg.is_Integer:
759        raise ValueError('Bernoulli number index must be an integer')
760    n = int(arg)
761    b = mpf_bernoulli(n, prec, rnd)
762    if b == fzero:
763        return None, None, None, None
764    return b, None, prec, None
765
766############################################################################
767#                                                                          #
768#                            High-level operations                         #
769#                                                                          #
770############################################################################
771
772
773def as_mpmath(x, prec, options):
774    from .numbers import oo
775    x = sympify(x)
776    if x == 0:
777        return mpf(0)
778    if x == oo:
779        return mpf('inf')
780    if x == -oo:
781        return mpf('-inf')
782    # XXX
783    re, im, _, _ = evalf(x, prec, options)
784    if im:
785        return mpc(re or fzero, im)
786    return mpf(re)
787
788
789def do_integral(expr, prec, options):
790    func = expr.args[0]
791    x, xlow, xhigh = expr.args[1]
792    if xlow == xhigh:
793        xlow = xhigh = 0
794    elif x not in func.free_symbols:
795        # only the difference in limits matters in this case
796        # so if there is a symbol in common that will cancel
797        # out when taking the difference, then use that
798        # difference
799        if xhigh.free_symbols & xlow.free_symbols:
800            diff = xhigh - xlow
801            if not diff.free_symbols:
802                xlow, xhigh = 0, diff
803
804    oldmaxprec = options['maxprec']
805    options['maxprec'] = min(oldmaxprec, 2*prec)
806
807    with workprec(prec + 5):
808        xlow = as_mpmath(xlow, prec + 15, options)
809        xhigh = as_mpmath(xhigh, prec + 15, options)
810
811        # Integration is like summation, and we can phone home from
812        # the integrand function to update accuracy summation style
813        # Note that this accuracy is inaccurate, since it fails
814        # to account for the variable quadrature weights,
815        # but it is better than nothing
816
817        from ..functions import cos, sin
818        from .numbers import pi
819        from .symbol import Wild
820
821        have_part = [False, False]
822        max_real_term = [MINUS_INF]
823        max_imag_term = [MINUS_INF]
824
825        def f(t):
826            re, im, *_ = evalf(func, mp.prec, {'subs': {x: t}, 'maxprec': DEFAULT_MAXPREC})
827
828            have_part[0] = re or have_part[0]
829            have_part[1] = im or have_part[1]
830
831            max_real_term[0] = max(max_real_term[0], fastlog(re))
832            max_imag_term[0] = max(max_imag_term[0], fastlog(im))
833
834            if im:
835                return mpc(re or fzero, im)
836            return mpf(re or fzero)
837
838        if options.get('quad') == 'osc':
839            A = Wild('A', exclude=[x])
840            B = Wild('B', exclude=[x])
841            D = Wild('D')
842            m = func.match(cos(A*x + B)*D)
843            if not m:
844                m = func.match(sin(A*x + B)*D)
845            if not m:
846                raise ValueError('An integrand of the form sin(A*x+B)*f(x) '
847                                 'or cos(A*x+B)*f(x) is required for oscillatory quadrature')
848            period = as_mpmath(2*pi/m[A], prec + 15, options)
849            result = quadosc(f, [xlow, xhigh], period=period)
850            # XXX: quadosc does not do error detection yet
851            quadrature_error = MINUS_INF
852        else:
853            result, quadrature_error = quadts(f, [xlow, xhigh], error=1)
854            quadrature_error = fastlog(quadrature_error._mpf_)
855
856    options['maxprec'] = oldmaxprec
857
858    if have_part[0]:
859        re = result.real._mpf_
860        if re == fzero:
861            re, re_acc = scaled_zero(
862                min(-prec, -max_real_term[0], -quadrature_error))
863            re = scaled_zero(re)  # handled ok in evalf_integral
864        else:
865            re_acc = -max(max_real_term[0] - fastlog(re) -
866                          prec, quadrature_error)
867    else:
868        re, re_acc = None, None
869
870    if have_part[1]:
871        im = result.imag._mpf_
872        if im == fzero:
873            im, im_acc = scaled_zero(
874                min(-prec, -max_imag_term[0], -quadrature_error))
875            im = scaled_zero(im)  # handled ok in evalf_integral
876        else:
877            im_acc = -max(max_imag_term[0] - fastlog(im) -
878                          prec, quadrature_error)
879    else:
880        im, im_acc = None, None
881
882    result = re, im, re_acc, im_acc
883    return result
884
885
886def evalf_integral(expr, prec, options):
887    limits = expr.limits
888    if len(limits) != 1 or len(limits[0]) != 3:
889        raise NotImplementedError
890    workprec = prec
891    i = 0
892    maxprec = options.get('maxprec', INF)
893    while 1:
894        result = do_integral(expr, workprec, options)
895        accuracy = complex_accuracy(result)
896        if accuracy >= prec:  # achieved desired precision
897            break
898        if workprec >= maxprec:  # can't increase accuracy any more
899            break
900        if accuracy == -1:
901            # maybe the answer really is zero and maybe we just haven't increased
902            # the precision enough. So increase by doubling to not take too long
903            # to get to maxprec.
904            workprec *= 2
905        else:
906            workprec += max(prec, 2**i)
907        workprec = min(workprec, maxprec)
908        i += 1
909    return result
910
911
912def check_convergence(numer, denom, n):
913    """
914    Returns (h, g, p) where
915    -- h is:
916        > 0 for convergence of rate 1/factorial(n)**h
917        < 0 for divergence of rate factorial(n)**(-h)
918        = 0 for geometric or polynomial convergence or divergence
919
920    -- abs(g) is:
921        > 1 for geometric convergence of rate 1/h**n
922        < 1 for geometric divergence of rate h**n
923        = 1 for polynomial convergence or divergence
924
925        (g < 0 indicates an alternating series)
926
927    -- p is:
928        > 1 for polynomial convergence of rate 1/n**h
929        <= 1 for polynomial divergence of rate n**(-h)
930
931    """
932    npol = numer.as_poly(n)
933    dpol = denom.as_poly(n)
934    p = npol.degree()
935    q = dpol.degree()
936    rate = q - p
937    if rate:
938        return rate, None, None
939    constant = dpol.LC() / npol.LC()
940    if abs(constant) != 1:
941        return rate, constant, None
942    if npol.degree() == dpol.degree() == 0:
943        return rate, constant, 0
944    pc = npol.all_coeffs()[-2]
945    qc = dpol.all_coeffs()[-2]
946    return rate, constant, (qc - pc)/dpol.LC()
947
948
949def hypsum(expr, n, start, prec):
950    """
951    Sum a rapidly convergent infinite hypergeometric series with
952    given general term, e.g. e = hypsum(1/factorial(n), n). The
953    quotient between successive terms must be a quotient of integer
954    polynomials.
955
956    """
957    from ..simplify import hypersimp
958    from ..utilities import lambdify
959    from .numbers import Float
960
961    if prec == float('inf'):
962        raise NotImplementedError('does not support inf prec')
963
964    if start:
965        expr = expr.subs({n: n + start})
966    hs = hypersimp(expr, n)
967    if hs is None:
968        raise NotImplementedError('a hypergeometric series is required')
969    num, den = hs.as_numer_denom()
970
971    func1 = lambdify(n, num)
972    func2 = lambdify(n, den)
973
974    h, g, p = check_convergence(num, den, n)
975
976    if h < 0:
977        raise ValueError('Sum diverges like (n!)^%i' % (-h))
978
979    term = expr.subs({n: 0})
980    if not term.is_Rational:
981        raise NotImplementedError('Non rational term functionality is not implemented.')
982
983    # Direct summation if geometric or faster
984    if h > 0 or (h == 0 and abs(g) > 1):
985        term = (MPZ(term.numerator) << prec) // term.denominator
986        s = term
987        k = 1
988        while abs(term) > 5:
989            term *= MPZ(func1(k - 1))
990            term //= MPZ(func2(k - 1))
991            s += term
992            k += 1
993        return from_man_exp(s, -prec)
994    else:
995        alt = g < 0
996        if abs(g) < 1:
997            raise ValueError('Sum diverges like (%i)^n' % abs(1/g))
998        if p < 1 or (p == 1 and not alt):
999            raise ValueError('Sum diverges like n^%i' % (-p))
1000        # We have polynomial convergence: use Richardson extrapolation
1001        vold = None
1002        ndig = prec_to_dps(prec)
1003        while True:
1004            # Need to use at least quad precision because a lot of cancellation
1005            # might occur in the extrapolation process; we check the answer to
1006            # make sure that the desired precision has been reached, too.
1007            prec2 = 4*prec
1008            term0 = (MPZ(term.numerator) << prec2) // term.denominator
1009
1010            def summand(k, _term=[term0]):
1011                if k:
1012                    k = int(k)
1013                    _term[0] *= MPZ(func1(k - 1))
1014                    _term[0] //= MPZ(func2(k - 1))
1015                return make_mpf(from_man_exp(_term[0], -prec2))
1016
1017            with workprec(prec):
1018                v = nsum(summand, [0, mpmath_inf], method='richardson')
1019            vf = Float(v, ndig)
1020            if vold is not None and vold == vf:
1021                break
1022            prec += prec  # double precision each time
1023            vold = vf
1024
1025        return v._mpf_
1026
1027
1028def evalf_prod(expr, prec, options):
1029    from ..concrete import Sum
1030    if all((l[1] - l[2]).is_Integer for l in expr.limits):
1031        re, im, re_acc, im_acc = evalf(expr.doit(), prec=prec, options=options)
1032    else:
1033        re, im, re_acc, im_acc = evalf(expr.rewrite(Sum), prec=prec, options=options)
1034    return re, im, re_acc, im_acc
1035
1036
1037def evalf_sum(expr, prec, options):
1038    from .numbers import Float, oo
1039    if 'subs' in options:
1040        expr = expr.subs(options['subs'])
1041    func = expr.function
1042    limits = expr.limits
1043    if len(limits) != 1 or len(limits[0]) != 3:
1044        raise NotImplementedError
1045    if func is S.Zero:
1046        return mpf(0), None, None, None
1047    prec2 = prec + 10
1048    try:
1049        n, a, b = limits[0]
1050        if b != oo or a != int(a):
1051            raise NotImplementedError
1052        # Use fast hypergeometric summation if possible
1053        v = hypsum(func, n, int(a), prec2)
1054        delta = prec - fastlog(v)
1055        if fastlog(v) < -10:
1056            v = hypsum(func, n, int(a), delta)
1057        return v, None, min(prec, delta), None
1058    except NotImplementedError:
1059        # Euler-Maclaurin summation for general series
1060        m, err, eps = prec, oo, Float(2.0)**(-prec)
1061        while err > eps:
1062            m <<= 1
1063            s, err = expr.euler_maclaurin(m=m, n=m, eps=eps,
1064                                          eval_integral=False)
1065            err = err.evalf(strict=False)
1066        err = fastlog(evalf(abs(err), 20, options)[0])
1067        re, im, re_acc, im_acc = evalf(s, prec2, options)
1068        if re_acc is None:
1069            re_acc = -err
1070        if im_acc is None:
1071            im_acc = -err
1072        return re, im, re_acc, im_acc
1073
1074
1075############################################################################
1076#                                                                          #
1077#                            Symbolic interface                            #
1078#                                                                          #
1079############################################################################
1080
1081def evalf_symbol(x, prec, options):
1082    val = options['subs'][x]
1083    if isinstance(val, mpf):
1084        if not val:
1085            return None, None, None, None
1086        return val._mpf_, None, prec, None
1087    else:
1088        if '_cache' not in options:
1089            options['_cache'] = {}
1090        cache = options['_cache']
1091        cached, cached_prec = cache.get(x, (None, MINUS_INF))
1092        if cached_prec >= prec:
1093            return cached
1094        v = evalf(sympify(val), prec, options)
1095        cache[x] = (v, prec)
1096        return v
1097
1098
1099evalf_table = None
1100
1101
1102def _create_evalf_table():
1103    global evalf_table
1104    from ..concrete.products import Product
1105    from ..concrete.summations import Sum
1106    from ..functions.combinatorial.numbers import bernoulli
1107    from ..functions.elementary.complexes import Abs, im, re
1108    from ..functions.elementary.exponential import log
1109    from ..functions.elementary.piecewise import Piecewise
1110    from ..functions.elementary.trigonometric import atan, cos, sin
1111    from ..integrals.integrals import Integral
1112    from .add import Add
1113    from .mul import Mul
1114    from .numbers import (Exp1, Float, Half, ImaginaryUnit, Integer, NaN,
1115                          NegativeOne, One, Pi, Rational, Zero)
1116    from .power import Pow
1117    from .symbol import Dummy, Symbol
1118    evalf_table = {
1119        Symbol: evalf_symbol,
1120        Dummy: evalf_symbol,
1121        Float: lambda x, prec, options: (x._mpf_, None, prec if prec <= x._prec else x._prec, None),
1122        Rational: lambda x, prec, options: (from_rational(x.numerator, x.denominator, prec),
1123                                            None, prec, None),
1124        Integer: lambda x, prec, options: (from_int(x.numerator, prec),
1125                                           None, prec, None),
1126        Zero: lambda x, prec, options: (None, None, prec, None),
1127        One: lambda x, prec, options: (fone, None, prec, None),
1128        Half: lambda x, prec, options: (fhalf, None, prec, None),
1129        Pi: lambda x, prec, options: (mpf_pi(prec), None, prec, None),
1130        Exp1: lambda x, prec, options: (mpf_e(prec), None, prec, None),
1131        ImaginaryUnit: lambda x, prec, options: (None, fone, None, prec),
1132        NegativeOne: lambda x, prec, options: (fnone, None, prec, None),
1133        NaN: lambda x, prec, options: (fnan, None, prec, None),
1134
1135        cos: evalf_trig,
1136        sin: evalf_trig,
1137
1138        Add: evalf_add,
1139        Mul: evalf_mul,
1140        Pow: evalf_pow,
1141
1142        log: evalf_log,
1143        atan: evalf_atan,
1144        Abs: evalf_abs,
1145
1146        re: evalf_re,
1147        im: evalf_im,
1148
1149        Integral: evalf_integral,
1150        Sum: evalf_sum,
1151        Product: evalf_prod,
1152        Piecewise: evalf_piecewise,
1153
1154        bernoulli: evalf_bernoulli,
1155    }
1156
1157
1158def evalf(x, prec, options):
1159    from ..functions import im as im_
1160    from ..functions import re as re_
1161    try:
1162        rf = evalf_table[x.func]
1163        r = rf(x, prec, options)
1164    except KeyError:
1165        try:
1166            # Fall back to ordinary evalf if possible
1167            if 'subs' in options:
1168                x = x.subs(evalf_subs(prec, options['subs']))
1169            re, im = x._eval_evalf(prec).as_real_imag()
1170            if re.has(re_) or im.has(im_):
1171                raise NotImplementedError
1172            if re == 0:
1173                re = None
1174                reprec = None
1175            else:
1176                re = re._to_mpmath(prec)._mpf_
1177                reprec = prec
1178            if im == 0:
1179                im = None
1180                imprec = None
1181            else:
1182                im = im._to_mpmath(prec)._mpf_
1183                imprec = prec
1184            r = re, im, reprec, imprec
1185        except AttributeError:
1186            raise NotImplementedError
1187    chop = options.get('chop', False)
1188    if chop:
1189        if chop is True:
1190            chop_prec = prec
1191        else:
1192            # convert (approximately) from given tolerance;
1193            # the formula here will will make 1e-i rounds to 0 for
1194            # i in the range +/-27 while 2e-i will not be chopped
1195            chop_prec = round(-3.321*math.log10(chop) + 2.5)
1196        r = chop_parts(r, chop_prec)
1197    if options.get('strict'):
1198        check_target(x, r, prec)
1199    return r
1200
1201
1202class EvalfMixin:
1203    """Mixin class adding evalf capability."""
1204
1205    def evalf(self, dps=15, subs=None, maxn=110, chop=False, strict=True, quad=None):
1206        """
1207        Evaluate the given formula to an accuracy of dps decimal digits.
1208        Optional keyword arguments:
1209
1210            subs=<dict>
1211                Substitute numerical values for symbols, e.g.
1212                subs={x:3, y:1+pi}. The substitutions must be given as a
1213                dictionary.
1214
1215            maxn=<integer>
1216                Allow a maximum temporary working precision of maxn digits
1217                (default=110)
1218
1219            chop=<bool>
1220                Replace tiny real or imaginary parts in subresults
1221                by exact zeros (default=False)
1222
1223            strict=<bool>
1224                Raise PrecisionExhausted if any subresult fails to evaluate
1225                to full accuracy, given the available maxprec
1226                (default=True)
1227
1228            quad=<str>
1229                Choose algorithm for numerical quadrature. By default,
1230                tanh-sinh quadrature is used. For oscillatory
1231                integrals on an infinite interval, try quad='osc'.
1232
1233        """
1234        from .numbers import Float, I
1235
1236        if subs and is_sequence(subs):
1237            raise TypeError('subs must be given as a dictionary')
1238
1239        if not evalf_table:
1240            _create_evalf_table()
1241        prec = dps_to_prec(dps)
1242        options = {'maxprec': max(prec, int(maxn*LG10)), 'chop': chop,
1243                   'strict': strict}
1244        if subs is not None:
1245            options['subs'] = subs
1246        if quad is not None:
1247            options['quad'] = quad
1248        try:
1249            result = evalf(self, prec + 4, options)
1250        except PrecisionExhausted:
1251            if self.is_Float and self._prec >= prec:
1252                return Float._new(self._mpf_, prec)
1253            else:
1254                raise
1255        except NotImplementedError:
1256            # Fall back to the ordinary evalf
1257            v = self._eval_evalf(prec)
1258            if v is None:
1259                return self
1260            else:
1261                # Normalize result
1262                return v.subs({_: _.evalf(dps, strict=strict)
1263                               for _ in v.atoms(Float)})
1264        re, im, re_acc, im_acc = result
1265        if re:
1266            p = max(min(prec, re_acc), 1)
1267            re = Float._new(re, p)
1268        else:
1269            re = S.Zero
1270        if im:
1271            p = max(min(prec, im_acc), 1)
1272            im = Float._new(im, p)
1273            return re + im*I
1274        else:
1275            return re
1276
1277    def _evalf(self, prec):
1278        """Helper for evalf. Does the same thing but takes binary precision."""
1279        r = self._eval_evalf(prec)
1280        if r is None:
1281            r = self
1282        return r
1283
1284    def _eval_evalf(self, prec):
1285        return
1286
1287    def _to_mpmath(self, prec):
1288        # mpmath functions accept ints as input
1289        errmsg = 'cannot convert to mpmath number'
1290        if hasattr(self, '_as_mpf_val'):
1291            return make_mpf(self._as_mpf_val(prec))
1292        try:
1293            re, im, _, _ = evalf(self, prec, {'maxprec': DEFAULT_MAXPREC})
1294            if im:
1295                if not re:
1296                    re = fzero
1297                return make_mpc((re, im))
1298            elif re:
1299                return make_mpf(re)
1300            else:
1301                return make_mpf(fzero)
1302        except NotImplementedError:
1303            v = self._eval_evalf(prec)
1304            if v is None:
1305                raise ValueError(errmsg)
1306            re, im = v.as_real_imag()
1307            if re.is_Float:
1308                re = re._mpf_
1309            else:
1310                raise ValueError(errmsg)
1311            if im.is_Float:
1312                im = im._mpf_
1313            else:
1314                raise ValueError(errmsg)
1315            return make_mpc((re, im))
1316
1317
1318def N(x, dps=15, **options):
1319    r"""
1320    Calls x.evalf(dps, \*\*options).
1321
1322    Examples
1323    ========
1324
1325    >>> Sum(1/k**k, (k, 1, oo))
1326    Sum(k**(-k), (k, 1, oo))
1327    >>> N(_, 4)
1328    1.291
1329
1330    See Also
1331    ========
1332
1333    diofant.core.evalf.EvalfMixin.evalf
1334
1335    """
1336    return sympify(x).evalf(dps, **options)
1337