1from ..libmp.backend import xrange
2from .functions import defun, defun_wrapped
3
4@defun
5def gammaprod(ctx, a, b, _infsign=False):
6    a = [ctx.convert(x) for x in a]
7    b = [ctx.convert(x) for x in b]
8    poles_num = []
9    poles_den = []
10    regular_num = []
11    regular_den = []
12    for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
13    for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
14    # One more pole in numerator or denominator gives 0 or inf
15    if len(poles_num) < len(poles_den): return ctx.zero
16    if len(poles_num) > len(poles_den):
17        # Get correct sign of infinity for x+h, h -> 0 from above
18        # XXX: hack, this should be done properly
19        if _infsign:
20            a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
21            b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
22            return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
23        else:
24            return ctx.inf
25    # All poles cancel
26    # lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
27    p = ctx.one
28    orig = ctx.prec
29    try:
30        ctx.prec = orig + 15
31        while poles_num:
32            i = poles_num.pop()
33            j = poles_den.pop()
34            p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
35        for x in regular_num: p *= ctx.gamma(x)
36        for x in regular_den: p /= ctx.gamma(x)
37    finally:
38        ctx.prec = orig
39    return +p
40
41@defun
42def beta(ctx, x, y):
43    x = ctx.convert(x)
44    y = ctx.convert(y)
45    if ctx.isinf(y):
46        x, y = y, x
47    if ctx.isinf(x):
48        if x == ctx.inf and not ctx._im(y):
49            if y == ctx.ninf:
50                return ctx.nan
51            if y > 0:
52                return ctx.zero
53            if ctx.isint(y):
54                return ctx.nan
55            if y < 0:
56                return ctx.sign(ctx.gamma(y)) * ctx.inf
57        return ctx.nan
58    xy = ctx.fadd(x, y, prec=2*ctx.prec)
59    return ctx.gammaprod([x, y], [xy])
60
61@defun
62def binomial(ctx, n, k):
63    n1 = ctx.fadd(n, 1, prec=2*ctx.prec)
64    k1 = ctx.fadd(k, 1, prec=2*ctx.prec)
65    nk1 = ctx.fsub(n1, k, prec=2*ctx.prec)
66    return ctx.gammaprod([n1], [k1, nk1])
67
68@defun
69def rf(ctx, x, n):
70    xn = ctx.fadd(x, n, prec=2*ctx.prec)
71    return ctx.gammaprod([xn], [x])
72
73@defun
74def ff(ctx, x, n):
75    x1 = ctx.fadd(x, 1, prec=2*ctx.prec)
76    xn1 = ctx.fadd(ctx.fsub(x, n, prec=2*ctx.prec), 1, prec=2*ctx.prec)
77    return ctx.gammaprod([x1], [xn1])
78
79@defun_wrapped
80def fac2(ctx, x):
81    if ctx.isinf(x):
82        if x == ctx.inf:
83            return x
84        return ctx.nan
85    return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
86
87@defun_wrapped
88def barnesg(ctx, z):
89    if ctx.isinf(z):
90        if z == ctx.inf:
91            return z
92        return ctx.nan
93    if ctx.isnan(z):
94        return z
95    if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
96        return z*0
97    # Account for size (would not be needed if computing log(G))
98    if abs(z) > 5:
99        ctx.dps += 2*ctx.log(abs(z),2)
100    # Reflection formula
101    if ctx.re(z) < -ctx.dps:
102        w = 1-z
103        pi2 = 2*ctx.pi
104        u = ctx.expjpi(2*w)
105        v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
106            ctx.j*ctx.polylog(2, u)/pi2
107        v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
108        if ctx._is_real_type(z):
109            v = ctx._re(v)
110        return v
111    # Estimate terms for asymptotic expansion
112    # TODO: fixme, obviously
113    N = ctx.dps // 2 + 5
114    G = 1
115    while abs(z) < N or ctx.re(z) < 1:
116        G /= ctx.gamma(z)
117        z += 1
118    z -= 1
119    s = ctx.mpf(1)/12
120    s -= ctx.log(ctx.glaisher)
121    s += z*ctx.log(2*ctx.pi)/2
122    s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
123    s -= 3*z**2/4
124    z2k = z2 = z**2
125    for k in xrange(1, N+1):
126        t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
127        if abs(t) < ctx.eps:
128            #print k, N      # check how many terms were needed
129            break
130        z2k *= z2
131        s += t
132    #if k == N:
133    #    print "warning: series for barnesg failed to converge", ctx.dps
134    return G*ctx.exp(s)
135
136@defun
137def superfac(ctx, z):
138    return ctx.barnesg(z+2)
139
140@defun_wrapped
141def hyperfac(ctx, z):
142    # XXX: estimate needed extra bits accurately
143    if z == ctx.inf:
144        return z
145    if abs(z) > 5:
146        extra = 4*int(ctx.log(abs(z),2))
147    else:
148        extra = 0
149    ctx.prec += extra
150    if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
151        n = int(ctx.re(z))
152        h = ctx.hyperfac(-n-1)
153        if ((n+1)//2) & 1:
154            h = -h
155        if ctx._is_complex_type(z):
156            return h + 0j
157        return h
158    zp1 = z+1
159    # Wrong branch cut
160    #v = ctx.gamma(zp1)**z
161    #ctx.prec -= extra
162    #return v / ctx.barnesg(zp1)
163    v = ctx.exp(z*ctx.loggamma(zp1))
164    ctx.prec -= extra
165    return v / ctx.barnesg(zp1)
166
167'''
168@defun
169def psi0(ctx, z):
170    """Shortcut for psi(0,z) (the digamma function)"""
171    return ctx.psi(0, z)
172
173@defun
174def psi1(ctx, z):
175    """Shortcut for psi(1,z) (the trigamma function)"""
176    return ctx.psi(1, z)
177
178@defun
179def psi2(ctx, z):
180    """Shortcut for psi(2,z) (the tetragamma function)"""
181    return ctx.psi(2, z)
182
183@defun
184def psi3(ctx, z):
185    """Shortcut for psi(3,z) (the pentagamma function)"""
186    return ctx.psi(3, z)
187'''
188