1from operator import gt, lt
2
3from .libmp.backend import xrange
4
5from .functions.functions import SpecialFunctions
6from .functions.rszeta import RSCache
7from .calculus.quadrature import QuadratureMethods
8from .calculus.inverselaplace import LaplaceTransformInversionMethods
9from .calculus.calculus import CalculusMethods
10from .calculus.optimization import OptimizationMethods
11from .calculus.odes import ODEMethods
12from .matrices.matrices import MatrixMethods
13from .matrices.calculus import MatrixCalculusMethods
14from .matrices.linalg import LinearAlgebraMethods
15from .matrices.eigen import Eigen
16from .identification import IdentificationMethods
17from .visualization import VisualizationMethods
18
19from . import libmp
20
21class Context(object):
22    pass
23
24class StandardBaseContext(Context,
25    SpecialFunctions,
26    RSCache,
27    QuadratureMethods,
28    LaplaceTransformInversionMethods,
29    CalculusMethods,
30    MatrixMethods,
31    MatrixCalculusMethods,
32    LinearAlgebraMethods,
33    Eigen,
34    IdentificationMethods,
35    OptimizationMethods,
36    ODEMethods,
37    VisualizationMethods):
38
39    NoConvergence = libmp.NoConvergence
40    ComplexResult = libmp.ComplexResult
41
42    def __init__(ctx):
43        ctx._aliases = {}
44        # Call those that need preinitialization (e.g. for wrappers)
45        SpecialFunctions.__init__(ctx)
46        RSCache.__init__(ctx)
47        QuadratureMethods.__init__(ctx)
48        LaplaceTransformInversionMethods.__init__(ctx)
49        CalculusMethods.__init__(ctx)
50        MatrixMethods.__init__(ctx)
51
52    def _init_aliases(ctx):
53        for alias, value in ctx._aliases.items():
54            try:
55                setattr(ctx, alias, getattr(ctx, value))
56            except AttributeError:
57                pass
58
59    _fixed_precision = False
60
61    # XXX
62    verbose = False
63
64    def warn(ctx, msg):
65        print("Warning:", msg)
66
67    def bad_domain(ctx, msg):
68        raise ValueError(msg)
69
70    def _re(ctx, x):
71        if hasattr(x, "real"):
72            return x.real
73        return x
74
75    def _im(ctx, x):
76        if hasattr(x, "imag"):
77            return x.imag
78        return ctx.zero
79
80    def _as_points(ctx, x):
81        return x
82
83    def fneg(ctx, x, **kwargs):
84        return -ctx.convert(x)
85
86    def fadd(ctx, x, y, **kwargs):
87        return ctx.convert(x)+ctx.convert(y)
88
89    def fsub(ctx, x, y, **kwargs):
90        return ctx.convert(x)-ctx.convert(y)
91
92    def fmul(ctx, x, y, **kwargs):
93        return ctx.convert(x)*ctx.convert(y)
94
95    def fdiv(ctx, x, y, **kwargs):
96        return ctx.convert(x)/ctx.convert(y)
97
98    def fsum(ctx, args, absolute=False, squared=False):
99        if absolute:
100            if squared:
101                return sum((abs(x)**2 for x in args), ctx.zero)
102            return sum((abs(x) for x in args), ctx.zero)
103        if squared:
104            return sum((x**2 for x in args), ctx.zero)
105        return sum(args, ctx.zero)
106
107    def fdot(ctx, xs, ys=None, conjugate=False):
108        if ys is not None:
109            xs = zip(xs, ys)
110        if conjugate:
111            cf = ctx.conj
112            return sum((x*cf(y) for (x,y) in xs), ctx.zero)
113        else:
114            return sum((x*y for (x,y) in xs), ctx.zero)
115
116    def fprod(ctx, args):
117        prod = ctx.one
118        for arg in args:
119            prod *= arg
120        return prod
121
122    def nprint(ctx, x, n=6, **kwargs):
123        """
124        Equivalent to ``print(nstr(x, n))``.
125        """
126        print(ctx.nstr(x, n, **kwargs))
127
128    def chop(ctx, x, tol=None):
129        """
130        Chops off small real or imaginary parts, or converts
131        numbers close to zero to exact zeros. The input can be a
132        single number or an iterable::
133
134            >>> from mpmath import *
135            >>> mp.dps = 15; mp.pretty = False
136            >>> chop(5+1e-10j, tol=1e-9)
137            mpf('5.0')
138            >>> nprint(chop([1.0, 1e-20, 3+1e-18j, -4, 2]))
139            [1.0, 0.0, 3.0, -4.0, 2.0]
140
141        The tolerance defaults to ``100*eps``.
142        """
143        if tol is None:
144            tol = 100*ctx.eps
145        try:
146            x = ctx.convert(x)
147            absx = abs(x)
148            if abs(x) < tol:
149                return ctx.zero
150            if ctx._is_complex_type(x):
151                #part_tol = min(tol, absx*tol)
152                part_tol = max(tol, absx*tol)
153                if abs(x.imag) < part_tol:
154                    return x.real
155                if abs(x.real) < part_tol:
156                    return ctx.mpc(0, x.imag)
157        except TypeError:
158            if isinstance(x, ctx.matrix):
159                return x.apply(lambda a: ctx.chop(a, tol))
160            if hasattr(x, "__iter__"):
161                return [ctx.chop(a, tol) for a in x]
162        return x
163
164    def almosteq(ctx, s, t, rel_eps=None, abs_eps=None):
165        r"""
166        Determine whether the difference between `s` and `t` is smaller
167        than a given epsilon, either relatively or absolutely.
168
169        Both a maximum relative difference and a maximum difference
170        ('epsilons') may be specified. The absolute difference is
171        defined as `|s-t|` and the relative difference is defined
172        as `|s-t|/\max(|s|, |t|)`.
173
174        If only one epsilon is given, both are set to the same value.
175        If none is given, both epsilons are set to `2^{-p+m}` where
176        `p` is the current working precision and `m` is a small
177        integer. The default setting typically allows :func:`~mpmath.almosteq`
178        to be used to check for mathematical equality
179        in the presence of small rounding errors.
180
181        **Examples**
182
183            >>> from mpmath import *
184            >>> mp.dps = 15
185            >>> almosteq(3.141592653589793, 3.141592653589790)
186            True
187            >>> almosteq(3.141592653589793, 3.141592653589700)
188            False
189            >>> almosteq(3.141592653589793, 3.141592653589700, 1e-10)
190            True
191            >>> almosteq(1e-20, 2e-20)
192            True
193            >>> almosteq(1e-20, 2e-20, rel_eps=0, abs_eps=0)
194            False
195
196        """
197        t = ctx.convert(t)
198        if abs_eps is None and rel_eps is None:
199            rel_eps = abs_eps = ctx.ldexp(1, -ctx.prec+4)
200        if abs_eps is None:
201            abs_eps = rel_eps
202        elif rel_eps is None:
203            rel_eps = abs_eps
204        diff = abs(s-t)
205        if diff <= abs_eps:
206            return True
207        abss = abs(s)
208        abst = abs(t)
209        if abss < abst:
210            err = diff/abst
211        else:
212            err = diff/abss
213        return err <= rel_eps
214
215    def arange(ctx, *args):
216        r"""
217        This is a generalized version of Python's :func:`~mpmath.range` function
218        that accepts fractional endpoints and step sizes and
219        returns a list of ``mpf`` instances. Like :func:`~mpmath.range`,
220        :func:`~mpmath.arange` can be called with 1, 2 or 3 arguments:
221
222        ``arange(b)``
223            `[0, 1, 2, \ldots, x]`
224        ``arange(a, b)``
225            `[a, a+1, a+2, \ldots, x]`
226        ``arange(a, b, h)``
227            `[a, a+h, a+h, \ldots, x]`
228
229        where `b-1 \le x < b` (in the third case, `b-h \le x < b`).
230
231        Like Python's :func:`~mpmath.range`, the endpoint is not included. To
232        produce ranges where the endpoint is included, :func:`~mpmath.linspace`
233        is more convenient.
234
235        **Examples**
236
237            >>> from mpmath import *
238            >>> mp.dps = 15; mp.pretty = False
239            >>> arange(4)
240            [mpf('0.0'), mpf('1.0'), mpf('2.0'), mpf('3.0')]
241            >>> arange(1, 2, 0.25)
242            [mpf('1.0'), mpf('1.25'), mpf('1.5'), mpf('1.75')]
243            >>> arange(1, -1, -0.75)
244            [mpf('1.0'), mpf('0.25'), mpf('-0.5')]
245
246        """
247        if not len(args) <= 3:
248            raise TypeError('arange expected at most 3 arguments, got %i'
249                            % len(args))
250        if not len(args) >= 1:
251            raise TypeError('arange expected at least 1 argument, got %i'
252                            % len(args))
253        # set default
254        a = 0
255        dt = 1
256        # interpret arguments
257        if len(args) == 1:
258            b = args[0]
259        elif len(args) >= 2:
260            a = args[0]
261            b = args[1]
262        if len(args) == 3:
263            dt = args[2]
264        a, b, dt = ctx.mpf(a), ctx.mpf(b), ctx.mpf(dt)
265        assert a + dt != a, 'dt is too small and would cause an infinite loop'
266        # adapt code for sign of dt
267        if a > b:
268            if dt > 0:
269                return []
270            op = gt
271        else:
272            if dt < 0:
273                return []
274            op = lt
275        # create list
276        result = []
277        i = 0
278        t = a
279        while 1:
280            t = a + dt*i
281            i += 1
282            if op(t, b):
283                result.append(t)
284            else:
285                break
286        return result
287
288    def linspace(ctx, *args, **kwargs):
289        """
290        ``linspace(a, b, n)`` returns a list of `n` evenly spaced
291        samples from `a` to `b`. The syntax ``linspace(mpi(a,b), n)``
292        is also valid.
293
294        This function is often more convenient than :func:`~mpmath.arange`
295        for partitioning an interval into subintervals, since
296        the endpoint is included::
297
298            >>> from mpmath import *
299            >>> mp.dps = 15; mp.pretty = False
300            >>> linspace(1, 4, 4)
301            [mpf('1.0'), mpf('2.0'), mpf('3.0'), mpf('4.0')]
302
303        You may also provide the keyword argument ``endpoint=False``::
304
305            >>> linspace(1, 4, 4, endpoint=False)
306            [mpf('1.0'), mpf('1.75'), mpf('2.5'), mpf('3.25')]
307
308        """
309        if len(args) == 3:
310            a = ctx.mpf(args[0])
311            b = ctx.mpf(args[1])
312            n = int(args[2])
313        elif len(args) == 2:
314            assert hasattr(args[0], '_mpi_')
315            a = args[0].a
316            b = args[0].b
317            n = int(args[1])
318        else:
319            raise TypeError('linspace expected 2 or 3 arguments, got %i' \
320                            % len(args))
321        if n < 1:
322            raise ValueError('n must be greater than 0')
323        if not 'endpoint' in kwargs or kwargs['endpoint']:
324            if n == 1:
325                return [ctx.mpf(a)]
326            step = (b - a) / ctx.mpf(n - 1)
327            y = [i*step + a for i in xrange(n)]
328            y[-1] = b
329        else:
330            step = (b - a) / ctx.mpf(n)
331            y = [i*step + a for i in xrange(n)]
332        return y
333
334    def cos_sin(ctx, z, **kwargs):
335        return ctx.cos(z, **kwargs), ctx.sin(z, **kwargs)
336
337    def cospi_sinpi(ctx, z, **kwargs):
338        return ctx.cospi(z, **kwargs), ctx.sinpi(z, **kwargs)
339
340    def _default_hyper_maxprec(ctx, p):
341        return int(1000 * p**0.25 + 4*p)
342
343    _gcd = staticmethod(libmp.gcd)
344    list_primes = staticmethod(libmp.list_primes)
345    isprime = staticmethod(libmp.isprime)
346    bernfrac = staticmethod(libmp.bernfrac)
347    moebius = staticmethod(libmp.moebius)
348    _ifac = staticmethod(libmp.ifac)
349    _eulernum = staticmethod(libmp.eulernum)
350    _stirling1 = staticmethod(libmp.stirling1)
351    _stirling2 = staticmethod(libmp.stirling2)
352
353    def sum_accurately(ctx, terms, check_step=1):
354        prec = ctx.prec
355        try:
356            extraprec = 10
357            while 1:
358                ctx.prec = prec + extraprec + 5
359                max_mag = ctx.ninf
360                s = ctx.zero
361                k = 0
362                for term in terms():
363                    s += term
364                    if (not k % check_step) and term:
365                        term_mag = ctx.mag(term)
366                        max_mag = max(max_mag, term_mag)
367                        sum_mag = ctx.mag(s)
368                        if sum_mag - term_mag > ctx.prec:
369                            break
370                    k += 1
371                cancellation = max_mag - sum_mag
372                if cancellation != cancellation:
373                    break
374                if cancellation < extraprec or ctx._fixed_precision:
375                    break
376                extraprec += min(ctx.prec, cancellation)
377            return s
378        finally:
379            ctx.prec = prec
380
381    def mul_accurately(ctx, factors, check_step=1):
382        prec = ctx.prec
383        try:
384            extraprec = 10
385            while 1:
386                ctx.prec = prec + extraprec + 5
387                max_mag = ctx.ninf
388                one = ctx.one
389                s = one
390                k = 0
391                for factor in factors():
392                    s *= factor
393                    term = factor - one
394                    if (not k % check_step):
395                        term_mag = ctx.mag(term)
396                        max_mag = max(max_mag, term_mag)
397                        sum_mag = ctx.mag(s-one)
398                        #if sum_mag - term_mag > ctx.prec:
399                        #    break
400                        if -term_mag > ctx.prec:
401                            break
402                    k += 1
403                cancellation = max_mag - sum_mag
404                if cancellation != cancellation:
405                    break
406                if cancellation < extraprec or ctx._fixed_precision:
407                    break
408                extraprec += min(ctx.prec, cancellation)
409            return s
410        finally:
411            ctx.prec = prec
412
413    def power(ctx, x, y):
414        r"""Converts `x` and `y` to mpmath numbers and evaluates
415        `x^y = \exp(y \log(x))`::
416
417            >>> from mpmath import *
418            >>> mp.dps = 30; mp.pretty = True
419            >>> power(2, 0.5)
420            1.41421356237309504880168872421
421
422        This shows the leading few digits of a large Mersenne prime
423        (performing the exact calculation ``2**43112609-1`` and
424        displaying the result in Python would be very slow)::
425
426            >>> power(2, 43112609)-1
427            3.16470269330255923143453723949e+12978188
428        """
429        return ctx.convert(x) ** ctx.convert(y)
430
431    def _zeta_int(ctx, n):
432        return ctx.zeta(n)
433
434    def maxcalls(ctx, f, N):
435        """
436        Return a wrapped copy of *f* that raises ``NoConvergence`` when *f*
437        has been called more than *N* times::
438
439            >>> from mpmath import *
440            >>> mp.dps = 15
441            >>> f = maxcalls(sin, 10)
442            >>> print(sum(f(n) for n in range(10)))
443            1.95520948210738
444            >>> f(10) # doctest: +IGNORE_EXCEPTION_DETAIL
445            Traceback (most recent call last):
446              ...
447            NoConvergence: maxcalls: function evaluated 10 times
448
449        """
450        counter = [0]
451        def f_maxcalls_wrapped(*args, **kwargs):
452            counter[0] += 1
453            if counter[0] > N:
454                raise ctx.NoConvergence("maxcalls: function evaluated %i times" % N)
455            return f(*args, **kwargs)
456        return f_maxcalls_wrapped
457
458    def memoize(ctx, f):
459        """
460        Return a wrapped copy of *f* that caches computed values, i.e.
461        a memoized copy of *f*. Values are only reused if the cached precision
462        is equal to or higher than the working precision::
463
464            >>> from mpmath import *
465            >>> mp.dps = 15; mp.pretty = True
466            >>> f = memoize(maxcalls(sin, 1))
467            >>> f(2)
468            0.909297426825682
469            >>> f(2)
470            0.909297426825682
471            >>> mp.dps = 25
472            >>> f(2) # doctest: +IGNORE_EXCEPTION_DETAIL
473            Traceback (most recent call last):
474              ...
475            NoConvergence: maxcalls: function evaluated 1 times
476
477        """
478        f_cache = {}
479        def f_cached(*args, **kwargs):
480            if kwargs:
481                key = args, tuple(kwargs.items())
482            else:
483                key = args
484            prec = ctx.prec
485            if key in f_cache:
486                cprec, cvalue = f_cache[key]
487                if cprec >= prec:
488                    return +cvalue
489            value = f(*args, **kwargs)
490            f_cache[key] = (prec, value)
491            return value
492        f_cached.__name__ = f.__name__
493        f_cached.__doc__ = f.__doc__
494        return f_cached
495