1"""Hypergeometric and Meijer G-functions."""
2
3import functools
4import math
5
6import mpmath
7
8from ...core import (Derivative, Dummy, Expr, Function, I, Integer, Mod, Mul,
9                     Ne, Rational, Tuple, oo, pi, zoo)
10from ...core.function import ArgumentIndexError
11from .. import (acosh, acoth, asin, asinh, atan, atanh, cos, cosh, exp, log,
12                sin, sinh, sqrt)
13
14
15class TupleArg(Tuple):
16    """Arguments of the hyper/meijerg functions."""
17
18    def limit(self, x, xlim, dir='+'):
19        """Compute limit x->xlim."""
20        return self.func(*[_.limit(x, xlim, dir) for _ in self.args])
21
22
23# TODO should __new__ accept **options?
24# TODO should constructors should check if parameters are sensible?
25
26
27def _prep_tuple(v):
28    """
29    Turn an iterable argument V into a Tuple and unpolarify, since both
30    hypergeometric and meijer g-functions are unbranched in their parameters.
31
32    Examples
33    ========
34
35    >>> _prep_tuple([1, 2, 3])
36    (1, 2, 3)
37    >>> _prep_tuple((4, 5))
38    (4, 5)
39    >>> _prep_tuple((7, 8, 9))
40    (7, 8, 9)
41
42    """
43    from .. import unpolarify
44    return TupleArg(*[unpolarify(x) for x in v])
45
46
47class TupleParametersBase(Function):
48    """Base class that takes care of differentiation, when some of
49    the arguments are actually tuples.
50
51    """
52
53    # This is not deduced automatically since there are Tuples as arguments.
54    is_commutative = True
55
56    def _eval_derivative(self, s):
57        try:
58            res = 0
59            if self.args[0].has(s) or self.args[1].has(s):
60                for i, _ in enumerate(self._diffargs):
61                    m = self._diffargs[i].diff(s)
62                    if m != 0:
63                        res += self.fdiff((1, i))*m
64            return res + self.fdiff(3)*self.args[2].diff(s)
65        except (ArgumentIndexError, NotImplementedError):
66            return Derivative(self, s)
67
68    @property
69    def is_number(self):
70        """Returns True if 'self' has no free symbols."""
71        return not self.free_symbols
72
73
74class hyper(TupleParametersBase):
75    r"""
76    The (generalized) hypergeometric function is defined by a series where
77    the ratios of successive terms are a rational function of the summation
78    index. When convergent, it is continued analytically to the largest
79    possible domain.
80
81    The hypergeometric function depends on two vectors of parameters, called
82    the numerator parameters `a_p`, and the denominator parameters
83    `b_q`. It also has an argument `z`. The series definition is
84
85    .. math ::
86        {}_pF_q\left(\begin{matrix} a_1, \ldots, a_p \\ b_1, \ldots, b_q \end{matrix}
87                     \middle| z \right)
88        = \sum_{n=0}^\infty \frac{(a_1)_n \ldots (a_p)_n}{(b_1)_n \ldots (b_q)_n}
89                            \frac{z^n}{n!},
90
91    where `(a)_n = (a)(a+1)\ldots(a+n-1)` denotes the rising factorial.
92
93    If one of the `b_q` is a non-positive integer then the series is
94    undefined unless one of the `a_p` is a larger (i.e. smaller in
95    magnitude) non-positive integer. If none of the `b_q` is a
96    non-positive integer and one of the `a_p` is a non-positive
97    integer, then the series reduces to a polynomial. To simplify the
98    following discussion, we assume that none of the `a_p` or
99    `b_q` is a non-positive integer. For more details, see the
100    references.
101
102    The series converges for all `z` if `p \le q`, and thus
103    defines an entire single-valued function in this case. If `p =
104    q+1` the series converges for `|z| < 1`, and can be continued
105    analytically into a half-plane. If `p > q+1` the series is
106    divergent for all `z`.
107
108    Note: The hypergeometric function constructor currently does *not* check
109    if the parameters actually yield a well-defined function.
110
111    Examples
112    ========
113
114    The parameters `a_p` and `b_q` can be passed as arbitrary
115    iterables, for example:
116
117    >>> hyper((1, 2, 3), [3, 4], x)
118    hyper((1, 2, 3), (3, 4), x)
119
120    There is also pretty printing (it looks better using unicode):
121
122    >>> pprint(hyper((1, 2, 3), [3, 4], x), use_unicode=False)
123      _
124     |_  /1, 2, 3 |  \
125     |   |        | x|
126    3  2 \  3, 4  |  /
127
128    The parameters must always be iterables, even if they are vectors of
129    length one or zero:
130
131    >>> hyper([1], [], x)
132    hyper((1,), (), x)
133
134    But of course they may be variables (but if they depend on x then you
135    should not expect much implemented functionality):
136
137    >>> hyper([n, a], [n**2], x)
138    hyper((n, a), (n**2,), x)
139
140    The hypergeometric function generalizes many named special functions.
141    The function hyperexpand() tries to express a hypergeometric function
142    using named special functions.
143    For example:
144
145    >>> hyperexpand(hyper([], [], x))
146    E**x
147
148    You can also use expand_func:
149
150    >>> expand_func(x*hyper([1, 1], [2], -x))
151    log(x + 1)
152
153    More examples:
154
155    >>> hyperexpand(hyper([], [Rational(1, 2)], -x**2/4))
156    cos(x)
157    >>> hyperexpand(x*hyper([Rational(1, 2), Rational(1, 2)], [Rational(3, 2)], x**2))
158    asin(x)
159
160    We can also sometimes hyperexpand parametric functions:
161
162    >>> hyperexpand(hyper([-a], [], x))
163    (-x + 1)**a
164
165    See Also
166    ========
167
168    diofant.simplify.hyperexpand
169    diofant.functions.special.gamma_functions.gamma
170    diofant.functions.special.hyper.meijerg
171
172    References
173    ==========
174
175    * Luke, Y. L. (1969), The Special Functions and Their Approximations,
176      Volume 1
177    * https://en.wikipedia.org/wiki/Generalized_hypergeometric_function
178
179    """
180
181    def __new__(cls, ap, bq, z):
182        # TODO should we check convergence conditions?
183        return Function.__new__(cls, _prep_tuple(ap), _prep_tuple(bq), z)
184
185    @classmethod
186    def eval(cls, ap, bq, z):
187        from .. import unpolarify
188        if len(ap) <= len(bq):
189            nz = unpolarify(z)
190            if z != nz:
191                return hyper(ap, bq, nz)
192
193    def fdiff(self, argindex=3):
194        if argindex != 3:
195            raise ArgumentIndexError(self, argindex)
196        nap = Tuple(*[a + 1 for a in self.ap])
197        nbq = Tuple(*[b + 1 for b in self.bq])
198        fac = Mul(*self.ap)/Mul(*self.bq)
199        return fac*hyper(nap, nbq, self.argument)
200
201    def _eval_expand_func(self, **hints):
202        from ...simplify import hyperexpand
203        from .gamma_functions import gamma
204        if len(self.ap) == 2 and len(self.bq) == 1 and self.argument == 1:
205            a, b = self.ap
206            c = self.bq[0]
207            return gamma(c)*gamma(c - a - b)/gamma(c - a)/gamma(c - b)
208        return hyperexpand(self)
209
210    def _eval_rewrite_as_Sum(self, ap, bq, z):
211        from ...concrete import Sum
212        from .. import Piecewise, RisingFactorial, factorial
213        n = Dummy('n', integer=True)
214        rfap = Tuple(*[RisingFactorial(a, n) for a in ap])
215        rfbq = Tuple(*[RisingFactorial(b, n) for b in bq])
216        coeff = Mul(*rfap) / Mul(*rfbq)
217        return Piecewise((Sum(coeff * z**n / factorial(n), (n, 0, oo)),
218                          self.convergence_statement), (self, True))
219
220    @property
221    def argument(self):
222        """Argument of the hypergeometric function."""
223        return self.args[2]
224
225    @property
226    def ap(self):
227        """Numerator parameters of the hypergeometric function."""
228        return Tuple(*self.args[0])
229
230    @property
231    def bq(self):
232        """Denominator parameters of the hypergeometric function."""
233        return Tuple(*self.args[1])
234
235    @property
236    def _diffargs(self):
237        return self.ap + self.bq
238
239    @property
240    def eta(self):
241        """A quantity related to the convergence of the series."""
242        return sum(self.ap) - sum(self.bq)
243
244    @property
245    def radius_of_convergence(self):
246        """
247        Compute the radius of convergence of the defining series.
248
249        Note that even if this is not oo, the function may still be evaluated
250        outside of the radius of convergence by analytic continuation. But if
251        this is zero, then the function is not actually defined anywhere else.
252
253        >>> hyper((1, 2), [3], z).radius_of_convergence
254        1
255        >>> hyper((1, 2, 3), [4], z).radius_of_convergence
256        0
257        >>> hyper((1, 2), (3, 4), z).radius_of_convergence
258        oo
259
260        """
261        if any(a.is_integer and a.is_nonpositive for a in self.ap + self.bq):
262            aints = [a for a in self.ap if a.is_Integer and a.is_nonpositive]
263            bints = [a for a in self.bq if a.is_Integer and a.is_nonpositive]
264            if len(aints) < len(bints):
265                return Integer(0)
266            popped = False
267            for b in bints:
268                cancelled = False
269                while aints:
270                    a = aints.pop()
271                    if a >= b:
272                        cancelled = True
273                        break
274                    popped = True
275                if not cancelled:
276                    return Integer(0)
277            if aints or popped:
278                # There are still non-positive numerator parameters.
279                # This is a polynomial.
280                return oo
281        if len(self.ap) == len(self.bq) + 1:
282            return Integer(1)
283        elif len(self.ap) <= len(self.bq):
284            return oo
285        else:
286            return Integer(0)
287
288    @property
289    def convergence_statement(self):
290        """Return a condition on z under which the series converges."""
291        from ...logic import And, Or
292        from .. import re
293        R = self.radius_of_convergence
294        if R == 0:
295            return False
296        if R == oo:
297            return True
298        # The special functions and their approximations, page 44
299        e = self.eta
300        z = self.argument
301        c1 = And(re(e) < 0, abs(z) <= 1)
302        c2 = And(0 <= re(e), re(e) < 1, abs(z) <= 1, Ne(z, 1))
303        c3 = And(re(e) >= 1, abs(z) < 1)
304        return Or(c1, c2, c3)
305
306    def _eval_simplify(self, ratio, measure):
307        from ...simplify import hyperexpand
308        return hyperexpand(self)
309
310    def _eval_evalf(self, prec):
311        z = self.argument._to_mpmath(prec)
312        ap = [a._to_mpmath(prec) for a in self.ap]
313        bp = [b._to_mpmath(prec) for b in self.bq]
314        with mpmath.workprec(prec):
315            res = mpmath.hyper(ap, bp, z, eliminate=False)
316        return Expr._from_mpmath(res, prec)
317
318
319class meijerg(TupleParametersBase):
320    r"""
321    The Meijer G-function is defined by a Mellin-Barnes type integral that
322    resembles an inverse Mellin transform. It generalizes the hypergeometric
323    functions.
324
325    The Meijer G-function depends on four sets of parameters. There are
326    "*numerator parameters*"
327    `a_1, \ldots, a_n` and `a_{n+1}, \ldots, a_p`, and there are
328    "*denominator parameters*"
329    `b_1, \ldots, b_m` and `b_{m+1}, \ldots, b_q`.
330    Confusingly, it is traditionally denoted as follows (note the position
331    of `m`, `n`, `p`, `q`, and how they relate to the lengths of the four
332    parameter vectors):
333
334    .. math ::
335        G_{p,q}^{m,n} \left(\begin{matrix}a_1, \ldots, a_n & a_{n+1}, \ldots, a_p \\
336                                        b_1, \ldots, b_m & b_{m+1}, \ldots, b_q
337                          \end{matrix} \middle| z \right).
338
339    However, in diofant the four parameter vectors are always available
340    separately (see examples), so that there is no need to keep track of the
341    decorating sub- and super-scripts on the G symbol.
342
343    The G function is defined as the following integral:
344
345    .. math ::
346         \frac{1}{2 \pi i} \int_L \frac{\prod_{j=1}^m \Gamma(b_j - s)
347         \prod_{j=1}^n \Gamma(1 - a_j + s)}{\prod_{j=m+1}^q \Gamma(1- b_j +s)
348         \prod_{j=n+1}^p \Gamma(a_j - s)} z^s \mathrm{d}s,
349
350    where `\Gamma(z)` is the gamma function. There are three possible
351    contours which we will not describe in detail here (see the references).
352    If the integral converges along more than one of them the definitions
353    agree. The contours all separate the poles of `\Gamma(1-a_j+s)`
354    from the poles of `\Gamma(b_k-s)`, so in particular the G function
355    is undefined if `a_j - b_k \in \mathbb{Z}_{>0}` for some
356    `j \le n` and `k \le m`.
357
358    The conditions under which one of the contours yields a convergent integral
359    are complicated and we do not state them here, see the references.
360
361    Note: Currently the Meijer G-function constructor does *not* check any
362    convergence conditions.
363
364    Examples
365    ========
366
367    You can pass the parameters either as four separate vectors:
368
369    >>> pprint(meijerg([1, 2], [a, 4], [5], [], x), use_unicode=False)
370     __1, 2 /1, 2  a, 4 |  \
371    /__     |           | x|
372    \_|4, 1 \ 5         |  /
373
374    or as two nested vectors:
375
376    >>> pprint(meijerg(([1, 2], [3, 4]), ([5], []), x), use_unicode=False)
377     __1, 2 /1, 2  3, 4 |  \
378    /__     |           | x|
379    \_|4, 1 \ 5         |  /
380
381    As with the hypergeometric function, the parameters may be passed as
382    arbitrary iterables. Vectors of length zero and one also have to be
383    passed as iterables. The parameters need not be constants, but if they
384    depend on the argument then not much implemented functionality should be
385    expected.
386
387    All the subvectors of parameters are available:
388
389    >>> g = meijerg([1], [2], [3], [4], x)
390    >>> pprint(g, use_unicode=False)
391     __1, 1 /1  2 |  \
392    /__     |     | x|
393    \_|2, 2 \3  4 |  /
394    >>> g.an
395    (1,)
396    >>> g.ap
397    (1, 2)
398    >>> g.aother
399    (2,)
400    >>> g.bm
401    (3,)
402    >>> g.bq
403    (3, 4)
404    >>> g.bother
405    (4,)
406
407    The Meijer G-function generalizes the hypergeometric functions.
408    In some cases it can be expressed in terms of hypergeometric functions,
409    using Slater's theorem. For example:
410
411    >>> hyperexpand(meijerg([a], [], [c], [b], x), allow_hyper=True)
412    x**c*gamma(-a + c + 1)*hyper((-a + c + 1,),
413                                 (-b + c + 1,), -x)/gamma(-b + c + 1)
414
415    Thus the Meijer G-function also subsumes many named functions as special
416    cases. You can use expand_func or hyperexpand to (try to) rewrite a
417    Meijer G-function in terms of named special functions. For example:
418
419    >>> expand_func(meijerg([[], []], [[0], []], -x))
420    E**x
421    >>> hyperexpand(meijerg([[], []], [[Rational(1, 2)], [0]], (x/2)**2))
422    sin(x)/sqrt(pi)
423
424    See Also
425    ========
426
427    diofant.functions.special.hyper.hyper
428    diofant.simplify.hyperexpand
429
430    References
431    ==========
432
433    * Luke, Y. L. (1969), The Special Functions and Their Approximations,
434      Volume 1
435    * https://en.wikipedia.org/wiki/Meijer_G-function
436
437    """
438
439    def __new__(cls, *args):
440        if len(args) == 5:
441            args = [(args[0], args[1]), (args[2], args[3]), args[4]]
442        if len(args) != 3:
443            raise TypeError("args must be either as, as', bs, bs', z or "
444                            'as, bs, z')
445
446        def tr(p):
447            if len(p) != 2:
448                raise TypeError('wrong argument')
449            return TupleArg(_prep_tuple(p[0]), _prep_tuple(p[1]))
450
451        arg0, arg1 = tr(args[0]), tr(args[1])
452        if Tuple(arg0, arg1).has(oo, zoo,
453                                 -oo):
454            raise ValueError('G-function parameters must be finite')
455
456        if any((a - b).is_integer and (a - b).is_positive
457               for a in arg0[0] for b in arg1[0]):
458            raise ValueError('no parameter a1, ..., an may differ from '
459                             'any b1, ..., bm by a positive integer')
460
461        # TODO should we check convergence conditions?
462        return Function.__new__(cls, arg0, arg1, args[2])
463
464    def fdiff(self, argindex=3):
465        if argindex != 3:
466            return self._diff_wrt_parameter(argindex[1])
467        if len(self.an) >= 1:
468            a = list(self.an)
469            a[0] -= 1
470            G = meijerg(a, self.aother, self.bm, self.bother, self.argument)
471            return 1/self.argument * ((self.an[0] - 1)*self + G)
472        elif len(self.bm) >= 1:
473            b = list(self.bm)
474            b[0] += 1
475            G = meijerg(self.an, self.aother, b, self.bother, self.argument)
476            return 1/self.argument * (self.bm[0]*self - G)
477        else:
478            return Integer(0)
479
480    def _diff_wrt_parameter(self, idx):
481        # Differentiation wrt a parameter can only be done in very special
482        # cases. In particular, if we want to differentiate with respect to
483        # `a`, all other gamma factors have to reduce to rational functions.
484        #
485        # Let MT denote mellin transform. Suppose T(-s) is the gamma factor
486        # appearing in the definition of G. Then
487        #
488        #   MT(log(z)G(z)) = d/ds T(s) = d/da T(s) + ...
489        #
490        # Thus d/da G(z) = log(z)G(z) - ...
491        # The ... can be evaluated as a G function under the above conditions,
492        # the formula being most easily derived by using
493        #
494        # d  Gamma(s + n)    Gamma(s + n) / 1    1                1     \
495        # -- ------------ =  ------------ | - + ----  + ... + --------- |
496        # ds Gamma(s)        Gamma(s)     \ s   s + 1         s + n - 1 /
497        #
498        # which follows from the difference equation of the digamma function.
499        # (There is a similar equation for -n instead of +n).
500
501        # We first figure out how to pair the parameters.
502        an = list(self.an)
503        ap = list(self.aother)
504        bm = list(self.bm)
505        bq = list(self.bother)
506        if idx < len(an):
507            an.pop(idx)
508        else:
509            idx -= len(an)
510            if idx < len(ap):
511                ap.pop(idx)
512            else:
513                idx -= len(ap)
514                if idx < len(bm):
515                    bm.pop(idx)
516                else:
517                    bq.pop(idx - len(bm))
518        pairs1 = []
519        pairs2 = []
520        for l1, l2, pairs in [(an, bq, pairs1), (ap, bm, pairs2)]:
521            while l1:
522                x = l1.pop()
523                found = None
524                for i, y in enumerate(l2):
525                    if not Mod((x - y).simplify(), 1):
526                        found = i
527                        break
528                if found is None:
529                    raise NotImplementedError('Derivative not expressible '
530                                              'as G-function?')
531                y = l2[i]
532                l2.pop(i)
533                pairs.append((x, y))
534
535        # Now build the result.
536        res = log(self.argument)*self
537
538        for a, b in pairs1:
539            sign = 1
540            n = a - b
541            base = b
542            if n < 0:
543                sign = -1
544                n = b - a
545                base = a
546            for k in range(n):
547                res -= sign*meijerg(self.an + (base + k + 1,), self.aother,
548                                    self.bm, self.bother + (base + k + 0,),
549                                    self.argument)
550
551        for a, b in pairs2:
552            sign = 1
553            n = b - a
554            base = a
555            if n < 0:
556                sign = -1
557                n = a - b
558                base = b
559            for k in range(n):
560                res -= sign*meijerg(self.an, self.aother + (base + k + 1,),
561                                    self.bm + (base + k + 0,), self.bother,
562                                    self.argument)
563
564        return res
565
566    def get_period(self):
567        """
568        Return a number P such that G(x*exp(I*P)) == G(x).
569
570        >>> meijerg([1], [], [], [], z).get_period()
571        2*pi
572        >>> meijerg([pi], [], [], [], z).get_period()
573        oo
574        >>> meijerg([1, 2], [], [], [], z).get_period()
575        oo
576        >>> meijerg([1, 1], [2], [1, Rational(1, 2), Rational(1, 3)], [1], z).get_period()
577        12*pi
578
579        """
580        # This follows from slater's theorem.
581        def compute(l):
582            # first check that no two differ by an integer
583            for i, b in enumerate(l):
584                if not b.is_Rational:
585                    return oo
586                for j in range(i + 1, len(l)):
587                    if not Mod((b - l[j]).simplify(), 1):
588                        return oo
589            return functools.reduce(math.lcm, (x.denominator for x in l), 1)
590        beta = compute(self.bm)
591        alpha = compute(self.an)
592        p, q = len(self.ap), len(self.bq)
593        if p == q:
594            if oo in (beta, alpha):
595                return oo
596            return 2*pi*math.lcm(alpha, beta)
597        elif p < q:
598            return 2*pi*beta
599        else:
600            return 2*pi*alpha
601
602    def _eval_expand_func(self, **hints):
603        from ...simplify import hyperexpand
604        return hyperexpand(self)
605
606    def _eval_evalf(self, prec):
607        # The default code is insufficient for polar arguments.
608        # mpmath provides an optional argument "r", which evaluates
609        # G(z**(1/r)). I am not sure what its intended use is, but we hijack it
610        # here in the following way: to evaluate at a number z of |argument|
611        # less than (say) n*pi, we put r=1/n, compute z' = root(z, n)
612        # (carefully so as not to loose the branch information), and evaluate
613        # G(z'**(1/r)) = G(z'**n) = G(z).
614        from .. import ceiling, exp_polar
615        z = self.argument
616        znum = self.argument.evalf(prec, strict=False)
617        if znum.has(exp_polar):
618            znum, branch = znum.as_coeff_mul(exp_polar)
619            if len(branch) != 1:
620                return
621            branch = branch[0].args[0]/I
622        else:
623            branch = Integer(0)
624        n = ceiling(abs(branch/pi)) + 1
625        znum = znum**(Integer(1)/n)*exp(I*branch / n)
626
627        # Convert all args to mpf or mpc
628        [z, r, ap, bq] = [arg._to_mpmath(prec)
629                          for arg in [znum, 1/n, self.args[0], self.args[1]]]
630
631        with mpmath.workprec(prec):
632            v = mpmath.meijerg(ap, bq, z, r)
633
634        return Expr._from_mpmath(v, prec)
635
636    def integrand(self, s):
637        """Get the defining integrand D(s)."""
638        from .gamma_functions import gamma
639        return self.argument**s \
640            * Mul(*(gamma(b - s) for b in self.bm)) \
641            * Mul(*(gamma(1 - a + s) for a in self.an)) \
642            / Mul(*(gamma(1 - b + s) for b in self.bother)) \
643            / Mul(*(gamma(a - s) for a in self.aother))
644
645    @property
646    def argument(self):
647        """Argument of the Meijer G-function."""
648        return self.args[2]
649
650    @property
651    def an(self):
652        """First set of numerator parameters."""
653        return Tuple(*self.args[0][0])
654
655    @property
656    def ap(self):
657        """Combined numerator parameters."""
658        return Tuple(*(self.args[0][0] + self.args[0][1]))
659
660    @property
661    def aother(self):
662        """Second set of numerator parameters."""
663        return Tuple(*self.args[0][1])
664
665    @property
666    def bm(self):
667        """First set of denominator parameters."""
668        return Tuple(*self.args[1][0])
669
670    @property
671    def bq(self):
672        """Combined denominator parameters."""
673        return Tuple(*(self.args[1][0] + self.args[1][1]))
674
675    @property
676    def bother(self):
677        """Second set of denominator parameters."""
678        return Tuple(*self.args[1][1])
679
680    @property
681    def _diffargs(self):
682        return self.ap + self.bq
683
684    @property
685    def nu(self):
686        """A quantity related to the convergence region of the integral,
687        c.f. references.
688
689        """
690        return sum(self.bq) - sum(self.ap)
691
692    @property
693    def delta(self):
694        """A quantity related to the convergence region of the integral,
695        c.f. references.
696
697        """
698        return len(self.bm) + len(self.an) - Integer(len(self.ap) + len(self.bq))/2
699
700
701class HyperRep(Function):
702    """
703    A base class for "hyper representation functions".
704
705    This is used exclusively in hyperexpand(), but fits more logically here.
706
707    pFq is branched at 1 if p == q+1. For use with slater-expansion, we want
708    define an "analytic continuation" to all polar numbers, which is
709    continuous on circles and on the ray t*exp_polar(I*pi). Moreover, we want
710    a "nice" expression for the various cases.
711
712    This base class contains the core logic, concrete derived classes only
713    supply the actual functions.
714
715    """
716
717    @classmethod
718    def eval(cls, *args):
719        from .. import unpolarify
720        newargs = tuple(map(unpolarify, args[:-1])) + args[-1:]
721        if args != newargs:
722            return cls(*newargs)
723
724    @classmethod
725    def _expr_small(cls, x):
726        """An expression for F(x) which holds for |x| < 1."""
727        raise NotImplementedError
728
729    @classmethod
730    def _expr_small_minus(cls, x):
731        """An expression for F(-x) which holds for |x| < 1."""
732        raise NotImplementedError
733
734    @classmethod
735    def _expr_big(cls, x, n):
736        """An expression for F(exp_polar(2*I*pi*n)*x), |x| > 1."""
737        raise NotImplementedError
738
739    @classmethod
740    def _expr_big_minus(cls, x, n):
741        """An expression for F(exp_polar(2*I*pi*n + pi*I)*x), |x| > 1."""
742        raise NotImplementedError
743
744    def _eval_rewrite_as_nonrep(self, *args):
745        from .. import Piecewise
746        x, n = self.args[-1].extract_branch_factor(allow_half=True)
747        minus = False
748        newargs = self.args[:-1] + (x,)
749        if not n.is_Integer:
750            minus = True
751            n -= Rational(1, 2)
752        newerargs = newargs + (n,)
753        if minus:
754            small = self._expr_small_minus(*newargs)
755            big = self._expr_big_minus(*newerargs)
756        else:
757            small = self._expr_small(*newargs)
758            big = self._expr_big(*newerargs)
759
760        if big == small:
761            return small
762        return Piecewise((big, abs(x) > 1), (small, True))
763
764    def _eval_rewrite_as_nonrepsmall(self, *args):
765        x, n = self.args[-1].extract_branch_factor(allow_half=True)
766        args = self.args[:-1] + (x,)
767        if not n.is_Integer:
768            return self._expr_small_minus(*args)
769        return self._expr_small(*args)
770
771
772class HyperRep_power1(HyperRep):
773    """Return a representative for hyper([-a], [], z) == (1 - z)**a."""
774
775    @classmethod
776    def _expr_small(cls, a, x):
777        return (1 - x)**a
778
779    @classmethod
780    def _expr_small_minus(cls, a, x):
781        return (1 + x)**a
782
783    @classmethod
784    def _expr_big(cls, a, x, n):
785        if a.is_integer:
786            return cls._expr_small(a, x)
787        return (x - 1)**a*exp((2*n - 1)*pi*I*a)
788
789    @classmethod
790    def _expr_big_minus(cls, a, x, n):
791        if a.is_integer:
792            return cls._expr_small_minus(a, x)
793        return (1 + x)**a*exp(2*n*pi*I*a)
794
795
796class HyperRep_power2(HyperRep):
797    """Return a representative for hyper([a, a - 1/2], [2*a], z)."""
798
799    @classmethod
800    def _expr_small(cls, a, x):
801        return 2**(2*a - 1)*(1 + sqrt(1 - x))**(1 - 2*a)
802
803    @classmethod
804    def _expr_small_minus(cls, a, x):
805        return 2**(2*a - 1)*(1 + sqrt(1 + x))**(1 - 2*a)
806
807    @classmethod
808    def _expr_big(cls, a, x, n):
809        sgn = -1
810        if n.is_odd:
811            sgn = 1
812            n -= 1
813        return 2**(2*a - 1)*(1 + sgn*I*sqrt(x - 1))**(1 - 2*a) \
814            * exp(-2*n*pi*I*a)
815
816    @classmethod
817    def _expr_big_minus(cls, a, x, n):
818        sgn = 1
819        if n.is_odd:
820            sgn = -1
821        return sgn*2**(2*a - 1)*(sqrt(1 + x) + sgn)**(1 - 2*a)*exp(-2*pi*I*a*n)
822
823
824class HyperRep_log1(HyperRep):
825    """Represent -z*hyper([1, 1], [2], z) == log(1 - z)."""
826
827    @classmethod
828    def _expr_small(cls, x):
829        return log(1 - x)
830
831    @classmethod
832    def _expr_small_minus(cls, x):
833        return log(1 + x)
834
835    @classmethod
836    def _expr_big(cls, x, n):
837        return log(x - 1) + (2*n - 1)*pi*I
838
839    @classmethod
840    def _expr_big_minus(cls, x, n):
841        return log(1 + x) + 2*n*pi*I
842
843
844class HyperRep_atanh(HyperRep):
845    """Represent hyper([1/2, 1], [3/2], z) == atanh(sqrt(z))/sqrt(z)."""
846
847    @classmethod
848    def _expr_small(cls, x):
849        return atanh(sqrt(x))/sqrt(x)
850
851    def _expr_small_minus(self, x):
852        return atan(sqrt(x))/sqrt(x)
853
854    def _expr_big(self, x, n):
855        if n.is_even:
856            return (acoth(sqrt(x)) + I*pi/2)/sqrt(x)
857        else:
858            return (acoth(sqrt(x)) - I*pi/2)/sqrt(x)
859
860    def _expr_big_minus(self, x, n):
861        if n.is_even:
862            return atan(sqrt(x))/sqrt(x)
863        else:
864            return (atan(sqrt(x)) - pi)/sqrt(x)
865
866
867class HyperRep_asin1(HyperRep):
868    """Represent hyper([1/2, 1/2], [3/2], z) == asin(sqrt(z))/sqrt(z)."""
869
870    @classmethod
871    def _expr_small(cls, x):
872        return asin(sqrt(x))/sqrt(x)
873
874    @classmethod
875    def _expr_small_minus(cls, x):
876        return asinh(sqrt(x))/sqrt(x)
877
878    @classmethod
879    def _expr_big(cls, x, n):
880        return Integer(-1)**n*((Rational(1, 2) - n)*pi/sqrt(x) + I*acosh(sqrt(x))/sqrt(x))
881
882    @classmethod
883    def _expr_big_minus(cls, x, n):
884        return Integer(-1)**n*(asinh(sqrt(x))/sqrt(x) + n*pi*I/sqrt(x))
885
886
887class HyperRep_asin2(HyperRep):
888    """Represent hyper([1, 1], [3/2], z) == asin(sqrt(z))/sqrt(z)/sqrt(1-z)."""
889
890    # TODO this can be nicer
891    @classmethod
892    def _expr_small(cls, x):
893        return HyperRep_asin1._expr_small(x) \
894            / HyperRep_power1._expr_small(Rational(1, 2), x)
895
896    @classmethod
897    def _expr_small_minus(cls, x):
898        return HyperRep_asin1._expr_small_minus(x) \
899            / HyperRep_power1._expr_small_minus(Rational(1, 2), x)
900
901    @classmethod
902    def _expr_big(cls, x, n):
903        return HyperRep_asin1._expr_big(x, n) \
904            / HyperRep_power1._expr_big(Rational(1, 2), x, n)
905
906    @classmethod
907    def _expr_big_minus(cls, x, n):
908        return HyperRep_asin1._expr_big_minus(x, n) \
909            / HyperRep_power1._expr_big_minus(Rational(1, 2), x, n)
910
911
912class HyperRep_sqrts1(HyperRep):
913    """Return a representative for hyper([-a, 1/2 - a], [1/2], z)."""
914
915    @classmethod
916    def _expr_small(cls, a, z):
917        return ((1 - sqrt(z))**(2*a) + (1 + sqrt(z))**(2*a))/2
918
919    @classmethod
920    def _expr_small_minus(cls, a, z):
921        return (1 + z)**a*cos(2*a*atan(sqrt(z)))
922
923    @classmethod
924    def _expr_big(cls, a, z, n):
925        if n.is_even:
926            return ((sqrt(z) + 1)**(2*a)*exp(2*pi*I*n*a) +
927                    (sqrt(z) - 1)**(2*a)*exp(2*pi*I*(n - 1)*a))/2
928        else:
929            n -= 1
930            return ((sqrt(z) - 1)**(2*a)*exp(2*pi*I*a*(n + 1)) +
931                    (sqrt(z) + 1)**(2*a)*exp(2*pi*I*a*n))/2
932
933    @classmethod
934    def _expr_big_minus(cls, a, z, n):
935        if n.is_even:
936            return (1 + z)**a*exp(2*pi*I*n*a)*cos(2*a*atan(sqrt(z)))
937        else:
938            return (1 + z)**a*exp(2*pi*I*n*a)*cos(2*a*atan(sqrt(z)) - 2*pi*a)
939
940
941class HyperRep_sqrts2(HyperRep):
942    """Return a representative for
943    sqrt(z)/2*[(1-sqrt(z))**2a - (1 + sqrt(z))**2a]
944    == -2*z/(2*a+1) d/dz hyper([-a - 1/2, -a], [1/2], z)
945
946    """
947
948    @classmethod
949    def _expr_small(cls, a, z):
950        return sqrt(z)*((1 - sqrt(z))**(2*a) - (1 + sqrt(z))**(2*a))/2
951
952    @classmethod
953    def _expr_small_minus(cls, a, z):
954        return sqrt(z)*(1 + z)**a*sin(2*a*atan(sqrt(z)))
955
956    @classmethod
957    def _expr_big(cls, a, z, n):
958        if n.is_even:
959            return sqrt(z)/2*((sqrt(z) - 1)**(2*a)*exp(2*pi*I*a*(n - 1)) -
960                              (sqrt(z) + 1)**(2*a)*exp(2*pi*I*a*n))
961        else:
962            n -= 1
963            return sqrt(z)/2*((sqrt(z) - 1)**(2*a)*exp(2*pi*I*a*(n + 1)) -
964                              (sqrt(z) + 1)**(2*a)*exp(2*pi*I*a*n))
965
966    def _expr_big_minus(self, a, z, n):
967        if n.is_even:
968            return (1 + z)**a*exp(2*pi*I*n*a)*sqrt(z)*sin(2*a*atan(sqrt(z)))
969        else:
970            return (1 + z)**a*exp(2*pi*I*n*a)*sqrt(z) \
971                * sin(2*a*atan(sqrt(z)) - 2*pi*a)
972
973
974class HyperRep_log2(HyperRep):
975    """Represent log(1/2 + sqrt(1 - z)/2) == -z/4*hyper([3/2, 1, 1], [2, 2], z)."""
976
977    @classmethod
978    def _expr_small(cls, x):
979        return log(Rational(1, 2) + sqrt(1 - x)/2)
980
981    @classmethod
982    def _expr_small_minus(cls, x):
983        return log(Rational(1, 2) + sqrt(1 + x)/2)
984
985    @classmethod
986    def _expr_big(cls, x, n):
987        if n.is_even:
988            return (n - Rational(1, 2))*pi*I + log(sqrt(x)/2) + I*asin(1/sqrt(x))
989        else:
990            return (n - Rational(1, 2))*pi*I + log(sqrt(x)/2) - I*asin(1/sqrt(x))
991
992    def _expr_big_minus(self, x, n):
993        if n.is_even:
994            return pi*I*n + log(sqrt(1 + x)/2 + Rational(1, 2))
995        else:
996            return pi*I*n + log(sqrt(1 + x)/2 - Rational(1, 2))
997
998
999class HyperRep_cosasin(HyperRep):
1000    """Represent hyper([a, -a], [1/2], z) == cos(2*a*asin(sqrt(z)))."""
1001
1002    # Note there are many alternative expressions, e.g. as powers of a sum of
1003    # square roots.
1004
1005    @classmethod
1006    def _expr_small(cls, a, z):
1007        return cos(2*a*asin(sqrt(z)))
1008
1009    @classmethod
1010    def _expr_small_minus(cls, a, z):
1011        return cosh(2*a*asinh(sqrt(z)))
1012
1013    @classmethod
1014    def _expr_big(cls, a, z, n):
1015        return cosh(2*a*acosh(sqrt(z)) + a*pi*I*(2*n - 1))
1016
1017    @classmethod
1018    def _expr_big_minus(cls, a, z, n):
1019        return cosh(2*a*asinh(sqrt(z)) + 2*a*pi*I*n)
1020
1021
1022class HyperRep_sinasin(HyperRep):
1023    """Represent 2*a*z*hyper([1 - a, 1 + a], [3/2], z)
1024    == sqrt(z)/sqrt(1-z)*sin(2*a*asin(sqrt(z)))
1025
1026    """
1027
1028    @classmethod
1029    def _expr_small(cls, a, z):
1030        return sqrt(z)/sqrt(1 - z)*sin(2*a*asin(sqrt(z)))
1031
1032    @classmethod
1033    def _expr_small_minus(cls, a, z):
1034        return -sqrt(z)/sqrt(1 + z)*sinh(2*a*asinh(sqrt(z)))
1035
1036    @classmethod
1037    def _expr_big(cls, a, z, n):
1038        return -1/sqrt(1 - 1/z)*sinh(2*a*acosh(sqrt(z)) + a*pi*I*(2*n - 1))
1039
1040    @classmethod
1041    def _expr_big_minus(cls, a, z, n):
1042        return -1/sqrt(1 + 1/z)*sinh(2*a*asinh(sqrt(z)) + 2*a*pi*I*n)
1043