1from .functions import defun, defun_wrapped
2
3def _hermite_param(ctx, n, z, parabolic_cylinder):
4    """
5    Combined calculation of the Hermite polynomial H_n(z) (and its
6    generalization to complex n) and the parabolic cylinder
7    function D.
8    """
9    n, ntyp = ctx._convert_param(n)
10    z = ctx.convert(z)
11    q = -ctx.mpq_1_2
12    # For re(z) > 0, 2F0 -- http://functions.wolfram.com/
13    #     HypergeometricFunctions/HermiteHGeneral/06/02/0009/
14    # Otherwise, there is a reflection formula
15    # 2F0 + http://functions.wolfram.com/HypergeometricFunctions/
16    #           HermiteHGeneral/16/01/01/0006/
17    #
18    # TODO:
19    # An alternative would be to use
20    # http://functions.wolfram.com/HypergeometricFunctions/
21    #     HermiteHGeneral/06/02/0006/
22    #
23    # Also, the 1F1 expansion
24    # http://functions.wolfram.com/HypergeometricFunctions/
25    #     HermiteHGeneral/26/01/02/0001/
26    # should probably be used for tiny z
27    if not z:
28        T1 = [2, ctx.pi], [n, 0.5], [], [q*(n-1)], [], [], 0
29        if parabolic_cylinder:
30            T1[1][0] += q*n
31        return T1,
32    can_use_2f0 = ctx.isnpint(-n) or ctx.re(z) > 0 or \
33        (ctx.re(z) == 0 and ctx.im(z) > 0)
34    expprec = ctx.prec*4 + 20
35    if parabolic_cylinder:
36        u = ctx.fmul(ctx.fmul(z,z,prec=expprec), -0.25, exact=True)
37        w = ctx.fmul(z, ctx.sqrt(0.5,prec=expprec), prec=expprec)
38    else:
39        w = z
40    w2 = ctx.fmul(w, w, prec=expprec)
41    rw2 = ctx.fdiv(1, w2, prec=expprec)
42    nrw2 = ctx.fneg(rw2, exact=True)
43    nw = ctx.fneg(w, exact=True)
44    if can_use_2f0:
45        T1 = [2, w], [n, n], [], [], [q*n, q*(n-1)], [], nrw2
46        terms = [T1]
47    else:
48        T1 = [2, nw], [n, n], [], [], [q*n, q*(n-1)], [], nrw2
49        T2 = [2, ctx.pi, nw], [n+2, 0.5, 1], [], [q*n], [q*(n-1)], [1-q], w2
50        terms = [T1,T2]
51    # Multiply by prefactor for D_n
52    if parabolic_cylinder:
53        expu = ctx.exp(u)
54        for i in range(len(terms)):
55            terms[i][1][0] += q*n
56            terms[i][0].append(expu)
57            terms[i][1].append(1)
58    return tuple(terms)
59
60@defun
61def hermite(ctx, n, z, **kwargs):
62    return ctx.hypercomb(lambda: _hermite_param(ctx, n, z, 0), [], **kwargs)
63
64@defun
65def pcfd(ctx, n, z, **kwargs):
66    r"""
67    Gives the parabolic cylinder function in Whittaker's notation
68    `D_n(z) = U(-n-1/2, z)` (see :func:`~mpmath.pcfu`).
69    It solves the differential equation
70
71    .. math ::
72
73        y'' + \left(n + \frac{1}{2} - \frac{1}{4} z^2\right) y = 0.
74
75    and can be represented in terms of Hermite polynomials
76    (see :func:`~mpmath.hermite`) as
77
78    .. math ::
79
80        D_n(z) = 2^{-n/2} e^{-z^2/4} H_n\left(\frac{z}{\sqrt{2}}\right).
81
82    **Plots**
83
84    .. literalinclude :: /plots/pcfd.py
85    .. image :: /plots/pcfd.png
86
87    **Examples**
88
89        >>> from mpmath import *
90        >>> mp.dps = 25; mp.pretty = True
91        >>> pcfd(0,0); pcfd(1,0); pcfd(2,0); pcfd(3,0)
92        1.0
93        0.0
94        -1.0
95        0.0
96        >>> pcfd(4,0); pcfd(-3,0)
97        3.0
98        0.6266570686577501256039413
99        >>> pcfd('1/2', 2+3j)
100        (-5.363331161232920734849056 - 3.858877821790010714163487j)
101        >>> pcfd(2, -10)
102        1.374906442631438038871515e-9
103
104    Verifying the differential equation::
105
106        >>> n = mpf(2.5)
107        >>> y = lambda z: pcfd(n,z)
108        >>> z = 1.75
109        >>> chop(diff(y,z,2) + (n+0.5-0.25*z**2)*y(z))
110        0.0
111
112    Rational Taylor series expansion when `n` is an integer::
113
114        >>> taylor(lambda z: pcfd(5,z), 0, 7)
115        [0.0, 15.0, 0.0, -13.75, 0.0, 3.96875, 0.0, -0.6015625]
116
117    """
118    return ctx.hypercomb(lambda: _hermite_param(ctx, n, z, 1), [], **kwargs)
119
120@defun
121def pcfu(ctx, a, z, **kwargs):
122    r"""
123    Gives the parabolic cylinder function `U(a,z)`, which may be
124    defined for `\Re(z) > 0` in terms of the confluent
125    U-function (see :func:`~mpmath.hyperu`) by
126
127    .. math ::
128
129        U(a,z) = 2^{-\frac{1}{4}-\frac{a}{2}} e^{-\frac{1}{4} z^2}
130            U\left(\frac{a}{2}+\frac{1}{4},
131            \frac{1}{2}, \frac{1}{2}z^2\right)
132
133    or, for arbitrary `z`,
134
135    .. math ::
136
137        e^{-\frac{1}{4}z^2} U(a,z) =
138            U(a,0) \,_1F_1\left(-\tfrac{a}{2}+\tfrac{1}{4};
139            \tfrac{1}{2}; -\tfrac{1}{2}z^2\right) +
140            U'(a,0) z \,_1F_1\left(-\tfrac{a}{2}+\tfrac{3}{4};
141            \tfrac{3}{2}; -\tfrac{1}{2}z^2\right).
142
143    **Examples**
144
145    Connection to other functions::
146
147        >>> from mpmath import *
148        >>> mp.dps = 25; mp.pretty = True
149        >>> z = mpf(3)
150        >>> pcfu(0.5,z)
151        0.03210358129311151450551963
152        >>> sqrt(pi/2)*exp(z**2/4)*erfc(z/sqrt(2))
153        0.03210358129311151450551963
154        >>> pcfu(0.5,-z)
155        23.75012332835297233711255
156        >>> sqrt(pi/2)*exp(z**2/4)*erfc(-z/sqrt(2))
157        23.75012332835297233711255
158        >>> pcfu(0.5,-z)
159        23.75012332835297233711255
160        >>> sqrt(pi/2)*exp(z**2/4)*erfc(-z/sqrt(2))
161        23.75012332835297233711255
162
163    """
164    n, _ = ctx._convert_param(a)
165    return ctx.pcfd(-n-ctx.mpq_1_2, z)
166
167@defun
168def pcfv(ctx, a, z, **kwargs):
169    r"""
170    Gives the parabolic cylinder function `V(a,z)`, which can be
171    represented in terms of :func:`~mpmath.pcfu` as
172
173    .. math ::
174
175        V(a,z) = \frac{\Gamma(a+\tfrac{1}{2}) (U(a,-z)-\sin(\pi a) U(a,z)}{\pi}.
176
177    **Examples**
178
179    Wronskian relation between `U` and `V`::
180
181        >>> from mpmath import *
182        >>> mp.dps = 25; mp.pretty = True
183        >>> a, z = 2, 3
184        >>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
185        0.7978845608028653558798921
186        >>> sqrt(2/pi)
187        0.7978845608028653558798921
188        >>> a, z = 2.5, 3
189        >>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
190        0.7978845608028653558798921
191        >>> a, z = 0.25, -1
192        >>> pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z)
193        0.7978845608028653558798921
194        >>> a, z = 2+1j, 2+3j
195        >>> chop(pcfu(a,z)*diff(pcfv,(a,z),(0,1))-diff(pcfu,(a,z),(0,1))*pcfv(a,z))
196        0.7978845608028653558798921
197
198    """
199    n, ntype = ctx._convert_param(a)
200    z = ctx.convert(z)
201    q = ctx.mpq_1_2
202    r = ctx.mpq_1_4
203    if ntype == 'Q' and ctx.isint(n*2):
204        # Faster for half-integers
205        def h():
206            jz = ctx.fmul(z, -1j, exact=True)
207            T1terms = _hermite_param(ctx, -n-q, z, 1)
208            T2terms = _hermite_param(ctx, n-q, jz, 1)
209            for T in T1terms:
210                T[0].append(1j)
211                T[1].append(1)
212                T[3].append(q-n)
213            u = ctx.expjpi((q*n-r)) * ctx.sqrt(2/ctx.pi)
214            for T in T2terms:
215                T[0].append(u)
216                T[1].append(1)
217            return T1terms + T2terms
218        v = ctx.hypercomb(h, [], **kwargs)
219        if ctx._is_real_type(n) and ctx._is_real_type(z):
220            v = ctx._re(v)
221        return v
222    else:
223        def h(n):
224            w = ctx.square_exp_arg(z, -0.25)
225            u = ctx.square_exp_arg(z, 0.5)
226            e = ctx.exp(w)
227            l = [ctx.pi, q, ctx.exp(w)]
228            Y1 = l, [-q, n*q+r, 1], [r-q*n], [], [q*n+r], [q], u
229            Y2 = l + [z], [-q, n*q-r, 1, 1], [1-r-q*n], [], [q*n+1-r], [1+q], u
230            c, s = ctx.cospi_sinpi(r+q*n)
231            Y1[0].append(s)
232            Y2[0].append(c)
233            for Y in (Y1, Y2):
234                Y[1].append(1)
235                Y[3].append(q-n)
236            return Y1, Y2
237        return ctx.hypercomb(h, [n], **kwargs)
238
239
240@defun
241def pcfw(ctx, a, z, **kwargs):
242    r"""
243    Gives the parabolic cylinder function `W(a,z)` defined in (DLMF 12.14).
244
245    **Examples**
246
247    Value at the origin::
248
249        >>> from mpmath import *
250        >>> mp.dps = 25; mp.pretty = True
251        >>> a = mpf(0.25)
252        >>> pcfw(a,0)
253        0.9722833245718180765617104
254        >>> power(2,-0.75)*sqrt(abs(gamma(0.25+0.5j*a)/gamma(0.75+0.5j*a)))
255        0.9722833245718180765617104
256        >>> diff(pcfw,(a,0),(0,1))
257        -0.5142533944210078966003624
258        >>> -power(2,-0.25)*sqrt(abs(gamma(0.75+0.5j*a)/gamma(0.25+0.5j*a)))
259        -0.5142533944210078966003624
260
261    """
262    n, _ = ctx._convert_param(a)
263    z = ctx.convert(z)
264    def terms():
265        phi2 = ctx.arg(ctx.gamma(0.5 + ctx.j*n))
266        phi2 = (ctx.loggamma(0.5+ctx.j*n) - ctx.loggamma(0.5-ctx.j*n))/2j
267        rho = ctx.pi/8 + 0.5*phi2
268        # XXX: cancellation computing k
269        k = ctx.sqrt(1 + ctx.exp(2*ctx.pi*n)) - ctx.exp(ctx.pi*n)
270        C = ctx.sqrt(k/2) * ctx.exp(0.25*ctx.pi*n)
271        yield C * ctx.expj(rho) * ctx.pcfu(ctx.j*n, z*ctx.expjpi(-0.25))
272        yield C * ctx.expj(-rho) * ctx.pcfu(-ctx.j*n, z*ctx.expjpi(0.25))
273    v = ctx.sum_accurately(terms)
274    if ctx._is_real_type(n) and ctx._is_real_type(z):
275        v = ctx._re(v)
276    return v
277
278"""
279Even/odd PCFs. Useful?
280
281@defun
282def pcfy1(ctx, a, z, **kwargs):
283    a, _ = ctx._convert_param(n)
284    z = ctx.convert(z)
285    def h():
286        w = ctx.square_exp_arg(z)
287        w1 = ctx.fmul(w, -0.25, exact=True)
288        w2 = ctx.fmul(w, 0.5, exact=True)
289        e = ctx.exp(w1)
290        return [e], [1], [], [], [ctx.mpq_1_2*a+ctx.mpq_1_4], [ctx.mpq_1_2], w2
291    return ctx.hypercomb(h, [], **kwargs)
292
293@defun
294def pcfy2(ctx, a, z, **kwargs):
295    a, _ = ctx._convert_param(n)
296    z = ctx.convert(z)
297    def h():
298        w = ctx.square_exp_arg(z)
299        w1 = ctx.fmul(w, -0.25, exact=True)
300        w2 = ctx.fmul(w, 0.5, exact=True)
301        e = ctx.exp(w1)
302        return [e, z], [1, 1], [], [], [ctx.mpq_1_2*a+ctx.mpq_3_4], \
303            [ctx.mpq_3_2], w2
304    return ctx.hypercomb(h, [], **kwargs)
305"""
306
307@defun_wrapped
308def gegenbauer(ctx, n, a, z, **kwargs):
309    # Special cases: a+0.5, a*2 poles
310    if ctx.isnpint(a):
311        return 0*(z+n)
312    if ctx.isnpint(a+0.5):
313        # TODO: something else is required here
314        # E.g.: gegenbauer(-2, -0.5, 3) == -12
315        if ctx.isnpint(n+1):
316            raise NotImplementedError("Gegenbauer function with two limits")
317        def h(a):
318            a2 = 2*a
319            T = [], [], [n+a2], [n+1, a2], [-n, n+a2], [a+0.5], 0.5*(1-z)
320            return [T]
321        return ctx.hypercomb(h, [a], **kwargs)
322    def h(n):
323        a2 = 2*a
324        T = [], [], [n+a2], [n+1, a2], [-n, n+a2], [a+0.5], 0.5*(1-z)
325        return [T]
326    return ctx.hypercomb(h, [n], **kwargs)
327
328@defun_wrapped
329def jacobi(ctx, n, a, b, x, **kwargs):
330    if not ctx.isnpint(a):
331        def h(n):
332            return (([], [], [a+n+1], [n+1, a+1], [-n, a+b+n+1], [a+1], (1-x)*0.5),)
333        return ctx.hypercomb(h, [n], **kwargs)
334    if not ctx.isint(b):
335        def h(n, a):
336            return (([], [], [-b], [n+1, -b-n], [-n, a+b+n+1], [b+1], (x+1)*0.5),)
337        return ctx.hypercomb(h, [n, a], **kwargs)
338    # XXX: determine appropriate limit
339    return ctx.binomial(n+a,n) * ctx.hyp2f1(-n,1+n+a+b,a+1,(1-x)/2, **kwargs)
340
341@defun_wrapped
342def laguerre(ctx, n, a, z, **kwargs):
343    # XXX: limits, poles
344    #if ctx.isnpint(n):
345    #    return 0*(a+z)
346    def h(a):
347        return (([], [], [a+n+1], [a+1, n+1], [-n], [a+1], z),)
348    return ctx.hypercomb(h, [a], **kwargs)
349
350@defun_wrapped
351def legendre(ctx, n, x, **kwargs):
352    if ctx.isint(n):
353        n = int(n)
354        # Accuracy near zeros
355        if (n + (n < 0)) & 1:
356            if not x:
357                return x
358            mag = ctx.mag(x)
359            if mag < -2*ctx.prec-10:
360                return x
361            if mag < -5:
362                ctx.prec += -mag
363    return ctx.hyp2f1(-n,n+1,1,(1-x)/2, **kwargs)
364
365@defun
366def legenp(ctx, n, m, z, type=2, **kwargs):
367    # Legendre function, 1st kind
368    n = ctx.convert(n)
369    m = ctx.convert(m)
370    # Faster
371    if not m:
372        return ctx.legendre(n, z, **kwargs)
373    # TODO: correct evaluation at singularities
374    if type == 2:
375        def h(n,m):
376            g = m*0.5
377            T = [1+z, 1-z], [g, -g], [], [1-m], [-n, n+1], [1-m], 0.5*(1-z)
378            return (T,)
379        return ctx.hypercomb(h, [n,m], **kwargs)
380    if type == 3:
381        def h(n,m):
382            g = m*0.5
383            T = [z+1, z-1], [g, -g], [], [1-m], [-n, n+1], [1-m], 0.5*(1-z)
384            return (T,)
385        return ctx.hypercomb(h, [n,m], **kwargs)
386    raise ValueError("requires type=2 or type=3")
387
388@defun
389def legenq(ctx, n, m, z, type=2, **kwargs):
390    # Legendre function, 2nd kind
391    n = ctx.convert(n)
392    m = ctx.convert(m)
393    z = ctx.convert(z)
394    if z in (1, -1):
395        #if ctx.isint(m):
396        #    return ctx.nan
397        #return ctx.inf  # unsigned
398        return ctx.nan
399    if type == 2:
400        def h(n, m):
401            cos, sin = ctx.cospi_sinpi(m)
402            s = 2 * sin / ctx.pi
403            c = cos
404            a = 1+z
405            b = 1-z
406            u = m/2
407            w = (1-z)/2
408            T1 = [s, c, a, b], [-1, 1, u, -u], [], [1-m], \
409                [-n, n+1], [1-m], w
410            T2 = [-s, a, b], [-1, -u, u], [n+m+1], [n-m+1, m+1], \
411                [-n, n+1], [m+1], w
412            return T1, T2
413        return ctx.hypercomb(h, [n, m], **kwargs)
414    if type == 3:
415        # The following is faster when there only is a single series
416        # Note: not valid for -1 < z < 0 (?)
417        if abs(z) > 1:
418            def h(n, m):
419                T1 = [ctx.expjpi(m), 2, ctx.pi, z, z-1, z+1], \
420                     [1, -n-1, 0.5, -n-m-1, 0.5*m, 0.5*m], \
421                     [n+m+1], [n+1.5], \
422                     [0.5*(2+n+m), 0.5*(1+n+m)], [n+1.5], z**(-2)
423                return [T1]
424            return ctx.hypercomb(h, [n, m], **kwargs)
425        else:
426            # not valid for 1 < z < inf ?
427            def h(n, m):
428                s = 2 * ctx.sinpi(m) / ctx.pi
429                c = ctx.expjpi(m)
430                a = 1+z
431                b = z-1
432                u = m/2
433                w = (1-z)/2
434                T1 = [s, c, a, b], [-1, 1, u, -u], [], [1-m], \
435                    [-n, n+1], [1-m], w
436                T2 = [-s, c, a, b], [-1, 1, -u, u], [n+m+1], [n-m+1, m+1], \
437                    [-n, n+1], [m+1], w
438                return T1, T2
439            return ctx.hypercomb(h, [n, m], **kwargs)
440    raise ValueError("requires type=2 or type=3")
441
442@defun_wrapped
443def chebyt(ctx, n, x, **kwargs):
444    if (not x) and ctx.isint(n) and int(ctx._re(n)) % 2 == 1:
445        return x * 0
446    return ctx.hyp2f1(-n,n,(1,2),(1-x)/2, **kwargs)
447
448@defun_wrapped
449def chebyu(ctx, n, x, **kwargs):
450    if (not x) and ctx.isint(n) and int(ctx._re(n)) % 2 == 1:
451        return x * 0
452    return (n+1) * ctx.hyp2f1(-n, n+2, (3,2), (1-x)/2, **kwargs)
453
454@defun
455def spherharm(ctx, l, m, theta, phi, **kwargs):
456    l = ctx.convert(l)
457    m = ctx.convert(m)
458    theta = ctx.convert(theta)
459    phi = ctx.convert(phi)
460    l_isint = ctx.isint(l)
461    l_natural = l_isint and l >= 0
462    m_isint = ctx.isint(m)
463    if l_isint and l < 0 and m_isint:
464        return ctx.spherharm(-(l+1), m, theta, phi, **kwargs)
465    if theta == 0 and m_isint and m < 0:
466        return ctx.zero * 1j
467    if l_natural and m_isint:
468        if abs(m) > l:
469            return ctx.zero * 1j
470        # http://functions.wolfram.com/Polynomials/
471        #     SphericalHarmonicY/26/01/02/0004/
472        def h(l,m):
473            absm = abs(m)
474            C = [-1, ctx.expj(m*phi),
475                 (2*l+1)*ctx.fac(l+absm)/ctx.pi/ctx.fac(l-absm),
476                 ctx.sin(theta)**2,
477                 ctx.fac(absm), 2]
478            P = [0.5*m*(ctx.sign(m)+1), 1, 0.5, 0.5*absm, -1, -absm-1]
479            return ((C, P, [], [], [absm-l, l+absm+1], [absm+1],
480                ctx.sin(0.5*theta)**2),)
481    else:
482        # http://functions.wolfram.com/HypergeometricFunctions/
483        #     SphericalHarmonicYGeneral/26/01/02/0001/
484        def h(l,m):
485            if ctx.isnpint(l-m+1) or ctx.isnpint(l+m+1) or ctx.isnpint(1-m):
486                return (([0], [-1], [], [], [], [], 0),)
487            cos, sin = ctx.cos_sin(0.5*theta)
488            C = [0.5*ctx.expj(m*phi), (2*l+1)/ctx.pi,
489                 ctx.gamma(l-m+1), ctx.gamma(l+m+1),
490                 cos**2, sin**2]
491            P = [1, 0.5, 0.5, -0.5, 0.5*m, -0.5*m]
492            return ((C, P, [], [1-m], [-l,l+1], [1-m], sin**2),)
493    return ctx.hypercomb(h, [l,m], **kwargs)
494