1from libc.math cimport fabs, exp, floor, M_PI
2
3import cython
4
5from . cimport sf_error
6from ._cephes cimport expm1, poch
7
8cdef extern from "numpy/npy_math.h":
9    double npy_isnan(double x) nogil
10    double NPY_NAN
11    double NPY_INFINITY
12
13cdef extern from 'specfun_wrappers.h':
14    double hypU_wrap(double, double, double) nogil
15    double hyp1f1_wrap(double, double, double) nogil
16
17DEF EPS = 2.220446049250313e-16
18DEF ACCEPTABLE_RTOL = 1e-7
19
20
21@cython.cdivision(True)
22cdef inline double hyperu(double a, double b, double x) nogil:
23    if npy_isnan(a) or npy_isnan(b) or npy_isnan(x):
24        return NPY_NAN
25
26    if x < 0.0:
27        sf_error.error("hyperu", sf_error.DOMAIN, NULL)
28        return NPY_NAN
29
30    if x == 0.0:
31        if b > 1.0:
32            # DMLF 13.2.16-18
33            sf_error.error("hyperu", sf_error.SINGULAR, NULL)
34            return NPY_INFINITY
35        else:
36            # DLMF 13.2.14-15 and 13.2.19-21
37            return poch(1.0 - b + a, -a)
38
39    return hypU_wrap(a, b, x)
40
41
42@cython.cdivision(True)
43cdef inline double hyp1f1(double a, double b, double x) nogil:
44    if npy_isnan(a) or npy_isnan(b) or npy_isnan(x):
45        return NPY_NAN
46    if b <= 0 and b == floor(b):
47        # There is potentially a pole.
48        if b <= a < 0 and a == floor(a):
49            # The Pochammer symbol (a)_n cancels the pole.
50            return hyp1f1_series_track_convergence(a, b, x)
51        return NPY_INFINITY
52    elif a == 0 or x == 0:
53        return 1
54    elif a == -1:
55        return 1 - x / b
56    elif a == b:
57        return exp(x)
58    elif a - b == 1:
59        return (1 + x / b) * exp(x)
60    elif a == 1 and b == 2:
61        return expm1(x) / x
62    elif a <= 0 and a == floor(a):
63        # The geometric series is finite in this case, but it could
64        # still suffer from cancellation.
65        return hyp1f1_series_track_convergence(a, b, x)
66
67    if b > 0 and (fabs(a) + 1) * fabs(x) < 0.9 * b:
68        # For the kth term of the series we are multiplying by
69        #
70        # t_k = (a + k) * x / ((b + k) * (k + 1))
71        #
72        # We have that
73        #
74        # |t_k| < (|a| + 1) * |x| / |b|,
75        #
76        # which means that in this branch we get geometric
77        # convergence.
78        return hyp1f1_series(a, b, x)
79
80    return hyp1f1_wrap(a, b, x)
81
82
83@cython.cdivision(True)
84cdef inline double hyp1f1_series_track_convergence(
85    double a,
86    double b,
87    double x
88) nogil:
89    # The hypergeometric series can suffer from cancelation or take a
90    # prohibitive number of terms to converge. This function computes
91    # the series while monitoring those conditions.
92    cdef int k
93    cdef double apk, bpk
94    cdef double term = 1
95    cdef double result = 1
96    cdef double abssum = result
97    for k in range(1000):
98        apk = a + k
99        bpk = b + k
100        if bpk != 0:
101            term *= apk * x / bpk / (k + 1)
102        elif apk == 0:
103            # The Pochammer symbol in the denominator has become zero,
104            # but we still have the continuation formula DLMF 13.2.5.
105            term = 0
106        else:
107            # We hit a pole
108            return NPY_NAN
109        abssum += fabs(term)
110        result += term
111        if fabs(term) <= EPS * fabs(result):
112            break
113    else:
114        sf_error.error("hyp1f1", sf_error.NO_RESULT, NULL)
115        return NPY_NAN
116
117    if k * EPS * abssum <= ACCEPTABLE_RTOL * fabs(result):
118        return result
119    sf_error.error("hyp1f1", sf_error.NO_RESULT, NULL)
120    return NPY_NAN
121
122
123@cython.cdivision(True)
124cdef inline double hyp1f1_series(double a, double b, double x) nogil:
125    cdef int k
126    cdef double term = 1
127    cdef double result = 1
128    for k in range(500):
129        term *= (a + k) * x / (b + k) / (k + 1)
130        result += term
131        if fabs(term) <= EPS * fabs(result):
132            break
133    else:
134        sf_error.error("hyp1f1", sf_error.NO_RESULT, NULL)
135        result = NPY_NAN
136    return result
137