1"""
2The Risch Algorithm for transcendental function integration.
3
4The core algorithms for the Risch algorithm are here.  The subproblem
5algorithms are in the rde.py and prde.py files for the Risch
6Differential Equation solver and the parametric problems solvers,
7respectively.  All important information concerning the differential extension
8for an integrand is stored in a DifferentialExtension object, which in the code
9is usually called DE.  Throughout the code and Inside the DifferentialExtension
10object, the conventions/attribute names are that the base domain is QQ and each
11differential extension is x, t0, t1, ..., tn-1 = DE.t. DE.x is the variable of
12integration (Dx == 1), DE.D is a list of the derivatives of
13x, t1, t2, ..., tn-1 = t, DE.T is the list [x, t1, t2, ..., tn-1], DE.t is the
14outer-most variable of the differential extension at the given level (the level
15can be adjusted using DE.increment_level() and DE.decrement_level()),
16k is the field C(x, t0, ..., tn-2), where C is the constant field.  The
17numerator of a fraction is denoted by a and the denominator by
18d.  If the fraction is named f, fa == numer(f) and fd == denom(f).
19Fractions are returned as tuples (fa, fd).  DE.d and DE.t are used to
20represent the topmost derivation and extension variable, respectively.
21The docstring of a function signifies whether an argument is in k[t], in
22which case it will just return a Poly in t, or in k(t), in which case it
23will return the fraction (fa, fd). Other variable names probably come
24from the names used in Bronstein's book.
25"""
26
27import functools
28import math
29
30from ..abc import z
31from ..core import Dummy, E, Eq, Integer, Lambda, Mul, Pow, Symbol, oo, sympify
32from ..functions import (Piecewise, acos, acot, asin, atan, cos, cosh, cot,
33                         coth, exp, log, sin, sinh, tan, tanh)
34from ..polys import (DomainError, Poly, PolynomialError, RootSum, cancel, gcd,
35                     real_roots, reduced)
36from ..utilities import default_sort_key, numbered_symbols, ordered
37from .heurisch import _symbols
38from .integrals import Integral, integrate
39
40
41def integer_powers(exprs):
42    """
43    Rewrites a list of expressions as integer multiples of each other.
44
45    For example, if you have [x, x/2, x**2 + 1, 2*x/3], then you can rewrite
46    this as [(x/6) * 6, (x/6) * 3, (x**2 + 1) * 1, (x/6) * 4]. This is useful
47    in the Risch integration algorithm, where we must write exp(x) + exp(x/2)
48    as (exp(x/2))**2 + exp(x/2), but not as exp(x) + sqrt(exp(x)) (this is
49    because only the transcendental case is implemented and we therefore cannot
50    integrate algebraic extensions). The integer multiples returned by this
51    function for each term are the smallest possible (their content equals 1).
52
53    Returns a list of tuples where the first element is the base term and the
54    second element is a list of `(item, factor)` terms, where `factor` is the
55    integer multiplicative factor that must multiply the base term to obtain
56    the original item.
57
58    The easiest way to understand this is to look at an example:
59
60    >>> integer_powers([x, x/2, x**2 + 1, 2*x/3])
61    [(x/6, [(x, 6), (x/2, 3), (2*x/3, 4)]), (x**2 + 1, [(x**2 + 1, 1)])]
62
63    We can see how this relates to the example at the beginning of the
64    docstring.  It chose x/6 as the first base term.  Then, x can be written as
65    (x/2) * 2, so we get (0, 2), and so on. Now only element (x**2 + 1)
66    remains, and there are no other terms that can be written as a rational
67    multiple of that, so we get that it can be written as (x**2 + 1) * 1.
68
69    """
70    # Here is the strategy:
71
72    # First, go through each term and determine if it can be rewritten as a
73    # rational multiple of any of the terms gathered so far.
74    # cancel(a/b).is_Rational is sufficient for this.  If it is a multiple, we
75    # add its multiple to the dictionary.
76
77    terms = {}
78    for term in exprs:
79        for j in terms:  # pylint: disable=consider-using-dict-items
80            a = cancel(term/j)
81            if a.is_Rational:
82                terms[j].append((term, a))
83                break
84        else:
85            terms[term] = [(term, Integer(1))]
86
87    # After we have done this, we have all the like terms together, so we just
88    # need to find a common denominator so that we can get the base term and
89    # integer multiples such that each term can be written as an integer
90    # multiple of the base term, and the content of the integers is 1.
91
92    newterms = {}
93    for term, val in terms.items():
94        common_denom = functools.reduce(math.lcm, [i.as_numer_denom()[1] for _, i in
95                                                   val])
96        newterm = term/common_denom
97        newmults = [(i, j*common_denom) for i, j in val]
98        newterms[newterm] = newmults
99
100    return sorted(newterms.items(), key=lambda item: item[0].sort_key())
101
102
103class DifferentialExtension:
104    """
105    A container for all the information relating to a differential extension.
106
107    The attributes of this object are (see also the docstring of __init__):
108
109    - f: The original (Expr) integrand.
110    - x: The variable of integration.
111    - T: List of variables in the extension.
112    - D: List of derivations in the extension; corresponds to the elements of T.
113    - fa: Poly of the numerator of the integrand.
114    - fd: Poly of the denominator of the integrand.
115    - Tfuncs: Lambda() representations of each element of T (except for x).
116      For back-substitution after integration.
117    - backsubs: A (possibly empty) list of further substitutions to be made on
118      the final integral to make it look more like the integrand.
119    - E_K: List of the positions of the exponential extensions in T.
120    - E_args: The arguments of each of the exponentials in E_K.
121    - L_K: List of the positions of the logarithmic extensions in T.
122    - L_args: The arguments of each of the logarithms in L_K.
123    (See the docstrings of is_deriv_k() and is_log_deriv_k_t_radical() for
124    more information on E_K, E_args, L_K, and L_args)
125    - cases: List of string representations of the cases of T.
126    - t: The top level extension variable, as defined by the current level
127      (see level below).
128    - d: The top level extension derivation, as defined by the current
129      derivation (see level below).
130    - case: The string representation of the case of self.d.
131    (Note that self.T and self.D will always contain the complete extension,
132    regardless of the level.  Therefore, you should ALWAYS use DE.t and DE.d
133    instead of DE.T[-1] and DE.D[-1].  If you want to have a list of the
134    derivations or variables only up to the current level, use
135    DE.D[:len(DE.D) + DE.level + 1] and DE.T[:len(DE.T) + DE.level + 1].  Note
136    that, in particular, the derivation() function does this.)
137
138    The following are also attributes, but will probably not be useful other
139    than in internal use:
140    - newf: Expr form of fa/fd.
141    - level: The number (between -1 and -len(self.T)) such that
142      self.T[self.level] == self.t and self.D[self.level] == self.d.
143      Use the methods self.increment_level() and self.decrement_level() to change
144      the current level.
145
146    """
147
148    def __init__(self, f=None, x=None, handle_first='log', dummy=True, extension=None, rewrite_complex=False):
149        """
150        Tries to build a transcendental extension tower from f with respect to x.
151
152        If it is successful, creates a DifferentialExtension object with, among
153        others, the attributes fa, fd, D, T, Tfuncs, and backsubs such that
154        fa and fd are Polys in T[-1] with rational coefficients in T[:-1],
155        fa/fd == f, and D[i] is a Poly in T[i] with rational coefficients in
156        T[:i] representing the derivative of T[i] for each i from 1 to len(T).
157        Tfuncs is a list of Lambda objects for back replacing the functions
158        after integrating.  Lambda() is only used (instead of lambda) to make
159        them easier to test and debug. Note that Tfuncs corresponds to the
160        elements of T, except for T[0] == x, but they should be back-substituted
161        in reverse order.  backsubs is a (possibly empty) back-substitution list
162        that should be applied on the completed integral to make it look more
163        like the original integrand.
164
165        If it is unsuccessful, it raises NotImplementedError.
166
167        You can also create an object by manually setting the attributes as a
168        dictionary to the extension keyword argument.  You must include at least
169        D.  Warning, any attribute that is not given will be set to None. The
170        attributes T, t, d, cases, case, x, and level are set automatically and
171        do not need to be given.  The functions in the Risch Algorithm will NOT
172        check to see if an attribute is None before using it.  This also does not
173        check to see if the extension is valid (non-algebraic) or even if it is
174        self-consistent.  Therefore, this should only be used for
175        testing/debugging purposes.
176
177        """
178        # XXX: If you need to debug this function, set the break point here
179
180        if extension:
181            if 'D' not in extension:
182                raise ValueError('At least the key D must be included with '
183                                 'the extension flag to DifferentialExtension.')
184            for attr in extension:
185                setattr(self, attr, extension[attr])
186
187            self._auto_attrs()
188
189            return
190        elif f is None or x is None:
191            raise ValueError('Either both f and x or a manual extension must '
192                             'be given.')
193
194        from .prde import is_deriv_k
195
196        if handle_first not in ['log', 'exp']:
197            raise ValueError(f"handle_first must be 'log' or 'exp', not {handle_first!s}.")
198
199        # f will be the original function, self.f might change if we reset
200        # (e.g., we pull out a constant from an exponential)
201        self.f = f
202        self.x = x
203        self.reset(dummy=dummy)
204        exp_new_extension, log_new_extension = True, True
205        if rewrite_complex:
206            rewritables = {
207                (sin, cos, cot, tan, sinh, cosh, coth, tanh): exp,
208                (asin, acos, acot, atan): log,
209            }
210        # rewrite the trigonometric components
211            for candidates, rule in rewritables.items():
212                self.newf = self.newf.rewrite(candidates, rule)
213        else:
214            if any(i.has(x) for i in self.f.atoms(sin, cos, tan, atan, asin, acos)):
215                raise NotImplementedError('Trigonometric extensions are not '
216                                          'supported (yet!)')
217
218        def update(seq, atoms, func):
219            s = set(seq)
220            new = atoms - s
221            s = atoms.intersection(s)
222            s.update(list(filter(func, new)))
223            return list(s)
224
225        exps = set()
226        pows = set()
227        numpows = set()
228        sympows = set()
229        logs = set()
230        symlogs = set()
231
232        while True:
233            if self.newf.is_rational_function(*self.T):
234                break
235
236            if not exp_new_extension and not log_new_extension:
237                # We couldn't find a new extension on the last pass, so I guess
238                # we can't do it.
239                raise NotImplementedError("Couldn't find an elementary "
240                                          f'transcendental extension for {f!s}.  Try using a '
241                                          'manual extension with the extension flag.')
242
243            # Pre-preparsing.
244            #################
245            # Get all exp arguments, so we can avoid ahead of time doing
246            # something like t1 = exp(x), t2 = exp(x/2) == sqrt(t1).
247
248            # Things like sqrt(exp(x)) do not automatically simplify to
249            # exp(x/2), so they will be viewed as algebraic.  The easiest way
250            # to handle this is to convert all instances of (a**b)**Rational
251            # to a**(Rational*b) before doing anything else.  Note that the
252            # _exp_part code can generate terms of this form, so we do need to
253            # do this at each pass (or else modify it to not do that).
254
255            ratpows = [i for i in self.newf.atoms(Pow)
256                       if (i.base.is_Pow and i.exp.is_Rational)]
257
258            ratpows_repl = [
259                (i, i.base.base**(i.exp*i.base.exp)) for i in ratpows]
260            self.backsubs += [(j, i) for i, j in ratpows_repl]
261            self.newf = self.newf.xreplace(dict(ratpows_repl))
262
263            # To make the process deterministic, the args are sorted
264            # so that functions with smaller op-counts are processed first.
265            # Ties are broken with the default_sort_key.
266
267            # XXX Although the method is deterministic no additional work
268            # has been done to guarantee that the simplest solution is
269            # returned and that it would be affected be using different
270            # variables. Though it is possible that this is the case
271            # one should know that it has not been done intentionally so
272            # further improvements may possible.
273
274            # TODO: This probably doesn't need to be completely recomputed at
275            # each pass.
276            exps = update(exps, {a for a in self.newf.atoms(Pow) if a.base is E},
277                          lambda i: i.exp.is_rational_function(*self.T) and
278                          i.exp.has(*self.T))
279            pows = update(pows, {a for a in self.newf.atoms(Pow) if a.base is not E},
280                          lambda i: i.exp.is_rational_function(*self.T) and
281                          i.exp.has(*self.T))
282            numpows = update(numpows, set(pows),
283                             lambda i: not i.base.has(*self.T))
284            sympows = update(sympows, set(pows) - set(numpows),
285                             lambda i: i.base.is_rational_function(*self.T) and
286                             not i.exp.is_Integer)
287
288            # The easiest way to deal with non-base E powers is to convert them
289            # into base E, integrate, and then convert back.
290            for i in ordered(pows):
291                old = i
292                new = exp(i.exp*log(i.base))
293                # If exp is ever changed to automatically reduce exp(x*log(2))
294                # to 2**x, then this will break.  The solution is to not change
295                # exp to do that :)
296                if i in sympows:
297                    if i.exp.is_Rational:
298                        raise NotImplementedError('Algebraic extensions are '
299                                                  f'not supported ({i!s}).')
300                    # We can add a**b only if log(a) in the extension, because
301                    # a**b == exp(b*log(a)).
302                    basea, based = frac_in(i.base, self.t)
303                    A = is_deriv_k(basea, based, self)
304                    if A is None:
305                        # Nonelementary monomial (so far)
306
307                        # TODO: Would there ever be any benefit from just
308                        # adding log(base) as a new monomial?
309                        # ANSWER: Yes, otherwise we can't integrate x**x (or
310                        # rather prove that it has no elementary integral)
311                        # without first manually rewriting it as exp(x*log(x))
312                        self.newf = self.newf.xreplace({old: new})
313                        self.backsubs += [(new, old)]
314                        log_new_extension = self._log_part([log(i.base)],
315                                                           dummy=dummy)
316                        exps = update(exps, {a for a in self.newf.atoms(Pow)
317                                             if a.base is E},
318                                      lambda i: (i.exp.is_rational_function(*self.T) and
319                                                 i.exp.has(*self.T)))
320                        continue
321                    _, u, const = A
322                    newterm = exp(i.exp*(log(const) + u))
323                    # Under the current implementation, exp kills terms
324                    # only if they are of the form a*log(x), where a is a
325                    # Number.  This case should have already been killed by the
326                    # above tests.  Again, if this changes to kill more than
327                    # that, this will break, which maybe is a sign that you
328                    # shouldn't be changing that.  Actually, if anything, this
329                    # auto-simplification should be removed.  See
330                    # http://groups.google.com/group/sympy/browse_thread/thread/a61d48235f16867f
331
332                    self.newf = self.newf.xreplace({i: newterm})
333
334                elif i not in numpows:
335                    continue
336                else:
337                    # i in numpows
338                    newterm = new
339                # TODO: Just put it in self.Tfuncs
340                self.backsubs.append((new, old))
341                self.newf = self.newf.xreplace({old: newterm})
342                exps.append(newterm)
343
344            atoms = self.newf.atoms(log)
345            logs = update(logs, atoms,
346                          lambda i: i.args[0].is_rational_function(*self.T) and
347                          i.args[0].has(*self.T))
348            symlogs = update(symlogs, atoms,
349                             lambda i: i.has(*self.T) and i.args[0].is_Pow and
350                             i.args[0].base.is_rational_function(*self.T) and
351                             not i.args[0].base is E and
352                             not i.args[0].exp.is_Integer)
353
354            # We can handle things like log(x**y) by converting it to y*log(x)
355            # This will fix not only symbolic exponents of the argument, but any
356            # non-Integer exponent, like log(sqrt(x)).  The exponent can also
357            # depend on x, like log(x**x).
358            for i in ordered(symlogs):
359                # Unlike in the exponential case above, we do not ever
360                # potentially add new monomials (above we had to add log(a)).
361                # Therefore, there is no need to run any is_deriv functions
362                # here.  Just convert log(a**b) to b*log(a) and let
363                # log_new_extension() handle it from there.
364                lbase = log(i.args[0].base)
365                logs.append(lbase)
366                new = i.args[0].exp*lbase
367                self.newf = self.newf.xreplace({i: new})
368                self.backsubs.append((new, i))
369
370            # remove any duplicates
371            logs = sorted(set(logs), key=default_sort_key)
372
373            if handle_first == 'exp' or not log_new_extension:
374                exp_new_extension = self._exp_part(exps, dummy=dummy)
375                if exp_new_extension is None:
376                    # reset and restart
377                    self.f = self.newf
378                    self.reset(dummy=dummy)
379                    exp_new_extension = True
380                    continue
381
382            if handle_first == 'log' or not exp_new_extension:
383                log_new_extension = self._log_part(logs, dummy=dummy)
384
385        self.fa, self.fd = frac_in(self.newf, self.t)
386        self._auto_attrs()
387
388    def __getattr__(self, attr):
389        # Avoid AttributeErrors when debugging
390        if attr not in ('f', 'x', 'T', 'D', 'fa', 'fd', 'Tfuncs', 'backsubs',
391                        'E_K', 'E_args', 'L_K', 'L_args', 'cases', 'case', 't',
392                        'd', 'newf', 'level', 'ts'):
393            raise AttributeError(f'{self!r} has no attribute {attr!r}')
394
395    def _auto_attrs(self):
396        """Set attributes that are generated automatically."""
397        if not self.T:
398            # i.e., when using the extension flag and T isn't given
399            self.T = [i.gen for i in self.D]
400        if not self.x:
401            self.x = self.T[0]
402        self.cases = [get_case(d, t) for d, t in zip(self.D, self.T)]
403        self.level = -1
404        self.t = self.T[self.level]
405        self.d = self.D[self.level]
406        self.case = self.cases[self.level]
407
408    def _exp_part(self, exps, dummy=True):
409        """
410        Try to build an exponential extension.
411
412        Returns True if there was a new extension, False if there was no new
413        extension but it was able to rewrite the given exponentials in terms
414        of the existing extension, and None if the entire extension building
415        process should be restarted.  If the process fails because there is no
416        way around an algebraic extension (e.g., exp(log(x)/2)), it will raise
417        NotImplementedError.
418
419        """
420        from .prde import is_log_deriv_k_t_radical
421
422        new_extension = False
423        restart = False
424        expargs = [i.exp for i in exps]
425        ip = integer_powers(expargs)
426        for arg, others in ip:
427            # Minimize potential problems with algebraic substitution
428            others.sort(key=lambda i: i[1])
429
430            arga, argd = frac_in(arg, self.t)
431            A = is_log_deriv_k_t_radical(arga, argd, self)
432
433            if A is not None:
434                ans, u, n, const = A
435                # if n is 1 or -1, it's algebraic, but we can handle it
436                if n == -1:
437                    # This probably will never happen, because
438                    # Rational.as_numer_denom() returns the negative term in
439                    # the numerator.  But in case that changes, reduce it to
440                    # n == 1.
441                    n = 1
442                    u **= -1
443                    const *= -1
444                    ans = [(i, -j) for i, j in ans]
445
446                if n == 1:
447                    # Example: exp(x + x**2) over QQ(x, exp(x), exp(x**2))
448                    self.newf = self.newf.xreplace({exp(arg): exp(const)*Mul(*[
449                        u**power for u, power in ans])})
450                    self.newf = self.newf.xreplace({exp(p*exparg):
451                                                    exp(const*p) * Mul(*[u**power for u, power in ans])
452                                                    for exparg, p in others})
453                    # TODO: Add something to backsubs to put exp(const*p)
454                    # back together.
455
456                    continue
457
458                else:
459                    # Bad news: we have an algebraic radical.  But maybe we
460                    # could still avoid it by choosing a different extension.
461                    # For example, integer_powers() won't handle exp(x/2 + 1)
462                    # over QQ(x, exp(x)), but if we pull out the exp(1), it
463                    # will.  Or maybe we have exp(x + x**2/2), over
464                    # QQ(x, exp(x), exp(x**2)), which is exp(x)*sqrt(exp(x**2)),
465                    # but if we use QQ(x, exp(x), exp(x**2/2)), then they will
466                    # all work.
467                    #
468                    # So here is what we do: If there is a non-zero const, pull
469                    # it out and retry.  Also, if len(ans) > 1, then rewrite
470                    # exp(arg) as the product of exponentials from ans, and
471                    # retry that.  If const == 0 and len(ans) == 1, then we
472                    # assume that it would have been handled by either
473                    # integer_powers() or n == 1 above if it could be handled,
474                    # so we give up at that point.  For example, you can never
475                    # handle exp(log(x)/2) because it equals sqrt(x).
476
477                    if const or len(ans) > 1:
478                        rad = Mul(*[term**(power/n) for term, power in ans])
479                        self.newf = self.newf.xreplace({exp(p*exparg):
480                                                        exp(const*p)*rad for exparg, p in others})
481                        self.newf = self.newf.xreplace(dict(zip(reversed(self.T),
482                                                                reversed([f(self.x) for f in self.Tfuncs]))))
483                        restart = True
484                        break
485                    else:
486                        # TODO: give algebraic dependence in error string
487                        raise NotImplementedError('Cannot integrate over '
488                                                  'algebraic extensions.')
489
490            else:
491                arga, argd = frac_in(arg, self.t)
492                darga = (argd*derivation(Poly(arga, self.t), self) -
493                         arga*derivation(Poly(argd, self.t), self))
494                dargd = argd**2
495                darga, dargd = darga.cancel(dargd, include=True)
496                darg = darga.as_expr()/dargd.as_expr()
497                self.t = next(self.ts)
498                self.T.append(self.t)
499                self.E_args.append(arg)
500                self.E_K.append(len(self.T) - 1)
501                self.D.append(darg.as_poly(self.t, expand=False)*Poly(self.t,
502                                                                      self.t, expand=False))
503                if dummy:
504                    i = Dummy('i')
505                else:
506                    i = Symbol('i')
507                self.Tfuncs = self.Tfuncs + [Lambda(i, exp(arg.subs({self.x: i})))]
508                self.newf = self.newf.xreplace(
509                    {exp(exparg): self.t**p for exparg, p in others})
510                new_extension = True
511
512        if not restart:
513            return new_extension
514
515    def _log_part(self, logs, dummy=True):
516        """
517        Try to build a logarithmic extension.
518
519        Returns True if there was a new extension and False if there was no new
520        extension but it was able to rewrite the given logarithms in terms
521        of the existing extension.  Unlike with exponential extensions, there
522        is no way that a logarithm is not transcendental over and cannot be
523        rewritten in terms of an already existing extension in a non-algebraic
524        way, so this function does not ever return None or raise
525        NotImplementedError.
526
527        """
528        from .prde import is_deriv_k
529
530        new_extension = False
531        logargs = [i.args[0] for i in logs]
532        for arg in ordered(logargs):
533            # The log case is easier, because whenever a logarithm is algebraic
534            # over the base field, it is of the form a1*t1 + ... an*tn + c,
535            # which is a polynomial, so we can just replace it with that.
536            # In other words, we don't have to worry about radicals.
537            arga, argd = frac_in(arg, self.t)
538            A = is_deriv_k(arga, argd, self)
539            if A is not None:
540                _, u, const = A
541                newterm = log(const) + u
542                self.newf = self.newf.xreplace({log(arg): newterm})
543                continue
544
545            else:
546                arga, argd = frac_in(arg, self.t)
547                darga = (argd*derivation(Poly(arga, self.t), self) -
548                         arga*derivation(Poly(argd, self.t), self))
549                dargd = argd**2
550                darg = darga.as_expr()/dargd.as_expr()
551                self.t = next(self.ts)
552                self.T.append(self.t)
553                self.L_args.append(arg)
554                self.L_K.append(len(self.T) - 1)
555                self.D.append(cancel(darg.as_expr()/arg).as_poly(self.t,
556                                                                 expand=False))
557                if dummy:
558                    i = Dummy('i')
559                else:
560                    i = Symbol('i')
561                self.Tfuncs = self.Tfuncs + [Lambda(i, log(arg.subs({self.x: i})))]
562                self.newf = self.newf.xreplace({log(arg): self.t})
563                new_extension = True
564
565        return new_extension
566
567    @property
568    def _important_attrs(self):
569        """
570        Returns some of the more important attributes of self.
571
572        Used for testing and debugging purposes.
573
574        The attributes are (fa, fd, D, T, Tfuncs, backsubs, E_K, E_args,
575        L_K, L_args).
576
577        """
578        # XXX: This might be easier to read as a dict or something
579        # Maybe a named tuple.
580        return (self.fa, self.fd, self.D, self.T, self.Tfuncs,
581                self.backsubs, self.E_K, self.E_args, self.L_K, self.L_args)
582
583    # TODO: Implement __repr__
584
585    def __str__(self):
586        return str(self._important_attrs)
587
588    def reset(self, dummy=True):
589        """Reset self to an initial state.  Used by __init__."""
590        self.t = self.x
591        self.T = [self.x]
592        self.D = [Poly(1, self.x)]
593        self.level = -1
594        self.L_K, self.E_K, self.L_args, self.E_args = [], [], [], []
595        if dummy:
596            self.ts = numbered_symbols('t', cls=Dummy)
597        else:
598            # For testing
599            self.ts = numbered_symbols('t')
600        # For various things that we change to make things work that we need to
601        # change back when we are done.
602        self.backsubs = []
603        self.Tfuncs = []
604        self.newf = self.f
605
606    def increment_level(self):
607        """
608        Increment the level of self.
609
610        This makes the working differential extension larger.  self.level is
611        given relative to the end of the list (-1, -2, etc.), so we don't need
612        do worry about it when building the extension.
613
614        """
615        if self.level >= -1:
616            raise ValueError('The level of the differential extension cannot '
617                             'be incremented any further.')
618
619        self.level += 1
620        self.t = self.T[self.level]
621        self.d = self.D[self.level]
622        self.case = self.cases[self.level]
623
624    def decrement_level(self):
625        """
626        Decrease the level of self.
627
628        This makes the working differential extension smaller.  self.level is
629        given relative to the end of the list (-1, -2, etc.), so we don't need
630        do worry about it when building the extension.
631
632        """
633        if self.level <= -len(self.T):
634            raise ValueError('The level of the differential extension cannot '
635                             'be decremented any further.')
636
637        self.level -= 1
638        self.t = self.T[self.level]
639        self.d = self.D[self.level]
640        self.case = self.cases[self.level]
641
642
643class DecrementLevel:
644    """A context manager for decrementing the level of a DifferentialExtension."""
645
646    def __init__(self, DE):
647        self.DE = DE
648
649    def __enter__(self):
650        self.DE.decrement_level()
651
652    def __exit__(self, exc_type, exc_value, traceback):
653        self.DE.increment_level()
654
655
656class NonElementaryIntegralException(Exception):
657    """
658    Exception used by subroutines within the Risch algorithm to indicate to one
659    another that the function being integrated does not have an elementary
660    integral in the given differential field.
661
662    """
663
664    # TODO: Rewrite algorithms below to use this (?)
665
666    # TODO: Pass through information about why the integral was nonelementary,
667    # and store that in the resulting NonElementaryIntegral somehow.
668
669
670def gcdex_diophantine(a, b, c):
671    """
672    Extended Euclidean Algorithm, Diophantine version.
673
674    Given a, b in K[x] and c in (a, b), the ideal generated by a and b,
675    return (s, t) such that s*a + t*b == c and either s == 0 or s.degree()
676    < b.degree().
677
678    """
679    # Extended Euclidean Algorithm (Diophantine Version) pg. 13
680    # TODO: This should go in densetools.py.
681    # XXX: Bettter name?
682
683    s, g = a.half_gcdex(b)
684    q = c.exquo(g)  # Inexact division means c is not in (a, b)
685    s = q*s
686
687    if not s.is_zero and s.degree() >= b.degree():
688        q, s = s.div(b)
689
690    t = (c - s*a).exquo(b)
691
692    return s, t
693
694
695def frac_in(f, t, **kwargs):
696    """
697    Returns the tuple (fa, fd), where fa and fd are Polys in t.
698
699    This is a common idiom in the Risch Algorithm functions, so we abstract
700    it out here.  f should be a basic expresion, a Poly, or a tuple (fa, fd),
701    where fa and fd are either basic expressions or Polys, and f == fa/fd.
702    **kwargs are applied to Poly.
703
704    """
705    cancel = kwargs.pop('cancel', False)
706    if type(f) is tuple:
707        fa, fd = f
708        f = fa.as_expr()/fd.as_expr()
709    fa, fd = f.as_expr().as_numer_denom()
710    fa, fd = fa.as_poly(t, extension=False, **kwargs), fd.as_poly(t, extension=False, **kwargs)
711    if cancel:
712        fa, fd = fa.cancel(fd, include=True)
713    if fa is None or fd is None:
714        raise ValueError(f'Could not turn {f} into a fraction in {t}.')
715    return fa, fd
716
717
718def as_poly_1t(p, t, z):
719    """
720    (Hackish) way to convert an element p of K[t, 1/t] to K[t, z].
721
722    In other words, z == 1/t will be a dummy variable that Poly can handle
723    better.
724
725    See issue sympy/sympy#5131.
726
727    Examples
728    ========
729
730    >>> p1 = random_poly(x, 10, -10, 10)
731    >>> p2 = random_poly(x, 10, -10, 10)
732    >>> p = p1 + p2.subs({x: 1/x})
733    >>> as_poly_1t(p, x, z).as_expr().subs({z: 1/x}) == p
734    True
735
736    """
737    # TODO: Use this on the final result.  That way, we can avoid answers like
738    # (...)*exp(-x).
739    pa, pd = frac_in(p, t, cancel=True)
740    if not pd.is_term:
741        # XXX: Is there a better Poly exception that we could raise here?
742        # Either way, if you see this (from the Risch Algorithm) it indicates
743        # a bug.
744        raise PolynomialError(f'{p} is not an element of K[{t}, 1/{t}].')
745    d = pd.degree(t)
746    one_t_part = pa.slice(0, d + 1)
747    r = pd.degree() - pa.degree()
748    t_part = pa - one_t_part
749    try:
750        t_part = t_part.to_field().exquo(pd)
751    except DomainError as e:
752        # issue sympy/sympy#4950
753        raise NotImplementedError(e)
754    # Compute the negative degree parts.
755    od = max(-r - one_t_part.degree() if r < 0 and d > 0 else 0, 0)
756    one_t_part = Poly([0]*od + list(reversed(one_t_part.rep.all_coeffs())),
757                      *one_t_part.gens, domain=one_t_part.domain)
758    if 0 < r < oo:
759        one_t_part *= Poly(t**r, t)
760
761    one_t_part = one_t_part.replace(t, z)  # z will be 1/t
762    if pd.coeff_monomial((d,)):
763        one_t_part *= Poly(1/pd.coeff_monomial((d,)), z, expand=False)
764    ans = t_part.as_poly(t, z, expand=False) + one_t_part.as_poly(t, z,
765                                                                  expand=False)
766
767    return ans
768
769
770def derivation(p, DE, coefficientD=False, basic=False):
771    """
772    Computes Dp.
773
774    Given the derivation D with D = d/dx and p is a polynomial in t over
775    K(x), return Dp.
776
777    If coefficientD is True, it computes the derivation kD
778    (kappaD), which is defined as kD(sum(ai*Xi**i, (i, 0, n))) ==
779    sum(Dai*Xi**i, (i, 1, n)) (Definition 3.2.2, page 80).  X in this case is
780    T[-1], so coefficientD computes the derivative just with respect to T[:-1],
781    with T[-1] treated as a constant.
782
783    If basic=True, the returns a Basic expression.  Elements of D can still be
784    instances of Poly.
785
786    """
787    if basic:
788        r = 0
789    else:
790        r = Poly(0, DE.t, field=True)
791
792    t = DE.t
793    if coefficientD:
794        if DE.level <= -len(DE.T):
795            # 'base' case, the answer is 0.
796            return r
797        DE.decrement_level()
798
799    D = DE.D[:len(DE.D) + DE.level + 1]
800    T = DE.T[:len(DE.T) + DE.level + 1]
801
802    for d, v in zip(D, T):
803        pv = p.as_poly(v)
804        if pv is None or basic:
805            pv = p.as_expr()
806
807        if basic:
808            r += d.as_expr()*pv.diff(v)
809        else:
810            r += (d*pv.diff(v)).as_poly(t, field=True)
811
812    if basic:
813        r = cancel(r)
814    if coefficientD:
815        DE.increment_level()
816
817    return r
818
819
820def get_case(d, t):
821    """
822    Returns the type of the derivation d.
823
824    Returns one of {'exp', 'tan', 'base', 'primitive', 'other_linear',
825    'other_nonlinear'}.
826
827    """
828    if not d.as_expr().has(t):
829        if d.is_one:
830            return 'base'
831        return 'primitive'
832    if d.rem(Poly(t, t)).is_zero:
833        return 'exp'
834    if d.rem(Poly(1 + t**2, t)).is_zero:
835        return 'tan'
836    if d.degree(t) > 1:
837        return 'other_nonlinear'
838    return 'other_linear'
839
840
841def splitfactor(p, DE, coefficientD=False, z=None):
842    """
843    Splitting factorization.
844
845    Given a derivation D on k[t] and p in k[t], return (p_n, p_s) in
846    k[t] x k[t] such that p = p_n*p_s, p_s is special, and each square
847    factor of p_n is normal.
848
849    Page. 100
850
851    """
852    kinv = [1/x for x in DE.T[:DE.level]]
853    if z:
854        kinv.append(z)
855
856    One = Poly(1, DE.t, domain=p.domain)
857    Dp = derivation(p, DE, coefficientD=coefficientD)
858    # XXX: Is this right?
859    if p.is_zero:
860        return p, One
861
862    if not p.as_expr().has(DE.t):
863        s = p.as_poly(*kinv).gcd(Dp.as_poly(*kinv)).as_poly(DE.t)
864        n = p.exquo(s)
865        return n, s
866
867    if not Dp.is_zero:
868        h = p.gcd(Dp).to_field()
869        g = p.gcd(p.diff(DE.t)).to_field()
870        s = h.exquo(g)
871
872        if s.degree(DE.t) == 0:
873            return p, One
874
875        q_split = splitfactor(p.exquo(s), DE, coefficientD=coefficientD)
876
877        return q_split[0], q_split[1]*s
878    else:
879        return p, One
880
881
882def splitfactor_sqf(p, DE, coefficientD=False, z=None, basic=False):
883    """
884    Splitting Square-free Factorization
885
886    Given a derivation D on k[t] and p in k[t], returns (N1, ..., Nm)
887    and (S1, ..., Sm) in k[t]^m such that p =
888    (N1*N2**2*...*Nm**m)*(S1*S2**2*...*Sm**m) is a splitting
889    factorization of p and the Ni and Si are square-free and coprime.
890
891    """
892    # TODO: This algorithm appears to be faster in every case
893    # TODO: Verify this and splitfactor() for multiple extensions
894    kkinv = [1/x for x in DE.T[:DE.level]] + DE.T[:DE.level]
895    if z:
896        kkinv = [z]
897
898    S = []
899    N = []
900    p_c, p_sqf = p.sqf_list()
901    if p_c != 1:
902        p_sqf.insert(0, (Poly(p_c, DE.t), 1))
903    if p.is_zero:
904        return ((p, 1),), ()
905
906    for pi, i in p_sqf:
907        Si = pi.as_poly(*kkinv).gcd(derivation(pi, DE,
908                                               coefficientD=coefficientD, basic=basic).as_poly(*kkinv)).as_poly(DE.t)
909        pi = Poly(pi, DE.t)
910        Si = Poly(Si, DE.t)
911        Ni = pi.exquo(Si)
912        if not Si.is_one:
913            S.append((Si, i))
914        if not Ni.is_one:
915            N.append((Ni, i))
916
917    return tuple(N), tuple(S)
918
919
920def canonical_representation(a, d, DE):
921    """
922    Canonical Representation.
923
924    Given a derivation D on k[t] and f = a/d in k(t), return (f_p, f_s,
925    f_n) in k[t] x k(t) x k(t) such that f = f_p + f_s + f_n is the
926    canonical representation of f (f_p is a polynomial, f_s is reduced
927    (has a special denominator), and f_n is simple (has a normal
928    denominator).
929
930    """
931    # Make d monic
932    l = Poly(1/d.LC(), DE.t)
933    a *= l
934    d *= l
935
936    q, r = a.div(d)
937    dn, ds = splitfactor(d, DE)
938
939    b, c = gcdex_diophantine(dn.as_poly(DE.t), ds.as_poly(DE.t), r.as_poly(DE.t))
940    b, c = b.as_poly(DE.t), c.as_poly(DE.t)
941
942    return q, (b, ds), (c, dn)
943
944
945def hermite_reduce(a, d, DE):
946    """
947    Hermite Reduction - Mack's Linear Version.
948
949    Given a derivation D on k(t) and f = a/d in k(t), returns g, h, r in
950    k(t) such that f = Dg + h + r, h is simple, and r is reduced.
951
952    """
953    # Make d monic
954    l = Poly(1/d.LC(), DE.t)
955    a *= l
956    d *= l
957
958    fp, fs, fn = canonical_representation(a, d, DE)
959    a, d = fn
960    l = Poly(1/d.LC(), DE.t)
961    a *= l
962    d *= l
963
964    ga = Poly(0, DE.t)
965    gd = Poly(1, DE.t)
966
967    dd = derivation(d, DE)
968    dm = gcd(d, dd).as_poly(DE.t)
969    ds, r = d.div(dm)
970
971    while dm.degree(DE.t) > 0:
972
973        ddm = derivation(dm, DE)
974        dm2 = gcd(dm, ddm)
975        dms, r = dm.div(dm2)
976        ds_ddm = ds*ddm
977        ds_ddm_dm, r = ds_ddm.div(dm)
978
979        b, c = gcdex_diophantine(-ds_ddm_dm.as_poly(DE.t), dms.as_poly(DE.t), a.as_poly(DE.t))
980        b, c = b.as_poly(DE.t), c.as_poly(DE.t)
981
982        db = derivation(b, DE).as_poly(DE.t)
983        ds_dms, r = ds.div(dms)
984        a = c.as_poly(DE.t) - (db*ds_dms).as_poly(DE.t)
985
986        ga = ga*dm + b*gd
987        gd = gd*dm
988        ga, gd = ga.cancel(gd, include=True)
989        dm = dm2
990
991    d = ds
992    q, r = a.div(d)
993    ga, gd = ga.cancel(gd, include=True)
994
995    r, d = r.cancel(d, include=True)
996    rra = q*fs[1] + fp*fs[1] + fs[0]
997    rrd = fs[1]
998    rra, rrd = rra.cancel(rrd, include=True)
999
1000    return (ga, gd), (r, d), (rra, rrd)
1001
1002
1003def polynomial_reduce(p, DE):
1004    """
1005    Polynomial Reduction.
1006
1007    Given a derivation D on k(t) and p in k[t] where t is a nonlinear
1008    monomial over k, return q, r in k[t] such that p = Dq  + r, and
1009    deg(r) < deg_t(Dt).
1010
1011    """
1012    q = Poly(0, DE.t)
1013    while p.degree(DE.t) >= DE.d.degree(DE.t):
1014        m = p.degree(DE.t) - DE.d.degree(DE.t) + 1
1015        q0 = Poly(DE.t**m, DE.t)*Poly(p.as_poly(DE.t).LC() /
1016                                      (m*DE.d.LC()), DE.t)
1017        q += q0
1018        p = p - derivation(q0, DE)
1019
1020    return q, p
1021
1022
1023def laurent_series(a, d, F, n, DE):
1024    """
1025    Contribution of F to the full partial fraction decomposition of A/D
1026
1027    Given a field K of characteristic 0 and A,D,F in K[x] with D monic,
1028    nonzero, coprime with A, and F the factor of multiplicity n in the square-
1029    free factorization of D, return the principal parts of the Laurent series of
1030    A/D at all the zeros of F.
1031
1032    """
1033    if F.degree() == 0:
1034        return 0
1035    Z = _symbols('z', n)
1036    Z.insert(0, z)
1037    delta_a = Poly(0, DE.t)
1038    delta_d = Poly(1, DE.t)
1039
1040    E = d.quo(F**n)
1041    ha, hd = (a, E*Poly(z**n, DE.t))
1042    dF = derivation(F, DE)
1043    B, _ = gcdex_diophantine(E, F, Poly(1, DE.t))
1044    C, _ = gcdex_diophantine(dF, F, Poly(1, DE.t))
1045
1046    # initialization
1047    F_store = F
1048    V, DE_D_list, H_list = [], [], []
1049
1050    for j in range(n):
1051        # jth derivative of z would be substituted with dfnth/(j+1) where dfnth =(d^n)f/(dx)^n
1052        F_store = derivation(F_store, DE)
1053        v = (F_store.as_expr())/(j + 1)
1054        V.append(v)
1055        DE_D_list.append(Poly(Z[j + 1], Z[j]))
1056
1057    DE_new = DifferentialExtension(extension={'D': DE_D_list})  # a differential indeterminate
1058    for j in range(n):
1059        zEha = Poly(z**(n + j), DE.t)*E**(j + 1)*ha
1060        zEhd = hd
1061        Pa, Pd = cancel((zEha, zEhd))[1], cancel((zEha, zEhd))[2]
1062        Q = Pa.quo(Pd)
1063        for i in range(j + 1):
1064            Q = Q.subs({Z[i]: V[i]})
1065        Dha = hd*derivation(ha, DE, basic=True) + ha*derivation(hd, DE, basic=True)
1066        Dha += hd*derivation(ha, DE_new, basic=True) + ha*derivation(hd, DE_new, basic=True)
1067        Dhd = Poly(j + 1, DE.t)*hd**2
1068        ha, hd = Dha, Dhd
1069
1070        Ff, _ = F.div(gcd(F, Q))
1071        F_stara, F_stard = frac_in(Ff, DE.t)
1072        if F_stara.degree(DE.t) - F_stard.degree(DE.t) > 0:
1073            QBC = Poly(Q, DE.t)*B**(1 + j)*C**(n + j)
1074            H = QBC
1075            H_list.append(H)
1076            H = (QBC*F_stard).rem(F_stara)
1077            alphas = real_roots(F_stara)
1078            for alpha in list(alphas):
1079                delta_a = delta_a*Poly((DE.t - alpha)**(n - j), DE.t) + Poly(H.eval(alpha), DE.t)
1080                delta_d = delta_d*Poly((DE.t - alpha)**(n - j), DE.t)
1081    return delta_a, delta_d, H_list
1082
1083
1084def recognize_derivative(a, d, DE, z=None):
1085    """
1086    Compute the squarefree factorization of the denominator of f
1087    and for each Di the polynomial H in K[x] (see Theorem 2.7.1), using the
1088    LaurentSeries algorithm. Write Di = GiEi where Gj = gcd(Hn, Di) and
1089    gcd(Ei,Hn) = 1. Since the residues of f at the roots of Gj are all 0, and
1090    the residue of f at a root alpha of Ei is Hi(a) != 0, f is the derivative of a
1091    rational function if and only if Ei = 1 for each i, which is equivalent to
1092    Di | H[-1] for each i.
1093
1094    """
1095    flag = True
1096    a, d = a.cancel(d, include=True)
1097    _, r = a.div(d)
1098    _, Sp = splitfactor_sqf(d, DE, coefficientD=True, z=z)
1099
1100    j = 1
1101    for (s, _) in Sp:
1102        *_, H = laurent_series(r, d, s, j, DE)
1103        g = gcd(d, H[-1]).as_poly()
1104        if g is not d:
1105            flag = False
1106            break
1107        j = j + 1
1108    return flag
1109
1110
1111def recognize_log_derivative(a, d, DE, z=None):
1112    """
1113    There exists a v in K(x)* such that f = dv/v
1114    where f a rational function if and only if f can be written as f = A/D
1115    where D is squarefree,deg(A) < deg(D), gcd(A, D) = 1,
1116    and all the roots of the Rothstein-Trager resultant are integers. In that case,
1117    any of the Rothstein-Trager, Lazard-Rioboo-Trager or Czichowski algorithm
1118    produces u in K(x) such that du/dx = uf.
1119
1120    """
1121    z = z or Dummy('z')
1122    a, d = a.cancel(d, include=True)
1123    _, a = a.div(d)
1124
1125    pz = Poly(z, DE.t)
1126    Dd = derivation(d, DE)
1127    q = a - pz*Dd
1128    r = d.resultant(q)
1129    r = Poly(r, z)
1130    _, Sp = splitfactor_sqf(r, DE, coefficientD=True, z=z)
1131
1132    for s, _ in Sp:
1133        # TODO also consider the complex roots
1134        # in case we have complex roots it should turn the flag false
1135        a = real_roots(s.as_poly(z))
1136
1137        if any(not j.is_Integer for j in a):
1138            return False
1139    return True
1140
1141
1142def residue_reduce(a, d, DE, z=None, invert=True):
1143    """
1144    Lazard-Rioboo-Rothstein-Trager resultant reduction.
1145
1146    Given a derivation D on k(t) and f in k(t) simple, return g
1147    elementary over k(t) and a Boolean b in {True, False} such that f -
1148    Dg in k[t] if b == True or f + h and f + h - Dg do not have an
1149    elementary integral over k(t) for any h in k<t> (reduced) if b ==
1150    False.
1151
1152    Returns (G, b), where G is a tuple of tuples of the form (s_i, S_i),
1153    such that g = Add(*[RootSum(s_i, lambda z: z*log(S_i(z, t))) for
1154    S_i, s_i in G]). f - Dg is the remaining integral, which is elementary
1155    only if b == True, and hence the integral of f is elementary only if
1156    b == True.
1157
1158    f - Dg is not calculated in this function because that would require
1159    explicitly calculating the RootSum.  Use residue_reduce_derivation().
1160
1161    """
1162    # TODO: Use log_to_atan() from rationaltools.py
1163    # If r = residue_reduce(...), then the logarithmic part is given by:
1164    # sum(RootSum(a[0].as_poly(z),
1165    #             lambda i: i*log(a[1].as_expr()).subs({z: i})).subs({t: log(x)})
1166    #     for a in r[0])
1167
1168    z = z or Dummy('z')
1169    a, d = a.cancel(d, include=True)
1170    a, d = a.to_field()*(1/d.LC()), d.to_field()*(1/d.LC())
1171    kkinv = [1/x for x in DE.T[:DE.level]] + DE.T[:DE.level]
1172
1173    if a.is_zero:
1174        return [], True
1175    _, a = a.div(d)
1176
1177    pz = Poly(z, DE.t)
1178
1179    Dd = derivation(d, DE)
1180    q = a - pz*Dd
1181
1182    if Dd.degree(DE.t) <= d.degree(DE.t):
1183        r, R = d.resultant(q, includePRS=True)
1184    else:
1185        r, R = q.resultant(d, includePRS=True)
1186
1187    R_map, H = {}, []
1188    for i in R:
1189        R_map[i.degree()] = i
1190
1191    r = Poly(r, z)
1192    Np, Sp = splitfactor_sqf(r, DE, coefficientD=True, z=z)
1193
1194    for s, i in Sp:
1195        if i == d.degree(DE.t):
1196            s = Poly(s, z).monic()
1197            H.append((s, d))
1198        else:
1199            h = R_map.get(i)
1200            if h is None:
1201                continue
1202            h_lc = Poly(h.as_poly(DE.t).LC(), DE.t, field=True)
1203
1204            h_lc_c, h_lc_sqf = h_lc.sqf_list()
1205            h_lc_sqf.insert(0, (Poly(h_lc_c, DE.t, field=True), 1))
1206
1207            for a, j in h_lc_sqf:
1208                h = Poly(h, DE.t, field=True).exquo(Poly(gcd(a, s**j, *kkinv),
1209                                                         DE.t))
1210
1211            s = Poly(s, z).monic()
1212
1213            if invert:
1214                h_lc = Poly(h.as_poly(DE.t).LC(), DE.t, field=True, expand=False)
1215                inv, coeffs = h_lc.as_poly(z, field=True).invert(s), [Integer(1)]
1216
1217                for coeff in h.coeffs()[1:]:
1218                    L = reduced(inv*coeff, [s])[1]
1219                    coeffs.append(L.as_expr())
1220
1221                h = Poly(dict(zip(h.monoms(), coeffs)), DE.t)
1222
1223            if not s.is_one:
1224                H.append((s, h))
1225
1226    b = all(not cancel(i.as_expr()).has(DE.t, z) for i, _ in Np)
1227
1228    return H, b
1229
1230
1231def residue_reduce_to_basic(H, DE, z):
1232    """
1233    Converts the tuple returned by residue_reduce() into a Basic expression.
1234
1235    """
1236    # TODO: check what Lambda does with RootOf
1237    i = Dummy('i')
1238    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1239
1240    return sum((RootSum(a[0].as_poly(z), Lambda(i, i*log(a[1].as_expr()).subs(
1241        {z: i}).subs(s))) for a in H))
1242
1243
1244def residue_reduce_derivation(H, DE, z):
1245    """
1246    Computes the derivation of an expression returned by residue_reduce().
1247
1248    In general, this is a rational function in t, so this returns an
1249    as_expr() result.
1250
1251    """
1252    # TODO: verify that this is correct for multiple extensions
1253    i = Dummy('i')
1254    return sympify(sum((RootSum(a[0].as_poly(z),
1255                        Lambda(i, i*derivation(a[1], DE).as_expr().subs({z: i})/a[1].as_expr().subs({z: i}))) for a in H)))
1256
1257
1258def integrate_primitive_polynomial(p, DE):
1259    """
1260    Integration of primitive polynomials.
1261
1262    Given a primitive monomial t over k, and p in k[t], return q in k[t],
1263    r in k, and a bool b in {True, False} such that r = p - Dq is in k if b is
1264    True, or r = p - Dq does not have an elementary integral over k(t) if b is
1265    False.
1266
1267    """
1268    from .prde import limited_integrate
1269
1270    Zero = Poly(0, DE.t)
1271    q = Poly(0, DE.t)
1272
1273    if not p.as_expr().has(DE.t):
1274        return Zero, p, True
1275
1276    while True:
1277        if not p.as_expr().has(DE.t):
1278            return q, p, True
1279
1280        Dta, Dtb = frac_in(DE.d, DE.T[DE.level - 1])
1281
1282        # We had better be integrating the lowest extension (x)
1283        # with ratint().
1284        with DecrementLevel(DE):
1285            a = p.LC()
1286            aa, ad = frac_in(a, DE.t)
1287
1288            try:
1289                (ba, bd), c = limited_integrate(aa, ad, [(Dta, Dtb)], DE)
1290                if len(c) != 1:
1291                    raise ValueError('Length of c should  be 1')
1292            except NonElementaryIntegralException:
1293                return q, p, False
1294
1295        m = p.degree(DE.t)
1296        q0 = c[0].as_poly(DE.t)*Poly(DE.t**(m + 1)/(m + 1), DE.t) + \
1297            (ba.as_expr()/bd.as_expr()).as_poly(DE.t)*Poly(DE.t**m, DE.t)
1298
1299        p = p - derivation(q0, DE)
1300        q = q + q0
1301
1302
1303def integrate_primitive(a, d, DE, z=None):
1304    """
1305    Integration of primitive functions.
1306
1307    Given a primitive monomial t over k and f in k(t), return g elementary over
1308    k(t), i in k(t), and b in {True, False} such that i = f - Dg is in k if b
1309    is True or i = f - Dg does not have an elementary integral over k(t) if b
1310    is False.
1311
1312    This function returns a Basic expression for the first argument.  If b is
1313    True, the second argument is Basic expression in k to recursively integrate.
1314    If b is False, the second argument is an unevaluated Integral, which has
1315    been proven to be nonelementary.
1316
1317    """
1318    # XXX: a and d must be canceled, or this might return incorrect results
1319    z = z or Dummy('z')
1320    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1321
1322    g1, h, r = hermite_reduce(a, d, DE)
1323    g2, b = residue_reduce(h[0], h[1], DE, z=z)
1324    if not b:
1325        i = cancel(a.as_expr()/d.as_expr() - (g1[1]*derivation(g1[0], DE) -
1326                                              g1[0]*derivation(g1[1], DE)).as_expr()/(g1[1]**2).as_expr() -
1327                   residue_reduce_derivation(g2, DE, z))
1328        i = NonElementaryIntegral(cancel(i).subs(s), DE.x)
1329        return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) +
1330                residue_reduce_to_basic(g2, DE, z), i, b)
1331
1332    # h - Dg2 + r
1333    p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2,
1334                                                                         DE, z) + r[0].as_expr()/r[1].as_expr())
1335    p = p.as_poly(DE.t)
1336
1337    q, i, b = integrate_primitive_polynomial(p, DE)
1338
1339    ret = ((g1[0].as_expr()/g1[1].as_expr() + q.as_expr()).subs(s) +
1340           residue_reduce_to_basic(g2, DE, z))
1341    if not b:
1342        # TODO: This does not do the right thing when b is False
1343        i = NonElementaryIntegral(cancel(i.as_expr()).subs(s), DE.x)
1344    else:
1345        i = cancel(i.as_expr())
1346
1347    return ret, i, b
1348
1349
1350def integrate_hyperexponential_polynomial(p, DE, z):
1351    """
1352    Integration of hyperexponential polynomials.
1353
1354    Given a hyperexponential monomial t over k and p in k[t, 1/t], return q in
1355    k[t, 1/t] and a bool b in {True, False} such that p - Dq in k if b is True,
1356    or p - Dq does not have an elementary integral over k(t) if b is False.
1357
1358    """
1359    from .rde import rischDE
1360
1361    t1 = DE.t
1362    dtt = DE.d.exquo(Poly(DE.t, DE.t))
1363    qa = Poly(0, DE.t)
1364    qd = Poly(1, DE.t)
1365    b = True
1366
1367    if p.is_zero:
1368        return(qa, qd, b)
1369
1370    with DecrementLevel(DE):
1371        for i in range(-p.degree(z), p.degree(t1) + 1):
1372            if not i:
1373                continue
1374            elif i < 0:
1375                # If you get AttributeError: 'NoneType' object has no attribute 'nth'
1376                # then this should really not have expand=False
1377                # But it shouldn't happen because p is already a Poly in t and z
1378                a = p.as_poly(z, expand=False).coeff_monomial((-i,))
1379            else:
1380                # If you get AttributeError: 'NoneType' object has no attribute 'nth'
1381                # then this should really not have expand=False
1382                a = p.as_poly(t1, expand=False).coeff_monomial((i,))
1383
1384            aa, ad = frac_in(a, DE.t, field=True)
1385            aa, ad = aa.cancel(ad, include=True)
1386            iDt = Poly(i, t1)*dtt
1387            iDta, iDtd = frac_in(iDt, DE.t, field=True)
1388            try:
1389                va, vd = rischDE(iDta, iDtd, Poly(aa, DE.t), Poly(ad, DE.t), DE)
1390                va, vd = frac_in((va, vd), t1)
1391            except NonElementaryIntegralException:
1392                b = False
1393            else:
1394                qa = qa*vd + va*Poly(t1**i)*qd
1395                qd *= vd
1396
1397    return qa, qd, b
1398
1399
1400def integrate_hyperexponential(a, d, DE, z=None, conds='piecewise'):
1401    """
1402    Integration of hyperexponential functions.
1403
1404    Given a hyperexponential monomial t over k and f in k(t), return g
1405    elementary over k(t), i in k(t), and a bool b in {True, False} such that
1406    i = f - Dg is in k if b is True or i = f - Dg does not have an elementary
1407    integral over k(t) if b is False.
1408
1409    This function returns a Basic expression for the first argument.  If b is
1410    True, the second argument is Basic expression in k to recursively integrate.
1411    If b is False, the second argument is an unevaluated Integral, which has
1412    been proven to be nonelementary.
1413
1414    """
1415    # XXX: a and d must be canceled, or this might return incorrect results
1416    z = z or Dummy('z')
1417    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1418
1419    g1, h, r = hermite_reduce(a, d, DE)
1420    g2, b = residue_reduce(h[0], h[1], DE, z=z)
1421    if not b:
1422        i = cancel(a.as_expr()/d.as_expr() - (g1[1]*derivation(g1[0], DE) -
1423                                              g1[0]*derivation(g1[1], DE)).as_expr()/(g1[1]**2).as_expr() -
1424                   residue_reduce_derivation(g2, DE, z))
1425        i = NonElementaryIntegral(cancel(i.subs(s)), DE.x)
1426        return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) +
1427                residue_reduce_to_basic(g2, DE, z), i, b)
1428
1429    # p should be a polynomial in t and 1/t, because Sirr == k[t, 1/t]
1430    # h - Dg2 + r
1431    p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2,
1432                                                                         DE, z) + r[0].as_expr()/r[1].as_expr())
1433    pp = as_poly_1t(p, DE.t, z)
1434
1435    qa, qd, b = integrate_hyperexponential_polynomial(pp, DE, z)
1436
1437    i = pp.coeff_monomial(1)
1438
1439    ret = ((g1[0].as_expr()/g1[1].as_expr()).subs(s)
1440           + residue_reduce_to_basic(g2, DE, z))
1441
1442    qas = qa.as_expr().subs(s)
1443    qds = qd.as_expr().subs(s)
1444    if conds == 'piecewise' and DE.x not in qds.free_symbols:
1445        # We have to be careful if the exponent is Integer(0)!
1446
1447        # XXX: Does qd = 0 always necessarily correspond to the exponential
1448        # equaling 1?
1449        ret += Piecewise(
1450            (integrate((p - i).subs({DE.t: 1}).subs(s), DE.x), Eq(qds, 0)),
1451            (qas/qds, True))
1452    else:
1453        ret += qas/qds
1454
1455    if not b:
1456        i = p - (qd*derivation(qa, DE) - qa*derivation(qd, DE)).as_expr() /\
1457            (qd**2).as_expr()
1458        i = NonElementaryIntegral(cancel(i).subs(s), DE.x)
1459    return ret, i, b
1460
1461
1462def integrate_hypertangent_polynomial(p, DE):
1463    """
1464    Integration of hypertangent polynomials.
1465
1466    Given a differential field k such that sqrt(-1) is not in k, a
1467    hypertangent monomial t over k, and p in k[t], return q in k[t] and
1468    c in k such that p - Dq - c*D(t**2 + 1)/(t**1 + 1) is in k and p -
1469    Dq does not have an elementary integral over k(t) if Dc != 0.
1470
1471    """
1472    # XXX: Make sure that sqrt(-1) is not in k.
1473    q, r = polynomial_reduce(p, DE)
1474    a = DE.d.exquo(Poly(DE.t**2 + 1, DE.t))
1475    c = Poly(r.coeff_monomial((1,))/(2*a.as_expr()), DE.t)
1476    return q, c
1477
1478
1479def integrate_nonlinear_no_specials(a, d, DE, z=None):
1480    """
1481    Integration of nonlinear monomials with no specials.
1482
1483    Given a nonlinear monomial t over k such that Sirr ({p in k[t] | p is
1484    special, monic, and irreducible}) is empty, and f in k(t), returns g
1485    elementary over k(t) and a Boolean b in {True, False} such that f - Dg is
1486    in k if b == True, or f - Dg does not have an elementary integral over k(t)
1487    if b == False.
1488
1489    This function is applicable to all nonlinear extensions, but in the case
1490    where it returns b == False, it will only have proven that the integral of
1491    f - Dg is nonelementary if Sirr is empty.
1492
1493    This function returns a Basic expression.
1494
1495    """
1496    # TODO: Integral from k?
1497    # TODO: split out nonelementary integral
1498    # XXX: a and d must be canceled, or this might not return correct results
1499    z = z or Dummy('z')
1500    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1501
1502    g1, h, r = hermite_reduce(a, d, DE)
1503    g2, b = residue_reduce(h[0], h[1], DE, z=z)
1504    if not b:
1505        return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) +
1506                residue_reduce_to_basic(g2, DE, z), b)
1507
1508    # Because f has no specials, this should be a polynomial in t, or else
1509    # there is a bug.
1510    p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2,
1511                                                                         DE, z).as_expr() + r[0].as_expr()/r[1].as_expr()).as_poly(DE.t)
1512    q1, q2 = polynomial_reduce(p, DE)
1513
1514    if q2.as_expr().has(DE.t):
1515        b = False
1516    else:
1517        b = True
1518
1519    ret = (cancel(g1[0].as_expr()/g1[1].as_expr() + q1.as_expr()).subs(s) +
1520           residue_reduce_to_basic(g2, DE, z))
1521    return ret, b
1522
1523
1524class NonElementaryIntegral(Integral):
1525    """
1526    Represents a nonelementary Integral.
1527
1528    If the result of integrate() is an instance of this class, it is
1529    guaranteed to be nonelementary.  Note that integrate() by default will try
1530    to find any closed-form solution, even in terms of special functions which
1531    may themselves not be elementary.  To make integrate() only give
1532    elementary solutions, or, in the cases where it can prove the integral to
1533    be nonelementary, instances of this class, use integrate(risch=True).
1534    In this case, integrate() may raise NotImplementedError if it cannot make
1535    such a determination.
1536
1537    integrate() uses the deterministic Risch algorithm to integrate elementary
1538    functions or prove that they have no elementary integral.  In some cases,
1539    this algorithm can split an integral into an elementary and nonelementary
1540    part, so that the result of integrate will be the sum of an elementary
1541    expression and a NonElementaryIntegral.
1542
1543    Examples
1544    ========
1545
1546    >>> a = integrate(exp(-x**2), x, risch=True)
1547    >>> a
1548    Integral(E**(-x**2), x)
1549    >>> type(a)
1550    <class 'diofant.integrals.risch.NonElementaryIntegral'>
1551
1552    >>> expr = (2*log(x)**2 - log(x) - x**2)/(log(x)**3 - x**2*log(x))
1553    >>> b = integrate(expr, x, risch=True)
1554    >>> b
1555    -log(-x + log(x))/2 + log(x + log(x))/2 + Integral(1/log(x), x)
1556    >>> type(b.atoms(Integral).pop())
1557    <class 'diofant.integrals.risch.NonElementaryIntegral'>
1558
1559    """
1560
1561    # TODO: This is useful in and of itself, because isinstance(result,
1562    # NonElementaryIntegral) will tell if the integral has been proven to be
1563    # elementary. But should we do more?  Perhaps a no-op .doit() if
1564    # elementary=True?  Or maybe some information on why the integral is
1565    # nonelementary.
1566
1567
1568def risch_integrate(f, x, extension=None, handle_first='log',
1569                    separate_integral=False, rewrite_complex=False,
1570                    conds='piecewise'):
1571    r"""
1572    The Risch Integration Algorithm.
1573
1574    Only transcendental functions are supported.  Currently, only exponentials
1575    and logarithms are supported, but support for trigonometric functions is
1576    forthcoming.
1577
1578    If this function returns an unevaluated Integral in the result, it means
1579    that it has proven that integral to be nonelementary.  Any errors will
1580    result in raising NotImplementedError.  The unevaluated Integral will be
1581    an instance of NonElementaryIntegral, a subclass of Integral.
1582
1583    handle_first may be either 'exp' or 'log'.  This changes the order in
1584    which the extension is built, and may result in a different (but
1585    equivalent) solution (for an example of this, see issue sympy/sympy#5109).  It is also
1586    possible that the integral may be computed with one but not the other,
1587    because not all cases have been implemented yet.  It defaults to 'log' so
1588    that the outer extension is exponential when possible, because more of the
1589    exponential case has been implemented.
1590
1591    If separate_integral is True, the result is returned as a tuple (ans, i),
1592    where the integral is ans + i, ans is elementary, and i is either a
1593    NonElementaryIntegral or 0.  This useful if you want to try further
1594    integrating the NonElementaryIntegral part using other algorithms to
1595    possibly get a solution in terms of special functions.  It is False by
1596    default.
1597
1598    Examples
1599    ========
1600
1601    First, we try integrating exp(-x**2). Except for a constant factor of
1602    2/sqrt(pi), this is the famous error function.
1603
1604    >>> pprint(risch_integrate(exp(-x**2), x), use_unicode=False)
1605      /
1606     |
1607     |    2
1608     |  -x
1609     | E    dx
1610     |
1611    /
1612
1613    The unevaluated Integral in the result means that risch_integrate() has
1614    proven that exp(-x**2) does not have an elementary anti-derivative.
1615
1616    In many cases, risch_integrate() can split out the elementary
1617    anti-derivative part from the nonelementary anti-derivative part.
1618    For example,
1619
1620    >>> pprint(risch_integrate((2*log(x)**2 - log(x) - x**2)/(log(x)**3 -
1621    ...                        x**2*log(x)), x), use_unicode=False)
1622                                             /
1623                                            |
1624      log(-x + log(x))   log(x + log(x))    |   1
1625    - ---------------- + --------------- +  | ------ dx
1626             2                  2           | log(x)
1627                                            |
1628                                           /
1629
1630    This means that it has proven that the integral of 1/log(x) is
1631    nonelementary.  This function is also known as the logarithmic integral,
1632    and is often denoted as Li(x).
1633
1634    risch_integrate() currently only accepts purely transcendental functions
1635    with exponentials and logarithms, though note that this can include
1636    nested exponentials and logarithms, as well as exponentials with bases
1637    other than E.
1638
1639    >>> pprint(risch_integrate(exp(x)*exp(exp(x)), x), use_unicode=False)
1640     / x\
1641     \E /
1642    E
1643    >>> pprint(risch_integrate(exp(exp(x)), x), use_unicode=False)
1644      /
1645     |
1646     |  / x\
1647     |  \E /
1648     | E     dx
1649     |
1650    /
1651
1652    >>> pprint(risch_integrate(x*x**x*log(x) + x**x + x*x**x, x), use_unicode=False)
1653       x
1654    x*x
1655    >>> pprint(risch_integrate(x**x, x), use_unicode=False)
1656      /
1657     |
1658     |  x
1659     | x  dx
1660     |
1661    /
1662
1663    >>> pprint(risch_integrate(-1/(x*log(x)*log(log(x))**2), x), use_unicode=False)
1664         1
1665    -----------
1666    log(log(x))
1667
1668    """
1669    f = sympify(f)
1670
1671    DE = extension or DifferentialExtension(f, x, handle_first=handle_first, rewrite_complex=rewrite_complex)
1672    fa, fd = DE.fa, DE.fd
1673
1674    result = Integer(0)
1675    for case in reversed(DE.cases):
1676        if not DE.fa.as_expr().has(DE.t) and not fd.as_expr().has(DE.t) and not case == 'base':
1677            DE.decrement_level()
1678            fa, fd = frac_in((fa, fd), DE.t)
1679            continue
1680
1681        fa, fd = fa.cancel(fd, include=True)
1682        if case == 'exp':
1683            ans, i, b = integrate_hyperexponential(fa, fd, DE, conds=conds)
1684        elif case == 'primitive':
1685            ans, i, b = integrate_primitive(fa, fd, DE)
1686        elif case == 'base':
1687            # XXX: We can't call ratint() directly here because it doesn't
1688            # handle polynomials correctly.
1689            ans = integrate(fa.as_expr()/fd.as_expr(), DE.x, risch=False)
1690            b = False
1691            i = Integer(0)
1692        else:
1693            raise NotImplementedError('Only exponential and logarithmic '
1694                                      'extensions are currently supported.')
1695
1696        result += ans
1697        if b:
1698            DE.decrement_level()
1699            fa, fd = frac_in(i, DE.t)
1700        else:
1701            result = result.subs(DE.backsubs)
1702            if not i.is_zero:
1703                i = NonElementaryIntegral(i.function.subs(DE.backsubs), i.limits)
1704            if not separate_integral:
1705                result += i
1706                return result
1707            else:
1708
1709                if isinstance(i, NonElementaryIntegral):
1710                    return result, i
1711                else:
1712                    return result, 0
1713