1"""
2Test cdflib functions versus mpmath, if available.
3
4The following functions still need tests:
5
6- ncfdtr
7- ncfdtri
8- ncfdtridfn
9- ncfdtridfd
10- ncfdtrinc
11- nbdtrik
12- nbdtrin
13- nrdtrimn
14- nrdtrisd
15- pdtrik
16- nctdtr
17- nctdtrit
18- nctdtridf
19- nctdtrinc
20
21"""
22import itertools
23
24import numpy as np
25from numpy.testing import assert_equal, assert_allclose
26import pytest
27
28import scipy.special as sp
29from scipy.special._testutils import (
30    MissingModule, check_version, FuncData)
31from scipy.special._mptestutils import (
32    Arg, IntArg, get_args, mpf2float, assert_mpmath_equal)
33
34try:
35    import mpmath
36except ImportError:
37    mpmath = MissingModule('mpmath')
38
39
40class ProbArg:
41    """Generate a set of probabilities on [0, 1]."""
42    def __init__(self):
43        # Include the endpoints for compatibility with Arg et. al.
44        self.a = 0
45        self.b = 1
46
47    def values(self, n):
48        """Return an array containing approximatively n numbers."""
49        m = max(1, n//3)
50        v1 = np.logspace(-30, np.log10(0.3), m)
51        v2 = np.linspace(0.3, 0.7, m + 1, endpoint=False)[1:]
52        v3 = 1 - np.logspace(np.log10(0.3), -15, m)
53        v = np.r_[v1, v2, v3]
54        return np.unique(v)
55
56
57class EndpointFilter:
58    def __init__(self, a, b, rtol, atol):
59        self.a = a
60        self.b = b
61        self.rtol = rtol
62        self.atol = atol
63
64    def __call__(self, x):
65        mask1 = np.abs(x - self.a) < self.rtol*np.abs(self.a) + self.atol
66        mask2 = np.abs(x - self.b) < self.rtol*np.abs(self.b) + self.atol
67        return np.where(mask1 | mask2, False, True)
68
69
70class _CDFData:
71    def __init__(self, spfunc, mpfunc, index, argspec, spfunc_first=True,
72                 dps=20, n=5000, rtol=None, atol=None,
73                 endpt_rtol=None, endpt_atol=None):
74        self.spfunc = spfunc
75        self.mpfunc = mpfunc
76        self.index = index
77        self.argspec = argspec
78        self.spfunc_first = spfunc_first
79        self.dps = dps
80        self.n = n
81        self.rtol = rtol
82        self.atol = atol
83
84        if not isinstance(argspec, list):
85            self.endpt_rtol = None
86            self.endpt_atol = None
87        elif endpt_rtol is not None or endpt_atol is not None:
88            if isinstance(endpt_rtol, list):
89                self.endpt_rtol = endpt_rtol
90            else:
91                self.endpt_rtol = [endpt_rtol]*len(self.argspec)
92            if isinstance(endpt_atol, list):
93                self.endpt_atol = endpt_atol
94            else:
95                self.endpt_atol = [endpt_atol]*len(self.argspec)
96        else:
97            self.endpt_rtol = None
98            self.endpt_atol = None
99
100    def idmap(self, *args):
101        if self.spfunc_first:
102            res = self.spfunc(*args)
103            if np.isnan(res):
104                return np.nan
105            args = list(args)
106            args[self.index] = res
107            with mpmath.workdps(self.dps):
108                res = self.mpfunc(*tuple(args))
109                # Imaginary parts are spurious
110                res = mpf2float(res.real)
111        else:
112            with mpmath.workdps(self.dps):
113                res = self.mpfunc(*args)
114                res = mpf2float(res.real)
115            args = list(args)
116            args[self.index] = res
117            res = self.spfunc(*tuple(args))
118        return res
119
120    def get_param_filter(self):
121        if self.endpt_rtol is None and self.endpt_atol is None:
122            return None
123
124        filters = []
125        for rtol, atol, spec in zip(self.endpt_rtol, self.endpt_atol, self.argspec):
126            if rtol is None and atol is None:
127                filters.append(None)
128                continue
129            elif rtol is None:
130                rtol = 0.0
131            elif atol is None:
132                atol = 0.0
133
134            filters.append(EndpointFilter(spec.a, spec.b, rtol, atol))
135        return filters
136
137    def check(self):
138        # Generate values for the arguments
139        args = get_args(self.argspec, self.n)
140        param_filter = self.get_param_filter()
141        param_columns = tuple(range(args.shape[1]))
142        result_columns = args.shape[1]
143        args = np.hstack((args, args[:,self.index].reshape(args.shape[0], 1)))
144        FuncData(self.idmap, args,
145                 param_columns=param_columns, result_columns=result_columns,
146                 rtol=self.rtol, atol=self.atol, vectorized=False,
147                 param_filter=param_filter).check()
148
149
150def _assert_inverts(*a, **kw):
151    d = _CDFData(*a, **kw)
152    d.check()
153
154
155def _binomial_cdf(k, n, p):
156    k, n, p = mpmath.mpf(k), mpmath.mpf(n), mpmath.mpf(p)
157    if k <= 0:
158        return mpmath.mpf(0)
159    elif k >= n:
160        return mpmath.mpf(1)
161
162    onemp = mpmath.fsub(1, p, exact=True)
163    return mpmath.betainc(n - k, k + 1, x2=onemp, regularized=True)
164
165
166def _f_cdf(dfn, dfd, x):
167    if x < 0:
168        return mpmath.mpf(0)
169    dfn, dfd, x = mpmath.mpf(dfn), mpmath.mpf(dfd), mpmath.mpf(x)
170    ub = dfn*x/(dfn*x + dfd)
171    res = mpmath.betainc(dfn/2, dfd/2, x2=ub, regularized=True)
172    return res
173
174
175def _student_t_cdf(df, t, dps=None):
176    if dps is None:
177        dps = mpmath.mp.dps
178    with mpmath.workdps(dps):
179        df, t = mpmath.mpf(df), mpmath.mpf(t)
180        fac = mpmath.hyp2f1(0.5, 0.5*(df + 1), 1.5, -t**2/df)
181        fac *= t*mpmath.gamma(0.5*(df + 1))
182        fac /= mpmath.sqrt(mpmath.pi*df)*mpmath.gamma(0.5*df)
183        return 0.5 + fac
184
185
186def _noncentral_chi_pdf(t, df, nc):
187    res = mpmath.besseli(df/2 - 1, mpmath.sqrt(nc*t))
188    res *= mpmath.exp(-(t + nc)/2)*(t/nc)**(df/4 - 1/2)/2
189    return res
190
191
192def _noncentral_chi_cdf(x, df, nc, dps=None):
193    if dps is None:
194        dps = mpmath.mp.dps
195    x, df, nc = mpmath.mpf(x), mpmath.mpf(df), mpmath.mpf(nc)
196    with mpmath.workdps(dps):
197        res = mpmath.quad(lambda t: _noncentral_chi_pdf(t, df, nc), [0, x])
198        return res
199
200
201def _tukey_lmbda_quantile(p, lmbda):
202    # For lmbda != 0
203    return (p**lmbda - (1 - p)**lmbda)/lmbda
204
205
206@pytest.mark.slow
207@check_version(mpmath, '0.19')
208class TestCDFlib:
209
210    @pytest.mark.xfail(run=False)
211    def test_bdtrik(self):
212        _assert_inverts(
213            sp.bdtrik,
214            _binomial_cdf,
215            0, [ProbArg(), IntArg(1, 1000), ProbArg()],
216            rtol=1e-4)
217
218    def test_bdtrin(self):
219        _assert_inverts(
220            sp.bdtrin,
221            _binomial_cdf,
222            1, [IntArg(1, 1000), ProbArg(), ProbArg()],
223            rtol=1e-4, endpt_atol=[None, None, 1e-6])
224
225    def test_btdtria(self):
226        _assert_inverts(
227            sp.btdtria,
228            lambda a, b, x: mpmath.betainc(a, b, x2=x, regularized=True),
229            0, [ProbArg(), Arg(0, 1e2, inclusive_a=False),
230                Arg(0, 1, inclusive_a=False, inclusive_b=False)],
231            rtol=1e-6)
232
233    def test_btdtrib(self):
234        # Use small values of a or mpmath doesn't converge
235        _assert_inverts(
236            sp.btdtrib,
237            lambda a, b, x: mpmath.betainc(a, b, x2=x, regularized=True),
238            1, [Arg(0, 1e2, inclusive_a=False), ProbArg(),
239             Arg(0, 1, inclusive_a=False, inclusive_b=False)],
240            rtol=1e-7, endpt_atol=[None, 1e-18, 1e-15])
241
242    @pytest.mark.xfail(run=False)
243    def test_fdtridfd(self):
244        _assert_inverts(
245            sp.fdtridfd,
246            _f_cdf,
247            1, [IntArg(1, 100), ProbArg(), Arg(0, 100, inclusive_a=False)],
248            rtol=1e-7)
249
250    def test_gdtria(self):
251        _assert_inverts(
252            sp.gdtria,
253            lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
254            0, [ProbArg(), Arg(0, 1e3, inclusive_a=False),
255                Arg(0, 1e4, inclusive_a=False)], rtol=1e-7,
256            endpt_atol=[None, 1e-7, 1e-10])
257
258    def test_gdtrib(self):
259        # Use small values of a and x or mpmath doesn't converge
260        _assert_inverts(
261            sp.gdtrib,
262            lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
263            1, [Arg(0, 1e2, inclusive_a=False), ProbArg(),
264                Arg(0, 1e3, inclusive_a=False)], rtol=1e-5)
265
266    def test_gdtrix(self):
267        _assert_inverts(
268            sp.gdtrix,
269            lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
270            2, [Arg(0, 1e3, inclusive_a=False), Arg(0, 1e3, inclusive_a=False),
271                ProbArg()], rtol=1e-7,
272            endpt_atol=[None, 1e-7, 1e-10])
273
274    def test_stdtr(self):
275        # Ideally the left endpoint for Arg() should be 0.
276        assert_mpmath_equal(
277            sp.stdtr,
278            _student_t_cdf,
279            [IntArg(1, 100), Arg(1e-10, np.inf)], rtol=1e-7)
280
281    @pytest.mark.xfail(run=False)
282    def test_stdtridf(self):
283        _assert_inverts(
284            sp.stdtridf,
285            _student_t_cdf,
286            0, [ProbArg(), Arg()], rtol=1e-7)
287
288    def test_stdtrit(self):
289        _assert_inverts(
290            sp.stdtrit,
291            _student_t_cdf,
292            1, [IntArg(1, 100), ProbArg()], rtol=1e-7,
293            endpt_atol=[None, 1e-10])
294
295    def test_chdtriv(self):
296        _assert_inverts(
297            sp.chdtriv,
298            lambda v, x: mpmath.gammainc(v/2, b=x/2, regularized=True),
299            0, [ProbArg(), IntArg(1, 100)], rtol=1e-4)
300
301    @pytest.mark.xfail(run=False)
302    def test_chndtridf(self):
303        # Use a larger atol since mpmath is doing numerical integration
304        _assert_inverts(
305            sp.chndtridf,
306            _noncentral_chi_cdf,
307            1, [Arg(0, 100, inclusive_a=False), ProbArg(),
308                Arg(0, 100, inclusive_a=False)],
309            n=1000, rtol=1e-4, atol=1e-15)
310
311    @pytest.mark.xfail(run=False)
312    def test_chndtrinc(self):
313        # Use a larger atol since mpmath is doing numerical integration
314        _assert_inverts(
315            sp.chndtrinc,
316            _noncentral_chi_cdf,
317            2, [Arg(0, 100, inclusive_a=False), IntArg(1, 100), ProbArg()],
318            n=1000, rtol=1e-4, atol=1e-15)
319
320    def test_chndtrix(self):
321        # Use a larger atol since mpmath is doing numerical integration
322        _assert_inverts(
323            sp.chndtrix,
324            _noncentral_chi_cdf,
325            0, [ProbArg(), IntArg(1, 100), Arg(0, 100, inclusive_a=False)],
326            n=1000, rtol=1e-4, atol=1e-15,
327            endpt_atol=[1e-6, None, None])
328
329    def test_tklmbda_zero_shape(self):
330        # When lmbda = 0 the CDF has a simple closed form
331        one = mpmath.mpf(1)
332        assert_mpmath_equal(
333            lambda x: sp.tklmbda(x, 0),
334            lambda x: one/(mpmath.exp(-x) + one),
335            [Arg()], rtol=1e-7)
336
337    def test_tklmbda_neg_shape(self):
338        _assert_inverts(
339            sp.tklmbda,
340            _tukey_lmbda_quantile,
341            0, [ProbArg(), Arg(-25, 0, inclusive_b=False)],
342            spfunc_first=False, rtol=1e-5,
343            endpt_atol=[1e-9, 1e-5])
344
345    @pytest.mark.xfail(run=False)
346    def test_tklmbda_pos_shape(self):
347        _assert_inverts(
348            sp.tklmbda,
349            _tukey_lmbda_quantile,
350            0, [ProbArg(), Arg(0, 100, inclusive_a=False)],
351            spfunc_first=False, rtol=1e-5)
352
353
354def test_nonfinite():
355    funcs = [
356        ("btdtria", 3),
357        ("btdtrib", 3),
358        ("bdtrik", 3),
359        ("bdtrin", 3),
360        ("chdtriv", 2),
361        ("chndtr", 3),
362        ("chndtrix", 3),
363        ("chndtridf", 3),
364        ("chndtrinc", 3),
365        ("fdtridfd", 3),
366        ("ncfdtr", 4),
367        ("ncfdtri", 4),
368        ("ncfdtridfn", 4),
369        ("ncfdtridfd", 4),
370        ("ncfdtrinc", 4),
371        ("gdtrix", 3),
372        ("gdtrib", 3),
373        ("gdtria", 3),
374        ("nbdtrik", 3),
375        ("nbdtrin", 3),
376        ("nrdtrimn", 3),
377        ("nrdtrisd", 3),
378        ("pdtrik", 2),
379        ("stdtr", 2),
380        ("stdtrit", 2),
381        ("stdtridf", 2),
382        ("nctdtr", 3),
383        ("nctdtrit", 3),
384        ("nctdtridf", 3),
385        ("nctdtrinc", 3),
386        ("tklmbda", 2),
387    ]
388
389    np.random.seed(1)
390
391    for func, numargs in funcs:
392        func = getattr(sp, func)
393
394        args_choices = [(float(x), np.nan, np.inf, -np.inf) for x in
395                        np.random.rand(numargs)]
396
397        for args in itertools.product(*args_choices):
398            res = func(*args)
399
400            if any(np.isnan(x) for x in args):
401                # Nan inputs should result to nan output
402                assert_equal(res, np.nan)
403            else:
404                # All other inputs should return something (but not
405                # raise exceptions or cause hangs)
406                pass
407
408
409def test_chndtrix_gh2158():
410    # test that gh-2158 is resolved; previously this blew up
411    res = sp.chndtrix(0.999999, 2, np.arange(20.)+1e-6)
412
413    # Generated in R
414    # options(digits=16)
415    # ncp <- seq(0, 19) + 1e-6
416    # print(qchisq(0.999999, df = 2, ncp = ncp))
417    res_exp = [27.63103493142305, 35.25728589950540, 39.97396073236288,
418               43.88033702110538, 47.35206403482798, 50.54112500166103,
419               53.52720257322766, 56.35830042867810, 59.06600769498512,
420               61.67243118946381, 64.19376191277179, 66.64228141346548,
421               69.02756927200180, 71.35726934749408, 73.63759723904816,
422               75.87368842650227, 78.06984431185720, 80.22971052389806,
423               82.35640899964173, 84.45263768373256]
424    assert_allclose(res, res_exp)
425