1from .ctx_base import StandardBaseContext
2
3import math
4import cmath
5from . import math2
6
7from . import function_docs
8
9from .libmp import mpf_bernoulli, to_float, int_types
10from . import libmp
11
12class FPContext(StandardBaseContext):
13    """
14    Context for fast low-precision arithmetic (53-bit precision, giving at most
15    about 15-digit accuracy), using Python's builtin float and complex.
16    """
17
18    def __init__(ctx):
19        StandardBaseContext.__init__(ctx)
20
21        # Override SpecialFunctions implementation
22        ctx.loggamma = math2.loggamma
23        ctx._bernoulli_cache = {}
24        ctx.pretty = False
25
26        ctx._init_aliases()
27
28    _mpq = lambda cls, x: float(x[0])/x[1]
29
30    NoConvergence = libmp.NoConvergence
31
32    def _get_prec(ctx): return 53
33    def _set_prec(ctx, p): return
34    def _get_dps(ctx): return 15
35    def _set_dps(ctx, p): return
36
37    _fixed_precision = True
38
39    prec = property(_get_prec, _set_prec)
40    dps = property(_get_dps, _set_dps)
41
42    zero = 0.0
43    one = 1.0
44    eps = math2.EPS
45    inf = math2.INF
46    ninf = math2.NINF
47    nan = math2.NAN
48    j = 1j
49
50    # Called by SpecialFunctions.__init__()
51    @classmethod
52    def _wrap_specfun(cls, name, f, wrap):
53        if wrap:
54            def f_wrapped(ctx, *args, **kwargs):
55                convert = ctx.convert
56                args = [convert(a) for a in args]
57                return f(ctx, *args, **kwargs)
58        else:
59            f_wrapped = f
60        f_wrapped.__doc__ = function_docs.__dict__.get(name, f.__doc__)
61        setattr(cls, name, f_wrapped)
62
63    def bernoulli(ctx, n):
64        cache = ctx._bernoulli_cache
65        if n in cache:
66            return cache[n]
67        cache[n] = to_float(mpf_bernoulli(n, 53, 'n'), strict=True)
68        return cache[n]
69
70    pi = math2.pi
71    e = math2.e
72    euler = math2.euler
73    sqrt2 = 1.4142135623730950488
74    sqrt5 = 2.2360679774997896964
75    phi = 1.6180339887498948482
76    ln2 = 0.69314718055994530942
77    ln10 = 2.302585092994045684
78    euler = 0.57721566490153286061
79    catalan = 0.91596559417721901505
80    khinchin = 2.6854520010653064453
81    apery = 1.2020569031595942854
82    glaisher = 1.2824271291006226369
83
84    absmin = absmax = abs
85
86    def is_special(ctx, x):
87        return x - x != 0.0
88
89    def isnan(ctx, x):
90        return x != x
91
92    def isinf(ctx, x):
93        return abs(x) == math2.INF
94
95    def isnormal(ctx, x):
96        if x:
97            return x - x == 0.0
98        return False
99
100    def isnpint(ctx, x):
101        if type(x) is complex:
102            if x.imag:
103                return False
104            x = x.real
105        return x <= 0.0 and round(x) == x
106
107    mpf = float
108    mpc = complex
109
110    def convert(ctx, x):
111        try:
112            return float(x)
113        except:
114            return complex(x)
115
116    power = staticmethod(math2.pow)
117    sqrt = staticmethod(math2.sqrt)
118    exp = staticmethod(math2.exp)
119    ln = log = staticmethod(math2.log)
120    cos = staticmethod(math2.cos)
121    sin = staticmethod(math2.sin)
122    tan = staticmethod(math2.tan)
123    cos_sin = staticmethod(math2.cos_sin)
124    acos = staticmethod(math2.acos)
125    asin = staticmethod(math2.asin)
126    atan = staticmethod(math2.atan)
127    cosh = staticmethod(math2.cosh)
128    sinh = staticmethod(math2.sinh)
129    tanh = staticmethod(math2.tanh)
130    gamma = staticmethod(math2.gamma)
131    rgamma = staticmethod(math2.rgamma)
132    fac = factorial = staticmethod(math2.factorial)
133    floor = staticmethod(math2.floor)
134    ceil = staticmethod(math2.ceil)
135    cospi = staticmethod(math2.cospi)
136    sinpi = staticmethod(math2.sinpi)
137    cbrt = staticmethod(math2.cbrt)
138    _nthroot = staticmethod(math2.nthroot)
139    _ei = staticmethod(math2.ei)
140    _e1 = staticmethod(math2.e1)
141    _zeta = _zeta_int = staticmethod(math2.zeta)
142
143    # XXX: math2
144    def arg(ctx, z):
145        z = complex(z)
146        return math.atan2(z.imag, z.real)
147
148    def expj(ctx, x):
149        return ctx.exp(ctx.j*x)
150
151    def expjpi(ctx, x):
152        return ctx.exp(ctx.j*ctx.pi*x)
153
154    ldexp = math.ldexp
155    frexp = math.frexp
156
157    def mag(ctx, z):
158        if z:
159            return ctx.frexp(abs(z))[1]
160        return ctx.ninf
161
162    def isint(ctx, z):
163        if hasattr(z, "imag"):   # float/int don't have .real/.imag in py2.5
164            if z.imag:
165                return False
166            z = z.real
167        try:
168            return z == int(z)
169        except:
170            return False
171
172    def nint_distance(ctx, z):
173        if hasattr(z, "imag"):   # float/int don't have .real/.imag in py2.5
174            n = round(z.real)
175        else:
176            n = round(z)
177        if n == z:
178            return n, ctx.ninf
179        return n, ctx.mag(abs(z-n))
180
181    def _convert_param(ctx, z):
182        if type(z) is tuple:
183            p, q = z
184            return ctx.mpf(p) / q, 'R'
185        if hasattr(z, "imag"):    # float/int don't have .real/.imag in py2.5
186            intz = int(z.real)
187        else:
188            intz = int(z)
189        if z == intz:
190            return intz, 'Z'
191        return z, 'R'
192
193    def _is_real_type(ctx, z):
194        return isinstance(z, float) or isinstance(z, int_types)
195
196    def _is_complex_type(ctx, z):
197        return isinstance(z, complex)
198
199    def hypsum(ctx, p, q, types, coeffs, z, maxterms=6000, **kwargs):
200        coeffs = list(coeffs)
201        num = range(p)
202        den = range(p,p+q)
203        tol = ctx.eps
204        s = t = 1.0
205        k = 0
206        while 1:
207            for i in num: t *= (coeffs[i]+k)
208            for i in den: t /= (coeffs[i]+k)
209            k += 1; t /= k; t *= z; s += t
210            if abs(t) < tol:
211                return s
212            if k > maxterms:
213                raise ctx.NoConvergence
214
215    def atan2(ctx, x, y):
216        return math.atan2(x, y)
217
218    def psi(ctx, m, z):
219        m = int(m)
220        if m == 0:
221            return ctx.digamma(z)
222        return (-1)**(m+1) * ctx.fac(m) * ctx.zeta(m+1, z)
223
224    digamma = staticmethod(math2.digamma)
225
226    def harmonic(ctx, x):
227        x = ctx.convert(x)
228        if x == 0 or x == 1:
229            return x
230        return ctx.digamma(x+1) + ctx.euler
231
232    nstr = str
233
234    def to_fixed(ctx, x, prec):
235        return int(math.ldexp(x, prec))
236
237    def rand(ctx):
238        import random
239        return random.random()
240
241    _erf = staticmethod(math2.erf)
242    _erfc = staticmethod(math2.erfc)
243
244    def sum_accurately(ctx, terms, check_step=1):
245        s = ctx.zero
246        k = 0
247        for term in terms():
248            s += term
249            if (not k % check_step) and term:
250                if abs(term) <= 1e-18*abs(s):
251                    break
252            k += 1
253        return s
254