1from ..libmp.backend import xrange
2
3class SpecialFunctions(object):
4    """
5    This class implements special functions using high-level code.
6
7    Elementary and some other functions (e.g. gamma function, basecase
8    hypergeometric series) are assumed to be predefined by the context as
9    "builtins" or "low-level" functions.
10    """
11    defined_functions = {}
12
13    # The series for the Jacobi theta functions converge for |q| < 1;
14    # in the current implementation they throw a ValueError for
15    # abs(q) > THETA_Q_LIM
16    THETA_Q_LIM = 1 - 10**-7
17
18    def __init__(self):
19        cls = self.__class__
20        for name in cls.defined_functions:
21            f, wrap = cls.defined_functions[name]
22            cls._wrap_specfun(name, f, wrap)
23
24        self.mpq_1 = self._mpq((1,1))
25        self.mpq_0 = self._mpq((0,1))
26        self.mpq_1_2 = self._mpq((1,2))
27        self.mpq_3_2 = self._mpq((3,2))
28        self.mpq_1_4 = self._mpq((1,4))
29        self.mpq_1_16 = self._mpq((1,16))
30        self.mpq_3_16 = self._mpq((3,16))
31        self.mpq_5_2 = self._mpq((5,2))
32        self.mpq_3_4 = self._mpq((3,4))
33        self.mpq_7_4 = self._mpq((7,4))
34        self.mpq_5_4 = self._mpq((5,4))
35        self.mpq_1_3 = self._mpq((1,3))
36        self.mpq_2_3 = self._mpq((2,3))
37        self.mpq_4_3 = self._mpq((4,3))
38        self.mpq_1_6 = self._mpq((1,6))
39        self.mpq_5_6 = self._mpq((5,6))
40        self.mpq_5_3 = self._mpq((5,3))
41
42        self._misc_const_cache = {}
43
44        self._aliases.update({
45            'phase' : 'arg',
46            'conjugate' : 'conj',
47            'nthroot' : 'root',
48            'polygamma' : 'psi',
49            'hurwitz' : 'zeta',
50            #'digamma' : 'psi0',
51            #'trigamma' : 'psi1',
52            #'tetragamma' : 'psi2',
53            #'pentagamma' : 'psi3',
54            'fibonacci' : 'fib',
55            'factorial' : 'fac',
56        })
57
58        self.zetazero_memoized = self.memoize(self.zetazero)
59
60    # Default -- do nothing
61    @classmethod
62    def _wrap_specfun(cls, name, f, wrap):
63        setattr(cls, name, f)
64
65    # Optional fast versions of common functions in common cases.
66    # If not overridden, default (generic hypergeometric series)
67    # implementations will be used
68    def _besselj(ctx, n, z): raise NotImplementedError
69    def _erf(ctx, z): raise NotImplementedError
70    def _erfc(ctx, z): raise NotImplementedError
71    def _gamma_upper_int(ctx, z, a): raise NotImplementedError
72    def _expint_int(ctx, n, z): raise NotImplementedError
73    def _zeta(ctx, s): raise NotImplementedError
74    def _zetasum_fast(ctx, s, a, n, derivatives, reflect): raise NotImplementedError
75    def _ei(ctx, z): raise NotImplementedError
76    def _e1(ctx, z): raise NotImplementedError
77    def _ci(ctx, z): raise NotImplementedError
78    def _si(ctx, z): raise NotImplementedError
79    def _altzeta(ctx, s): raise NotImplementedError
80
81def defun_wrapped(f):
82    SpecialFunctions.defined_functions[f.__name__] = f, True
83    return f
84
85def defun(f):
86    SpecialFunctions.defined_functions[f.__name__] = f, False
87    return f
88
89def defun_static(f):
90    setattr(SpecialFunctions, f.__name__, f)
91    return f
92
93@defun_wrapped
94def cot(ctx, z): return ctx.one / ctx.tan(z)
95
96@defun_wrapped
97def sec(ctx, z): return ctx.one / ctx.cos(z)
98
99@defun_wrapped
100def csc(ctx, z): return ctx.one / ctx.sin(z)
101
102@defun_wrapped
103def coth(ctx, z): return ctx.one / ctx.tanh(z)
104
105@defun_wrapped
106def sech(ctx, z): return ctx.one / ctx.cosh(z)
107
108@defun_wrapped
109def csch(ctx, z): return ctx.one / ctx.sinh(z)
110
111@defun_wrapped
112def acot(ctx, z):
113    if not z:
114        return ctx.pi * 0.5
115    else:
116        return ctx.atan(ctx.one / z)
117
118@defun_wrapped
119def asec(ctx, z): return ctx.acos(ctx.one / z)
120
121@defun_wrapped
122def acsc(ctx, z): return ctx.asin(ctx.one / z)
123
124@defun_wrapped
125def acoth(ctx, z):
126    if not z:
127        return ctx.pi * 0.5j
128    else:
129        return ctx.atanh(ctx.one / z)
130
131
132@defun_wrapped
133def asech(ctx, z): return ctx.acosh(ctx.one / z)
134
135@defun_wrapped
136def acsch(ctx, z): return ctx.asinh(ctx.one / z)
137
138@defun
139def sign(ctx, x):
140    x = ctx.convert(x)
141    if not x or ctx.isnan(x):
142        return x
143    if ctx._is_real_type(x):
144        if x > 0:
145            return ctx.one
146        else:
147            return -ctx.one
148    return x / abs(x)
149
150@defun
151def agm(ctx, a, b=1):
152    if b == 1:
153        return ctx.agm1(a)
154    a = ctx.convert(a)
155    b = ctx.convert(b)
156    return ctx._agm(a, b)
157
158@defun_wrapped
159def sinc(ctx, x):
160    if ctx.isinf(x):
161        return 1/x
162    if not x:
163        return x+1
164    return ctx.sin(x)/x
165
166@defun_wrapped
167def sincpi(ctx, x):
168    if ctx.isinf(x):
169        return 1/x
170    if not x:
171        return x+1
172    return ctx.sinpi(x)/(ctx.pi*x)
173
174# TODO: tests; improve implementation
175@defun_wrapped
176def expm1(ctx, x):
177    if not x:
178        return ctx.zero
179    # exp(x) - 1 ~ x
180    if ctx.mag(x) < -ctx.prec:
181        return x + 0.5*x**2
182    # TODO: accurately eval the smaller of the real/imag parts
183    return ctx.sum_accurately(lambda: iter([ctx.exp(x),-1]),1)
184
185@defun_wrapped
186def log1p(ctx, x):
187    if not x:
188        return ctx.zero
189    if ctx.mag(x) < -ctx.prec:
190        return x - 0.5*x**2
191    return ctx.log(ctx.fadd(1, x, prec=2*ctx.prec))
192
193@defun_wrapped
194def powm1(ctx, x, y):
195    mag = ctx.mag
196    one = ctx.one
197    w = x**y - one
198    M = mag(w)
199    # Only moderate cancellation
200    if M > -8:
201        return w
202    # Check for the only possible exact cases
203    if not w:
204        if (not y) or (x in (1, -1, 1j, -1j) and ctx.isint(y)):
205            return w
206    x1 = x - one
207    magy = mag(y)
208    lnx = ctx.ln(x)
209    # Small y: x^y - 1 ~ log(x)*y + O(log(x)^2 * y^2)
210    if magy + mag(lnx) < -ctx.prec:
211        return lnx*y + (lnx*y)**2/2
212    # TODO: accurately eval the smaller of the real/imag part
213    return ctx.sum_accurately(lambda: iter([x**y, -1]), 1)
214
215@defun
216def _rootof1(ctx, k, n):
217    k = int(k)
218    n = int(n)
219    k %= n
220    if not k:
221        return ctx.one
222    elif 2*k == n:
223        return -ctx.one
224    elif 4*k == n:
225        return ctx.j
226    elif 4*k == 3*n:
227        return -ctx.j
228    return ctx.expjpi(2*ctx.mpf(k)/n)
229
230@defun
231def root(ctx, x, n, k=0):
232    n = int(n)
233    x = ctx.convert(x)
234    if k:
235        # Special case: there is an exact real root
236        if (n & 1 and 2*k == n-1) and (not ctx.im(x)) and (ctx.re(x) < 0):
237            return -ctx.root(-x, n)
238        # Multiply by root of unity
239        prec = ctx.prec
240        try:
241            ctx.prec += 10
242            v = ctx.root(x, n, 0) * ctx._rootof1(k, n)
243        finally:
244            ctx.prec = prec
245        return +v
246    return ctx._nthroot(x, n)
247
248@defun
249def unitroots(ctx, n, primitive=False):
250    gcd = ctx._gcd
251    prec = ctx.prec
252    try:
253        ctx.prec += 10
254        if primitive:
255            v = [ctx._rootof1(k,n) for k in range(n) if gcd(k,n) == 1]
256        else:
257            # TODO: this can be done *much* faster
258            v = [ctx._rootof1(k,n) for k in range(n)]
259    finally:
260        ctx.prec = prec
261    return [+x for x in v]
262
263@defun
264def arg(ctx, x):
265    x = ctx.convert(x)
266    re = ctx._re(x)
267    im = ctx._im(x)
268    return ctx.atan2(im, re)
269
270@defun
271def fabs(ctx, x):
272    return abs(ctx.convert(x))
273
274@defun
275def re(ctx, x):
276    x = ctx.convert(x)
277    if hasattr(x, "real"):    # py2.5 doesn't have .real/.imag for all numbers
278        return x.real
279    return x
280
281@defun
282def im(ctx, x):
283    x = ctx.convert(x)
284    if hasattr(x, "imag"):    # py2.5 doesn't have .real/.imag for all numbers
285        return x.imag
286    return ctx.zero
287
288@defun
289def conj(ctx, x):
290    x = ctx.convert(x)
291    try:
292        return x.conjugate()
293    except AttributeError:
294        return x
295
296@defun
297def polar(ctx, z):
298    return (ctx.fabs(z), ctx.arg(z))
299
300@defun_wrapped
301def rect(ctx, r, phi):
302    return r * ctx.mpc(*ctx.cos_sin(phi))
303
304@defun
305def log(ctx, x, b=None):
306    if b is None:
307        return ctx.ln(x)
308    wp = ctx.prec + 20
309    return ctx.ln(x, prec=wp) / ctx.ln(b, prec=wp)
310
311@defun
312def log10(ctx, x):
313    return ctx.log(x, 10)
314
315@defun
316def fmod(ctx, x, y):
317    return ctx.convert(x) % ctx.convert(y)
318
319@defun
320def degrees(ctx, x):
321    return x / ctx.degree
322
323@defun
324def radians(ctx, x):
325    return x * ctx.degree
326
327def _lambertw_special(ctx, z, k):
328    # W(0,0) = 0; all other branches are singular
329    if not z:
330        if not k:
331            return z
332        return ctx.ninf + z
333    if z == ctx.inf:
334        if k == 0:
335            return z
336        else:
337            return z + 2*k*ctx.pi*ctx.j
338    if z == ctx.ninf:
339        return (-z) + (2*k+1)*ctx.pi*ctx.j
340    # Some kind of nan or complex inf/nan?
341    return ctx.ln(z)
342
343import math
344import cmath
345
346def _lambertw_approx_hybrid(z, k):
347    imag_sign = 0
348    if hasattr(z, "imag"):
349        x = float(z.real)
350        y = z.imag
351        if y:
352            imag_sign = (-1) ** (y < 0)
353        y = float(y)
354    else:
355        x = float(z)
356        y = 0.0
357        imag_sign = 0
358    # hack to work regardless of whether Python supports -0.0
359    if not y:
360        y = 0.0
361    z = complex(x,y)
362    if k == 0:
363        if -4.0 < y < 4.0 and -1.0 < x < 2.5:
364            if imag_sign:
365                # Taylor series in upper/lower half-plane
366                if y > 1.00: return (0.876+0.645j) + (0.118-0.174j)*(z-(0.75+2.5j))
367                if y > 0.25: return (0.505+0.204j) + (0.375-0.132j)*(z-(0.75+0.5j))
368                if y < -1.00: return (0.876-0.645j) + (0.118+0.174j)*(z-(0.75-2.5j))
369                if y < -0.25: return (0.505-0.204j) + (0.375+0.132j)*(z-(0.75-0.5j))
370            # Taylor series near -1
371            if x < -0.5:
372                if imag_sign >= 0:
373                    return (-0.318+1.34j) + (-0.697-0.593j)*(z+1)
374                else:
375                    return (-0.318-1.34j) + (-0.697+0.593j)*(z+1)
376            # return real type
377            r = -0.367879441171442
378            if (not imag_sign) and x > r:
379                z = x
380            # Singularity near -1/e
381            if x < -0.2:
382                return -1 + 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r)
383            # Taylor series near 0
384            if x < 0.5: return z
385            # Simple linear approximation
386            return 0.2 + 0.3*z
387        if (not imag_sign) and x > 0.0:
388            L1 = math.log(x); L2 = math.log(L1)
389        else:
390            L1 = cmath.log(z); L2 = cmath.log(L1)
391    elif k == -1:
392        # return real type
393        r = -0.367879441171442
394        if (not imag_sign) and r < x < 0.0:
395            z = x
396        if (imag_sign >= 0) and y < 0.1 and -0.6 < x < -0.2:
397            return -1 - 2.33164398159712*(z-r)**0.5 - 1.81218788563936*(z-r)
398        if (not imag_sign) and -0.2 <= x < 0.0:
399            L1 = math.log(-x)
400            return L1 - math.log(-L1)
401        else:
402            if imag_sign == -1 and (not y) and x < 0.0:
403                L1 = cmath.log(z) - 3.1415926535897932j
404            else:
405                L1 = cmath.log(z) - 6.2831853071795865j
406            L2 = cmath.log(L1)
407    return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2)
408
409def _lambertw_series(ctx, z, k, tol):
410    """
411    Return rough approximation for W_k(z) from an asymptotic series,
412    sufficiently accurate for the Halley iteration to converge to
413    the correct value.
414    """
415    magz = ctx.mag(z)
416    if (-10 < magz < 900) and (-1000 < k < 1000):
417        # Near the branch point at -1/e
418        if magz < 1 and abs(z+0.36787944117144) < 0.05:
419            if k == 0 or (k == -1 and ctx._im(z) >= 0) or \
420                         (k == 1  and ctx._im(z) < 0):
421                delta = ctx.sum_accurately(lambda: [z, ctx.exp(-1)])
422                cancellation = -ctx.mag(delta)
423                ctx.prec += cancellation
424                # Use series given in Corless et al.
425                p = ctx.sqrt(2*(ctx.e*z+1))
426                ctx.prec -= cancellation
427                u = {0:ctx.mpf(-1), 1:ctx.mpf(1)}
428                a = {0:ctx.mpf(2), 1:ctx.mpf(-1)}
429                if k != 0:
430                    p = -p
431                s = ctx.zero
432                # The series converges, so we could use it directly, but unless
433                # *extremely* close, it is better to just use the first few
434                # terms to get a good approximation for the iteration
435                for l in xrange(max(2,cancellation)):
436                    if l not in u:
437                        a[l] = ctx.fsum(u[j]*u[l+1-j] for j in xrange(2,l))
438                        u[l] = (l-1)*(u[l-2]/2+a[l-2]/4)/(l+1)-a[l]/2-u[l-1]/(l+1)
439                    term = u[l] * p**l
440                    s += term
441                    if ctx.mag(term) < -tol:
442                        return s, True
443                    l += 1
444                ctx.prec += cancellation//2
445                return s, False
446        if k == 0 or k == -1:
447            return _lambertw_approx_hybrid(z, k), False
448    if k == 0:
449        if magz < -1:
450            return z*(1-z), False
451        L1 = ctx.ln(z)
452        L2 = ctx.ln(L1)
453    elif k == -1 and (not ctx._im(z)) and (-0.36787944117144 < ctx._re(z) < 0):
454        L1 = ctx.ln(-z)
455        return L1 - ctx.ln(-L1), False
456    else:
457        # This holds both as z -> 0 and z -> inf.
458        # Relative error is O(1/log(z)).
459        L1 = ctx.ln(z) + 2j*ctx.pi*k
460        L2 = ctx.ln(L1)
461    return L1 - L2 + L2/L1 + L2*(L2-2)/(2*L1**2), False
462
463@defun
464def lambertw(ctx, z, k=0):
465    z = ctx.convert(z)
466    k = int(k)
467    if not ctx.isnormal(z):
468        return _lambertw_special(ctx, z, k)
469    prec = ctx.prec
470    ctx.prec += 20 + ctx.mag(k or 1)
471    wp = ctx.prec
472    tol = wp - 5
473    w, done = _lambertw_series(ctx, z, k, tol)
474    if not done:
475        # Use Halley iteration to solve w*exp(w) = z
476        two = ctx.mpf(2)
477        for i in xrange(100):
478            ew = ctx.exp(w)
479            wew = w*ew
480            wewz = wew-z
481            wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two))
482            if ctx.mag(wn-w) <= ctx.mag(wn) - tol:
483                w = wn
484                break
485            else:
486                w = wn
487        if i == 100:
488            ctx.warn("Lambert W iteration failed to converge for z = %s" % z)
489    ctx.prec = prec
490    return +w
491
492@defun_wrapped
493def bell(ctx, n, x=1):
494    x = ctx.convert(x)
495    if not n:
496        if ctx.isnan(x):
497            return x
498        return type(x)(1)
499    if ctx.isinf(x) or ctx.isinf(n) or ctx.isnan(x) or ctx.isnan(n):
500        return x**n
501    if n == 1: return x
502    if n == 2: return x*(x+1)
503    if x == 0: return ctx.sincpi(n)
504    return _polyexp(ctx, n, x, True) / ctx.exp(x)
505
506def _polyexp(ctx, n, x, extra=False):
507    def _terms():
508        if extra:
509            yield ctx.sincpi(n)
510        t = x
511        k = 1
512        while 1:
513            yield k**n * t
514            k += 1
515            t = t*x/k
516    return ctx.sum_accurately(_terms, check_step=4)
517
518@defun_wrapped
519def polyexp(ctx, s, z):
520    if ctx.isinf(z) or ctx.isinf(s) or ctx.isnan(z) or ctx.isnan(s):
521        return z**s
522    if z == 0: return z*s
523    if s == 0: return ctx.expm1(z)
524    if s == 1: return ctx.exp(z)*z
525    if s == 2: return ctx.exp(z)*z*(z+1)
526    return _polyexp(ctx, s, z)
527
528@defun_wrapped
529def cyclotomic(ctx, n, z):
530    n = int(n)
531    if n < 0:
532        raise ValueError("n cannot be negative")
533    p = ctx.one
534    if n == 0:
535        return p
536    if n == 1:
537        return z - p
538    if n == 2:
539        return z + p
540    # Use divisor product representation. Unfortunately, this sometimes
541    # includes singularities for roots of unity, which we have to cancel out.
542    # Matching zeros/poles pairwise, we have (1-z^a)/(1-z^b) ~ a/b + O(z-1).
543    a_prod = 1
544    b_prod = 1
545    num_zeros = 0
546    num_poles = 0
547    for d in range(1,n+1):
548        if not n % d:
549            w = ctx.moebius(n//d)
550            # Use powm1 because it is important that we get 0 only
551            # if it really is exactly 0
552            b = -ctx.powm1(z, d)
553            if b:
554                p *= b**w
555            else:
556                if w == 1:
557                    a_prod *= d
558                    num_zeros += 1
559                elif w == -1:
560                    b_prod *= d
561                    num_poles += 1
562    #print n, num_zeros, num_poles
563    if num_zeros:
564        if num_zeros > num_poles:
565            p *= 0
566        else:
567            p *= a_prod
568            p /= b_prod
569    return p
570
571@defun
572def mangoldt(ctx, n):
573    r"""
574    Evaluates the von Mangoldt function `\Lambda(n) = \log p`
575    if `n = p^k` a power of a prime, and `\Lambda(n) = 0` otherwise.
576
577    **Examples**
578
579        >>> from mpmath import *
580        >>> mp.dps = 25; mp.pretty = True
581        >>> [mangoldt(n) for n in range(-2,3)]
582        [0.0, 0.0, 0.0, 0.0, 0.6931471805599453094172321]
583        >>> mangoldt(6)
584        0.0
585        >>> mangoldt(7)
586        1.945910149055313305105353
587        >>> mangoldt(8)
588        0.6931471805599453094172321
589        >>> fsum(mangoldt(n) for n in range(101))
590        94.04531122935739224600493
591        >>> fsum(mangoldt(n) for n in range(10001))
592        10013.39669326311478372032
593
594    """
595    n = int(n)
596    if n < 2:
597        return ctx.zero
598    if n % 2 == 0:
599        # Must be a power of two
600        if n & (n-1) == 0:
601            return +ctx.ln2
602        else:
603            return ctx.zero
604    # TODO: the following could be generalized into a perfect
605    # power testing function
606    # ---
607    # Look for a small factor
608    for p in (3,5,7,11,13,17,19,23,29,31):
609        if not n % p:
610            q, r = n // p, 0
611            while q > 1:
612                q, r = divmod(q, p)
613                if r:
614                    return ctx.zero
615            return ctx.ln(p)
616    if ctx.isprime(n):
617        return ctx.ln(n)
618    # Obviously, we could use arbitrary-precision arithmetic for this...
619    if n > 10**30:
620        raise NotImplementedError
621    k = 2
622    while 1:
623        p = int(n**(1./k) + 0.5)
624        if p < 2:
625            return ctx.zero
626        if p ** k == n:
627            if ctx.isprime(p):
628                return ctx.ln(p)
629        k += 1
630
631@defun
632def stirling1(ctx, n, k, exact=False):
633    v = ctx._stirling1(int(n), int(k))
634    if exact:
635        return int(v)
636    else:
637        return ctx.mpf(v)
638
639@defun
640def stirling2(ctx, n, k, exact=False):
641    v = ctx._stirling2(int(n), int(k))
642    if exact:
643        return int(v)
644    else:
645        return ctx.mpf(v)
646