1"""
2The Risch Algorithm for transcendental function integration.
3
4The core algorithms for the Risch algorithm are here.  The subproblem
5algorithms are in the rde.py and prde.py files for the Risch
6Differential Equation solver and the parametric problems solvers,
7respectively.  All important information concerning the differential extension
8for an integrand is stored in a DifferentialExtension object, which in the code
9is usually called DE.  Throughout the code and Inside the DifferentialExtension
10object, the conventions/attribute names are that the base domain is QQ and each
11differential extension is x, t0, t1, ..., tn-1 = DE.t. DE.x is the variable of
12integration (Dx == 1), DE.D is a list of the derivatives of
13x, t1, t2, ..., tn-1 = t, DE.T is the list [x, t1, t2, ..., tn-1], DE.t is the
14outer-most variable of the differential extension at the given level (the level
15can be adjusted using DE.increment_level() and DE.decrement_level()),
16k is the field C(x, t0, ..., tn-2), where C is the constant field.  The
17numerator of a fraction is denoted by a and the denominator by
18d.  If the fraction is named f, fa == numer(f) and fd == denom(f).
19Fractions are returned as tuples (fa, fd).  DE.d and DE.t are used to
20represent the topmost derivation and extension variable, respectively.
21The docstring of a function signifies whether an argument is in k[t], in
22which case it will just return a Poly in t, or in k(t), in which case it
23will return the fraction (fa, fd). Other variable names probably come
24from the names used in Bronstein's book.
25"""
26
27from sympy import real_roots, default_sort_key
28from sympy.abc import z
29from sympy.core.function import Lambda
30from sympy.core.numbers import ilcm, oo, I
31from sympy.core.mul import Mul
32from sympy.core.power import Pow
33from sympy.core.relational import Ne
34from sympy.core.singleton import S
35from sympy.core.symbol import Symbol, Dummy
36from sympy.core.compatibility import ordered
37from sympy.integrals.heurisch import _symbols
38
39from sympy.functions import (acos, acot, asin, atan, cos, cot, exp, log,
40    Piecewise, sin, tan)
41
42from sympy.functions import sinh, cosh, tanh, coth
43from sympy.integrals import Integral, integrate
44
45from sympy.polys import gcd, cancel, PolynomialError, Poly, reduced, RootSum, DomainError
46
47from sympy.utilities.iterables import numbered_symbols
48
49from types import GeneratorType
50from functools import reduce
51
52
53def integer_powers(exprs):
54    """
55    Rewrites a list of expressions as integer multiples of each other.
56
57    Explanation
58    ===========
59
60    For example, if you have [x, x/2, x**2 + 1, 2*x/3], then you can rewrite
61    this as [(x/6) * 6, (x/6) * 3, (x**2 + 1) * 1, (x/6) * 4]. This is useful
62    in the Risch integration algorithm, where we must write exp(x) + exp(x/2)
63    as (exp(x/2))**2 + exp(x/2), but not as exp(x) + sqrt(exp(x)) (this is
64    because only the transcendental case is implemented and we therefore cannot
65    integrate algebraic extensions). The integer multiples returned by this
66    function for each term are the smallest possible (their content equals 1).
67
68    Returns a list of tuples where the first element is the base term and the
69    second element is a list of `(item, factor)` terms, where `factor` is the
70    integer multiplicative factor that must multiply the base term to obtain
71    the original item.
72
73    The easiest way to understand this is to look at an example:
74
75    >>> from sympy.abc import x
76    >>> from sympy.integrals.risch import integer_powers
77    >>> integer_powers([x, x/2, x**2 + 1, 2*x/3])
78    [(x/6, [(x, 6), (x/2, 3), (2*x/3, 4)]), (x**2 + 1, [(x**2 + 1, 1)])]
79
80    We can see how this relates to the example at the beginning of the
81    docstring.  It chose x/6 as the first base term.  Then, x can be written as
82    (x/2) * 2, so we get (0, 2), and so on. Now only element (x**2 + 1)
83    remains, and there are no other terms that can be written as a rational
84    multiple of that, so we get that it can be written as (x**2 + 1) * 1.
85
86    """
87    # Here is the strategy:
88
89    # First, go through each term and determine if it can be rewritten as a
90    # rational multiple of any of the terms gathered so far.
91    # cancel(a/b).is_Rational is sufficient for this.  If it is a multiple, we
92    # add its multiple to the dictionary.
93
94    terms = {}
95    for term in exprs:
96        for j in terms:
97            a = cancel(term/j)
98            if a.is_Rational:
99                terms[j].append((term, a))
100                break
101        else:
102            terms[term] = [(term, S.One)]
103
104    # After we have done this, we have all the like terms together, so we just
105    # need to find a common denominator so that we can get the base term and
106    # integer multiples such that each term can be written as an integer
107    # multiple of the base term, and the content of the integers is 1.
108
109    newterms = {}
110    for term in terms:
111        common_denom = reduce(ilcm, [i.as_numer_denom()[1] for _, i in
112            terms[term]])
113        newterm = term/common_denom
114        newmults = [(i, j*common_denom) for i, j in terms[term]]
115        newterms[newterm] = newmults
116
117    return sorted(iter(newterms.items()), key=lambda item: item[0].sort_key())
118
119
120class DifferentialExtension:
121    """
122    A container for all the information relating to a differential extension.
123
124    Explanation
125    ===========
126
127    The attributes of this object are (see also the docstring of __init__):
128
129    - f: The original (Expr) integrand.
130    - x: The variable of integration.
131    - T: List of variables in the extension.
132    - D: List of derivations in the extension; corresponds to the elements of T.
133    - fa: Poly of the numerator of the integrand.
134    - fd: Poly of the denominator of the integrand.
135    - Tfuncs: Lambda() representations of each element of T (except for x).
136      For back-substitution after integration.
137    - backsubs: A (possibly empty) list of further substitutions to be made on
138      the final integral to make it look more like the integrand.
139    - exts:
140    - extargs:
141    - cases: List of string representations of the cases of T.
142    - t: The top level extension variable, as defined by the current level
143      (see level below).
144    - d: The top level extension derivation, as defined by the current
145      derivation (see level below).
146    - case: The string representation of the case of self.d.
147    (Note that self.T and self.D will always contain the complete extension,
148    regardless of the level.  Therefore, you should ALWAYS use DE.t and DE.d
149    instead of DE.T[-1] and DE.D[-1].  If you want to have a list of the
150    derivations or variables only up to the current level, use
151    DE.D[:len(DE.D) + DE.level + 1] and DE.T[:len(DE.T) + DE.level + 1].  Note
152    that, in particular, the derivation() function does this.)
153
154    The following are also attributes, but will probably not be useful other
155    than in internal use:
156    - newf: Expr form of fa/fd.
157    - level: The number (between -1 and -len(self.T)) such that
158      self.T[self.level] == self.t and self.D[self.level] == self.d.
159      Use the methods self.increment_level() and self.decrement_level() to change
160      the current level.
161    """
162    # __slots__ is defined mainly so we can iterate over all the attributes
163    # of the class easily (the memory use doesn't matter too much, since we
164    # only create one DifferentialExtension per integration).  Also, it's nice
165    # to have a safeguard when debugging.
166    __slots__ = ('f', 'x', 'T', 'D', 'fa', 'fd', 'Tfuncs', 'backsubs',
167        'exts', 'extargs', 'cases', 'case', 't', 'd', 'newf', 'level',
168        'ts', 'dummy')
169
170    def __init__(self, f=None, x=None, handle_first='log', dummy=False, extension=None, rewrite_complex=None):
171        """
172        Tries to build a transcendental extension tower from ``f`` with respect to ``x``.
173
174        Explanation
175        ===========
176
177        If it is successful, creates a DifferentialExtension object with, among
178        others, the attributes fa, fd, D, T, Tfuncs, and backsubs such that
179        fa and fd are Polys in T[-1] with rational coefficients in T[:-1],
180        fa/fd == f, and D[i] is a Poly in T[i] with rational coefficients in
181        T[:i] representing the derivative of T[i] for each i from 1 to len(T).
182        Tfuncs is a list of Lambda objects for back replacing the functions
183        after integrating.  Lambda() is only used (instead of lambda) to make
184        them easier to test and debug. Note that Tfuncs corresponds to the
185        elements of T, except for T[0] == x, but they should be back-substituted
186        in reverse order.  backsubs is a (possibly empty) back-substitution list
187        that should be applied on the completed integral to make it look more
188        like the original integrand.
189
190        If it is unsuccessful, it raises NotImplementedError.
191
192        You can also create an object by manually setting the attributes as a
193        dictionary to the extension keyword argument.  You must include at least
194        D.  Warning, any attribute that is not given will be set to None. The
195        attributes T, t, d, cases, case, x, and level are set automatically and
196        do not need to be given.  The functions in the Risch Algorithm will NOT
197        check to see if an attribute is None before using it.  This also does not
198        check to see if the extension is valid (non-algebraic) or even if it is
199        self-consistent.  Therefore, this should only be used for
200        testing/debugging purposes.
201        """
202        # XXX: If you need to debug this function, set the break point here
203
204        if extension:
205            if 'D' not in extension:
206                raise ValueError("At least the key D must be included with "
207                    "the extension flag to DifferentialExtension.")
208            for attr in extension:
209                setattr(self, attr, extension[attr])
210
211            self._auto_attrs()
212
213            return
214        elif f is None or x is None:
215            raise ValueError("Either both f and x or a manual extension must "
216            "be given.")
217
218        if handle_first not in ['log', 'exp']:
219            raise ValueError("handle_first must be 'log' or 'exp', not %s." %
220                str(handle_first))
221
222        # f will be the original function, self.f might change if we reset
223        # (e.g., we pull out a constant from an exponential)
224        self.f = f
225        self.x = x
226        # setting the default value 'dummy'
227        self.dummy = dummy
228        self.reset()
229        exp_new_extension, log_new_extension = True, True
230
231        # case of 'automatic' choosing
232        if rewrite_complex is None:
233            rewrite_complex = I in self.f.atoms()
234
235        if rewrite_complex:
236            rewritables = {
237                (sin, cos, cot, tan, sinh, cosh, coth, tanh): exp,
238                (asin, acos, acot, atan): log,
239            }
240            # rewrite the trigonometric components
241            for candidates, rule in rewritables.items():
242                self.newf = self.newf.rewrite(candidates, rule)
243            self.newf = cancel(self.newf)
244        else:
245            if any(i.has(x) for i in self.f.atoms(sin, cos, tan, atan, asin, acos)):
246                raise NotImplementedError("Trigonometric extensions are not "
247                "supported (yet!)")
248
249        exps = set()
250        pows = set()
251        numpows = set()
252        sympows = set()
253        logs = set()
254        symlogs = set()
255
256        while True:
257            if self.newf.is_rational_function(*self.T):
258                break
259
260            if not exp_new_extension and not log_new_extension:
261                # We couldn't find a new extension on the last pass, so I guess
262                # we can't do it.
263                raise NotImplementedError("Couldn't find an elementary "
264                    "transcendental extension for %s.  Try using a " % str(f) +
265                    "manual extension with the extension flag.")
266
267            exps, pows, numpows, sympows, log_new_extension = \
268                    self._rewrite_exps_pows(exps, pows, numpows, sympows, log_new_extension)
269
270            logs, symlogs = self._rewrite_logs(logs, symlogs)
271
272            if handle_first == 'exp' or not log_new_extension:
273                exp_new_extension = self._exp_part(exps)
274                if exp_new_extension is None:
275                    # reset and restart
276                    self.f = self.newf
277                    self.reset()
278                    exp_new_extension = True
279                    continue
280
281            if handle_first == 'log' or not exp_new_extension:
282                log_new_extension = self._log_part(logs)
283
284        self.fa, self.fd = frac_in(self.newf, self.t)
285        self._auto_attrs()
286
287        return
288
289    def __getattr__(self, attr):
290        # Avoid AttributeErrors when debugging
291        if attr not in self.__slots__:
292            raise AttributeError("%s has no attribute %s" % (repr(self), repr(attr)))
293        return None
294
295    def _rewrite_exps_pows(self, exps, pows, numpows,
296            sympows, log_new_extension):
297        """
298        Rewrite exps/pows for better processing.
299        """
300        # Pre-preparsing.
301        #################
302        # Get all exp arguments, so we can avoid ahead of time doing
303        # something like t1 = exp(x), t2 = exp(x/2) == sqrt(t1).
304
305        # Things like sqrt(exp(x)) do not automatically simplify to
306        # exp(x/2), so they will be viewed as algebraic.  The easiest way
307        # to handle this is to convert all instances of (a**b)**Rational
308        # to a**(Rational*b) before doing anything else.  Note that the
309        # _exp_part code can generate terms of this form, so we do need to
310        # do this at each pass (or else modify it to not do that).
311
312        from sympy.integrals.prde import is_deriv_k
313
314        ratpows = [i for i in self.newf.atoms(Pow).union(self.newf.atoms(exp))
315            if (i.base.is_Pow or isinstance(i.base, exp) and i.exp.is_Rational)]
316
317        ratpows_repl = [
318            (i, i.base.base**(i.exp*i.base.exp)) for i in ratpows]
319        self.backsubs += [(j, i) for i, j in ratpows_repl]
320        self.newf = self.newf.xreplace(dict(ratpows_repl))
321
322        # To make the process deterministic, the args are sorted
323        # so that functions with smaller op-counts are processed first.
324        # Ties are broken with the default_sort_key.
325
326        # XXX Although the method is deterministic no additional work
327        # has been done to guarantee that the simplest solution is
328        # returned and that it would be affected be using different
329        # variables. Though it is possible that this is the case
330        # one should know that it has not been done intentionally, so
331        # further improvements may be possible.
332
333        # TODO: This probably doesn't need to be completely recomputed at
334        # each pass.
335        exps = update_sets(exps, self.newf.atoms(exp),
336            lambda i: i.exp.is_rational_function(*self.T) and
337            i.exp.has(*self.T))
338        pows = update_sets(pows, self.newf.atoms(Pow),
339            lambda i: i.exp.is_rational_function(*self.T) and
340            i.exp.has(*self.T))
341        numpows = update_sets(numpows, set(pows),
342            lambda i: not i.base.has(*self.T))
343        sympows = update_sets(sympows, set(pows) - set(numpows),
344            lambda i: i.base.is_rational_function(*self.T) and
345            not i.exp.is_Integer)
346
347        # The easiest way to deal with non-base E powers is to convert them
348        # into base E, integrate, and then convert back.
349        for i in ordered(pows):
350            old = i
351            new = exp(i.exp*log(i.base))
352            # If exp is ever changed to automatically reduce exp(x*log(2))
353            # to 2**x, then this will break.  The solution is to not change
354            # exp to do that :)
355            if i in sympows:
356                if i.exp.is_Rational:
357                    raise NotImplementedError("Algebraic extensions are "
358                        "not supported (%s)." % str(i))
359                # We can add a**b only if log(a) in the extension, because
360                # a**b == exp(b*log(a)).
361                basea, based = frac_in(i.base, self.t)
362                A = is_deriv_k(basea, based, self)
363                if A is None:
364                    # Nonelementary monomial (so far)
365
366                    # TODO: Would there ever be any benefit from just
367                    # adding log(base) as a new monomial?
368                    # ANSWER: Yes, otherwise we can't integrate x**x (or
369                    # rather prove that it has no elementary integral)
370                    # without first manually rewriting it as exp(x*log(x))
371                    self.newf = self.newf.xreplace({old: new})
372                    self.backsubs += [(new, old)]
373                    log_new_extension = self._log_part([log(i.base)])
374                    exps = update_sets(exps, self.newf.atoms(exp), lambda i:
375                        i.exp.is_rational_function(*self.T) and i.exp.has(*self.T))
376                    continue
377                ans, u, const = A
378                newterm = exp(i.exp*(log(const) + u))
379                # Under the current implementation, exp kills terms
380                # only if they are of the form a*log(x), where a is a
381                # Number.  This case should have already been killed by the
382                # above tests.  Again, if this changes to kill more than
383                # that, this will break, which maybe is a sign that you
384                # shouldn't be changing that.  Actually, if anything, this
385                # auto-simplification should be removed.  See
386                # http://groups.google.com/group/sympy/browse_thread/thread/a61d48235f16867f
387
388                self.newf = self.newf.xreplace({i: newterm})
389
390            elif i not in numpows:
391                continue
392            else:
393                # i in numpows
394                newterm = new
395            # TODO: Just put it in self.Tfuncs
396            self.backsubs.append((new, old))
397            self.newf = self.newf.xreplace({old: newterm})
398            exps.append(newterm)
399
400        return exps, pows, numpows, sympows, log_new_extension
401
402    def _rewrite_logs(self, logs, symlogs):
403        """
404        Rewrite logs for better processing.
405        """
406        atoms = self.newf.atoms(log)
407        logs = update_sets(logs, atoms,
408            lambda i: i.args[0].is_rational_function(*self.T) and
409            i.args[0].has(*self.T))
410        symlogs = update_sets(symlogs, atoms,
411            lambda i: i.has(*self.T) and i.args[0].is_Pow and
412            i.args[0].base.is_rational_function(*self.T) and
413            not i.args[0].exp.is_Integer)
414
415        # We can handle things like log(x**y) by converting it to y*log(x)
416        # This will fix not only symbolic exponents of the argument, but any
417        # non-Integer exponent, like log(sqrt(x)).  The exponent can also
418        # depend on x, like log(x**x).
419        for i in ordered(symlogs):
420            # Unlike in the exponential case above, we do not ever
421            # potentially add new monomials (above we had to add log(a)).
422            # Therefore, there is no need to run any is_deriv functions
423            # here.  Just convert log(a**b) to b*log(a) and let
424            # log_new_extension() handle it from there.
425            lbase = log(i.args[0].base)
426            logs.append(lbase)
427            new = i.args[0].exp*lbase
428            self.newf = self.newf.xreplace({i: new})
429            self.backsubs.append((new, i))
430
431        # remove any duplicates
432        logs = sorted(set(logs), key=default_sort_key)
433
434        return logs, symlogs
435
436    def _auto_attrs(self):
437        """
438        Set attributes that are generated automatically.
439        """
440        if not self.T:
441            # i.e., when using the extension flag and T isn't given
442            self.T = [i.gen for i in self.D]
443        if not self.x:
444            self.x = self.T[0]
445        self.cases = [get_case(d, t) for d, t in zip(self.D, self.T)]
446        self.level = -1
447        self.t = self.T[self.level]
448        self.d = self.D[self.level]
449        self.case = self.cases[self.level]
450
451    def _exp_part(self, exps):
452        """
453        Try to build an exponential extension.
454
455        Returns
456        =======
457
458        Returns True if there was a new extension, False if there was no new
459        extension but it was able to rewrite the given exponentials in terms
460        of the existing extension, and None if the entire extension building
461        process should be restarted.  If the process fails because there is no
462        way around an algebraic extension (e.g., exp(log(x)/2)), it will raise
463        NotImplementedError.
464        """
465        from sympy.integrals.prde import is_log_deriv_k_t_radical
466
467        new_extension = False
468        restart = False
469        expargs = [i.exp for i in exps]
470        ip = integer_powers(expargs)
471        for arg, others in ip:
472            # Minimize potential problems with algebraic substitution
473            others.sort(key=lambda i: i[1])
474
475            arga, argd = frac_in(arg, self.t)
476            A = is_log_deriv_k_t_radical(arga, argd, self)
477
478            if A is not None:
479                ans, u, n, const = A
480                # if n is 1 or -1, it's algebraic, but we can handle it
481                if n == -1:
482                    # This probably will never happen, because
483                    # Rational.as_numer_denom() returns the negative term in
484                    # the numerator.  But in case that changes, reduce it to
485                    # n == 1.
486                    n = 1
487                    u **= -1
488                    const *= -1
489                    ans = [(i, -j) for i, j in ans]
490
491                if n == 1:
492                    # Example: exp(x + x**2) over QQ(x, exp(x), exp(x**2))
493                    self.newf = self.newf.xreplace({exp(arg): exp(const)*Mul(*[
494                        u**power for u, power in ans])})
495                    self.newf = self.newf.xreplace({exp(p*exparg):
496                        exp(const*p) * Mul(*[u**power for u, power in ans])
497                        for exparg, p in others})
498                    # TODO: Add something to backsubs to put exp(const*p)
499                    # back together.
500
501                    continue
502
503                else:
504                    # Bad news: we have an algebraic radical.  But maybe we
505                    # could still avoid it by choosing a different extension.
506                    # For example, integer_powers() won't handle exp(x/2 + 1)
507                    # over QQ(x, exp(x)), but if we pull out the exp(1), it
508                    # will.  Or maybe we have exp(x + x**2/2), over
509                    # QQ(x, exp(x), exp(x**2)), which is exp(x)*sqrt(exp(x**2)),
510                    # but if we use QQ(x, exp(x), exp(x**2/2)), then they will
511                    # all work.
512                    #
513                    # So here is what we do: If there is a non-zero const, pull
514                    # it out and retry.  Also, if len(ans) > 1, then rewrite
515                    # exp(arg) as the product of exponentials from ans, and
516                    # retry that.  If const == 0 and len(ans) == 1, then we
517                    # assume that it would have been handled by either
518                    # integer_powers() or n == 1 above if it could be handled,
519                    # so we give up at that point.  For example, you can never
520                    # handle exp(log(x)/2) because it equals sqrt(x).
521
522                    if const or len(ans) > 1:
523                        rad = Mul(*[term**(power/n) for term, power in ans])
524                        self.newf = self.newf.xreplace({exp(p*exparg):
525                            exp(const*p)*rad for exparg, p in others})
526                        self.newf = self.newf.xreplace(dict(list(zip(reversed(self.T),
527                            reversed([f(self.x) for f in self.Tfuncs])))))
528                        restart = True
529                        break
530                    else:
531                        # TODO: give algebraic dependence in error string
532                        raise NotImplementedError("Cannot integrate over "
533                            "algebraic extensions.")
534
535            else:
536                arga, argd = frac_in(arg, self.t)
537                darga = (argd*derivation(Poly(arga, self.t), self) -
538                    arga*derivation(Poly(argd, self.t), self))
539                dargd = argd**2
540                darga, dargd = darga.cancel(dargd, include=True)
541                darg = darga.as_expr()/dargd.as_expr()
542                self.t = next(self.ts)
543                self.T.append(self.t)
544                self.extargs.append(arg)
545                self.exts.append('exp')
546                self.D.append(darg.as_poly(self.t, expand=False)*Poly(self.t,
547                    self.t, expand=False))
548                if self.dummy:
549                    i = Dummy("i")
550                else:
551                    i = Symbol('i')
552                self.Tfuncs += [Lambda(i, exp(arg.subs(self.x, i)))]
553                self.newf = self.newf.xreplace(
554                        {exp(exparg): self.t**p for exparg, p in others})
555                new_extension = True
556
557        if restart:
558            return None
559        return new_extension
560
561    def _log_part(self, logs):
562        """
563        Try to build a logarithmic extension.
564
565        Returns
566        =======
567
568        Returns True if there was a new extension and False if there was no new
569        extension but it was able to rewrite the given logarithms in terms
570        of the existing extension.  Unlike with exponential extensions, there
571        is no way that a logarithm is not transcendental over and cannot be
572        rewritten in terms of an already existing extension in a non-algebraic
573        way, so this function does not ever return None or raise
574        NotImplementedError.
575        """
576        from sympy.integrals.prde import is_deriv_k
577
578        new_extension = False
579        logargs = [i.args[0] for i in logs]
580        for arg in ordered(logargs):
581            # The log case is easier, because whenever a logarithm is algebraic
582            # over the base field, it is of the form a1*t1 + ... an*tn + c,
583            # which is a polynomial, so we can just replace it with that.
584            # In other words, we don't have to worry about radicals.
585            arga, argd = frac_in(arg, self.t)
586            A = is_deriv_k(arga, argd, self)
587            if A is not None:
588                ans, u, const = A
589                newterm = log(const) + u
590                self.newf = self.newf.xreplace({log(arg): newterm})
591                continue
592
593            else:
594                arga, argd = frac_in(arg, self.t)
595                darga = (argd*derivation(Poly(arga, self.t), self) -
596                    arga*derivation(Poly(argd, self.t), self))
597                dargd = argd**2
598                darg = darga.as_expr()/dargd.as_expr()
599                self.t = next(self.ts)
600                self.T.append(self.t)
601                self.extargs.append(arg)
602                self.exts.append('log')
603                self.D.append(cancel(darg.as_expr()/arg).as_poly(self.t,
604                    expand=False))
605                if self.dummy:
606                    i = Dummy("i")
607                else:
608                    i = Symbol('i')
609                self.Tfuncs += [Lambda(i, log(arg.subs(self.x, i)))]
610                self.newf = self.newf.xreplace({log(arg): self.t})
611                new_extension = True
612
613        return new_extension
614
615    @property
616    def _important_attrs(self):
617        """
618        Returns some of the more important attributes of self.
619
620        Explanation
621        ===========
622
623        Used for testing and debugging purposes.
624
625        The attributes are (fa, fd, D, T, Tfuncs, backsubs,
626        exts, extargs).
627        """
628        return (self.fa, self.fd, self.D, self.T, self.Tfuncs,
629            self.backsubs, self.exts, self.extargs)
630
631    # NOTE: this printing doesn't follow the Python's standard
632    # eval(repr(DE)) == DE, where DE is the DifferentialExtension object
633    # , also this printing is supposed to contain all the important
634    # attributes of a DifferentialExtension object
635    def __repr__(self):
636        # no need to have GeneratorType object printed in it
637        r = [(attr, getattr(self, attr)) for attr in self.__slots__
638                if not isinstance(getattr(self, attr), GeneratorType)]
639        return self.__class__.__name__ + '(dict(%r))' % (r)
640
641    # fancy printing of DifferentialExtension object
642    def __str__(self):
643        return (self.__class__.__name__ + '({fa=%s, fd=%s, D=%s})' %
644                (self.fa, self.fd, self.D))
645
646    # should only be used for debugging purposes, internally
647    # f1 = f2 = log(x) at different places in code execution
648    # may return D1 != D2 as True, since 'level' or other attribute
649    # may differ
650    def __eq__(self, other):
651        for attr in self.__class__.__slots__:
652            d1, d2 = getattr(self, attr), getattr(other, attr)
653            if not (isinstance(d1, GeneratorType) or d1 == d2):
654                return False
655        return True
656
657    def reset(self):
658        """
659        Reset self to an initial state.  Used by __init__.
660        """
661        self.t = self.x
662        self.T = [self.x]
663        self.D = [Poly(1, self.x)]
664        self.level = -1
665        self.exts = [None]
666        self.extargs = [None]
667        if self.dummy:
668            self.ts = numbered_symbols('t', cls=Dummy)
669        else:
670            # For testing
671            self.ts = numbered_symbols('t')
672        # For various things that we change to make things work that we need to
673        # change back when we are done.
674        self.backsubs = []
675        self.Tfuncs = []
676        self.newf = self.f
677
678    def indices(self, extension):
679        """
680        Parameters
681        ==========
682
683        extension : str
684            Represents a valid extension type.
685
686        Returns
687        =======
688
689        list: A list of indices of 'exts' where extension of
690            type 'extension' is present.
691
692        Examples
693        ========
694
695        >>> from sympy.integrals.risch import DifferentialExtension
696        >>> from sympy import log, exp
697        >>> from sympy.abc import x
698        >>> DE = DifferentialExtension(log(x) + exp(x), x, handle_first='exp')
699        >>> DE.indices('log')
700        [2]
701        >>> DE.indices('exp')
702        [1]
703
704        """
705        return [i for i, ext in enumerate(self.exts) if ext == extension]
706
707    def increment_level(self):
708        """
709        Increment the level of self.
710
711        Explanation
712        ===========
713
714        This makes the working differential extension larger.  self.level is
715        given relative to the end of the list (-1, -2, etc.), so we don't need
716        do worry about it when building the extension.
717        """
718        if self.level >= -1:
719            raise ValueError("The level of the differential extension cannot "
720                "be incremented any further.")
721
722        self.level += 1
723        self.t = self.T[self.level]
724        self.d = self.D[self.level]
725        self.case = self.cases[self.level]
726        return None
727
728    def decrement_level(self):
729        """
730        Decrease the level of self.
731
732        Explanation
733        ===========
734
735        This makes the working differential extension smaller.  self.level is
736        given relative to the end of the list (-1, -2, etc.), so we don't need
737        do worry about it when building the extension.
738        """
739        if self.level <= -len(self.T):
740            raise ValueError("The level of the differential extension cannot "
741                "be decremented any further.")
742
743        self.level -= 1
744        self.t = self.T[self.level]
745        self.d = self.D[self.level]
746        self.case = self.cases[self.level]
747        return None
748
749
750def update_sets(seq, atoms, func):
751    s = set(seq)
752    s = atoms.intersection(s)
753    new = atoms - s
754    s.update(list(filter(func, new)))
755    return list(s)
756
757
758class DecrementLevel:
759    """
760    A context manager for decrementing the level of a DifferentialExtension.
761    """
762    __slots__ = ('DE',)
763
764    def __init__(self, DE):
765        self.DE = DE
766        return
767
768    def __enter__(self):
769        self.DE.decrement_level()
770
771    def __exit__(self, exc_type, exc_value, traceback):
772        self.DE.increment_level()
773
774
775class NonElementaryIntegralException(Exception):
776    """
777    Exception used by subroutines within the Risch algorithm to indicate to one
778    another that the function being integrated does not have an elementary
779    integral in the given differential field.
780    """
781    # TODO: Rewrite algorithms below to use this (?)
782
783    # TODO: Pass through information about why the integral was nonelementary,
784    # and store that in the resulting NonElementaryIntegral somehow.
785    pass
786
787
788def gcdex_diophantine(a, b, c):
789    """
790    Extended Euclidean Algorithm, Diophantine version.
791
792    Explanation
793    ===========
794
795    Given ``a``, ``b`` in K[x] and ``c`` in (a, b), the ideal generated by ``a`` and
796    ``b``, return (s, t) such that s*a + t*b == c and either s == 0 or s.degree()
797    < b.degree().
798    """
799    # Extended Euclidean Algorithm (Diophantine Version) pg. 13
800    # TODO: This should go in densetools.py.
801    # XXX: Bettter name?
802
803    s, g = a.half_gcdex(b)
804    s *= c.exquo(g)  # Inexact division means c is not in (a, b)
805    if s and s.degree() >= b.degree():
806        _, s = s.div(b)
807    t = (c - s*a).exquo(b)
808    return (s, t)
809
810
811def frac_in(f, t, *, cancel=False, **kwargs):
812    """
813    Returns the tuple (fa, fd), where fa and fd are Polys in t.
814
815    Explanation
816    ===========
817
818    This is a common idiom in the Risch Algorithm functions, so we abstract
819    it out here. ``f`` should be a basic expression, a Poly, or a tuple (fa, fd),
820    where fa and fd are either basic expressions or Polys, and f == fa/fd.
821    **kwargs are applied to Poly.
822    """
823    if type(f) is tuple:
824        fa, fd = f
825        f = fa.as_expr()/fd.as_expr()
826    fa, fd = f.as_expr().as_numer_denom()
827    fa, fd = fa.as_poly(t, **kwargs), fd.as_poly(t, **kwargs)
828    if cancel:
829        fa, fd = fa.cancel(fd, include=True)
830    if fa is None or fd is None:
831        raise ValueError("Could not turn %s into a fraction in %s." % (f, t))
832    return (fa, fd)
833
834
835def as_poly_1t(p, t, z):
836    """
837    (Hackish) way to convert an element ``p`` of K[t, 1/t] to K[t, z].
838
839    In other words, ``z == 1/t`` will be a dummy variable that Poly can handle
840    better.
841
842    See issue 5131.
843
844    Examples
845    ========
846
847    >>> from sympy import random_poly
848    >>> from sympy.integrals.risch import as_poly_1t
849    >>> from sympy.abc import x, z
850
851    >>> p1 = random_poly(x, 10, -10, 10)
852    >>> p2 = random_poly(x, 10, -10, 10)
853    >>> p = p1 + p2.subs(x, 1/x)
854    >>> as_poly_1t(p, x, z).as_expr().subs(z, 1/x) == p
855    True
856    """
857    # TODO: Use this on the final result.  That way, we can avoid answers like
858    # (...)*exp(-x).
859    pa, pd = frac_in(p, t, cancel=True)
860    if not pd.is_monomial:
861        # XXX: Is there a better Poly exception that we could raise here?
862        # Either way, if you see this (from the Risch Algorithm) it indicates
863        # a bug.
864        raise PolynomialError("%s is not an element of K[%s, 1/%s]." % (p, t, t))
865    d = pd.degree(t)
866    one_t_part = pa.slice(0, d + 1)
867    r = pd.degree() - pa.degree()
868    t_part = pa - one_t_part
869    try:
870        t_part = t_part.to_field().exquo(pd)
871    except DomainError as e:
872        # issue 4950
873        raise NotImplementedError(e)
874    # Compute the negative degree parts.
875    one_t_part = Poly.from_list(reversed(one_t_part.rep.rep), *one_t_part.gens,
876        domain=one_t_part.domain)
877    if 0 < r < oo:
878        one_t_part *= Poly(t**r, t)
879
880    one_t_part = one_t_part.replace(t, z)  # z will be 1/t
881    if pd.nth(d):
882        one_t_part *= Poly(1/pd.nth(d), z, expand=False)
883    ans = t_part.as_poly(t, z, expand=False) + one_t_part.as_poly(t, z,
884        expand=False)
885
886    return ans
887
888
889def derivation(p, DE, coefficientD=False, basic=False):
890    """
891    Computes Dp.
892
893    Explanation
894    ===========
895
896    Given the derivation D with D = d/dx and p is a polynomial in t over
897    K(x), return Dp.
898
899    If coefficientD is True, it computes the derivation kD
900    (kappaD), which is defined as kD(sum(ai*Xi**i, (i, 0, n))) ==
901    sum(Dai*Xi**i, (i, 1, n)) (Definition 3.2.2, page 80).  X in this case is
902    T[-1], so coefficientD computes the derivative just with respect to T[:-1],
903    with T[-1] treated as a constant.
904
905    If ``basic=True``, the returns a Basic expression.  Elements of D can still be
906    instances of Poly.
907    """
908    if basic:
909        r = 0
910    else:
911        r = Poly(0, DE.t)
912
913    t = DE.t
914    if coefficientD:
915        if DE.level <= -len(DE.T):
916            # 'base' case, the answer is 0.
917            return r
918        DE.decrement_level()
919
920    D = DE.D[:len(DE.D) + DE.level + 1]
921    T = DE.T[:len(DE.T) + DE.level + 1]
922
923    for d, v in zip(D, T):
924        pv = p.as_poly(v)
925        if pv is None or basic:
926            pv = p.as_expr()
927
928        if basic:
929            r += d.as_expr()*pv.diff(v)
930        else:
931            r += (d.as_expr()*pv.diff(v).as_expr()).as_poly(t)
932
933    if basic:
934        r = cancel(r)
935    if coefficientD:
936        DE.increment_level()
937
938    return r
939
940
941def get_case(d, t):
942    """
943    Returns the type of the derivation d.
944
945    Returns one of {'exp', 'tan', 'base', 'primitive', 'other_linear',
946    'other_nonlinear'}.
947    """
948    if not d.expr.has(t):
949        if d.is_one:
950            return 'base'
951        return 'primitive'
952    if d.rem(Poly(t, t)).is_zero:
953        return 'exp'
954    if d.rem(Poly(1 + t**2, t)).is_zero:
955        return 'tan'
956    if d.degree(t) > 1:
957        return 'other_nonlinear'
958    return 'other_linear'
959
960
961def splitfactor(p, DE, coefficientD=False, z=None):
962    """
963    Splitting factorization.
964
965    Explanation
966    ===========
967
968    Given a derivation D on k[t] and ``p`` in k[t], return (p_n, p_s) in
969    k[t] x k[t] such that p = p_n*p_s, p_s is special, and each square
970    factor of p_n is normal.
971
972    Page. 100
973    """
974    kinv = [1/x for x in DE.T[:DE.level]]
975    if z:
976        kinv.append(z)
977
978    One = Poly(1, DE.t, domain=p.get_domain())
979    Dp = derivation(p, DE, coefficientD=coefficientD)
980    # XXX: Is this right?
981    if p.is_zero:
982        return (p, One)
983
984    if not p.expr.has(DE.t):
985        s = p.as_poly(*kinv).gcd(Dp.as_poly(*kinv)).as_poly(DE.t)
986        n = p.exquo(s)
987        return (n, s)
988
989    if not Dp.is_zero:
990        h = p.gcd(Dp).to_field()
991        g = p.gcd(p.diff(DE.t)).to_field()
992        s = h.exquo(g)
993
994        if s.degree(DE.t) == 0:
995            return (p, One)
996
997        q_split = splitfactor(p.exquo(s), DE, coefficientD=coefficientD)
998
999        return (q_split[0], q_split[1]*s)
1000    else:
1001        return (p, One)
1002
1003
1004def splitfactor_sqf(p, DE, coefficientD=False, z=None, basic=False):
1005    """
1006    Splitting Square-free Factorization.
1007
1008    Explanation
1009    ===========
1010
1011    Given a derivation D on k[t] and ``p`` in k[t], returns (N1, ..., Nm)
1012    and (S1, ..., Sm) in k[t]^m such that p =
1013    (N1*N2**2*...*Nm**m)*(S1*S2**2*...*Sm**m) is a splitting
1014    factorization of ``p`` and the Ni and Si are square-free and coprime.
1015    """
1016    # TODO: This algorithm appears to be faster in every case
1017    # TODO: Verify this and splitfactor() for multiple extensions
1018    kkinv = [1/x for x in DE.T[:DE.level]] + DE.T[:DE.level]
1019    if z:
1020        kkinv = [z]
1021
1022    S = []
1023    N = []
1024    p_sqf = p.sqf_list_include()
1025    if p.is_zero:
1026        return (((p, 1),), ())
1027
1028    for pi, i in p_sqf:
1029        Si = pi.as_poly(*kkinv).gcd(derivation(pi, DE,
1030            coefficientD=coefficientD,basic=basic).as_poly(*kkinv)).as_poly(DE.t)
1031        pi = Poly(pi, DE.t)
1032        Si = Poly(Si, DE.t)
1033        Ni = pi.exquo(Si)
1034        if not Si.is_one:
1035            S.append((Si, i))
1036        if not Ni.is_one:
1037            N.append((Ni, i))
1038
1039    return (tuple(N), tuple(S))
1040
1041
1042def canonical_representation(a, d, DE):
1043    """
1044    Canonical Representation.
1045
1046    Explanation
1047    ===========
1048
1049    Given a derivation D on k[t] and f = a/d in k(t), return (f_p, f_s,
1050    f_n) in k[t] x k(t) x k(t) such that f = f_p + f_s + f_n is the
1051    canonical representation of f (f_p is a polynomial, f_s is reduced
1052    (has a special denominator), and f_n is simple (has a normal
1053    denominator).
1054    """
1055    # Make d monic
1056    l = Poly(1/d.LC(), DE.t)
1057    a, d = a.mul(l), d.mul(l)
1058
1059    q, r = a.div(d)
1060    dn, ds = splitfactor(d, DE)
1061
1062    b, c = gcdex_diophantine(dn.as_poly(DE.t), ds.as_poly(DE.t), r.as_poly(DE.t))
1063    b, c = b.as_poly(DE.t), c.as_poly(DE.t)
1064
1065    return (q, (b, ds), (c, dn))
1066
1067
1068def hermite_reduce(a, d, DE):
1069    """
1070    Hermite Reduction - Mack's Linear Version.
1071
1072    Given a derivation D on k(t) and f = a/d in k(t), returns g, h, r in
1073    k(t) such that f = Dg + h + r, h is simple, and r is reduced.
1074
1075    """
1076    # Make d monic
1077    l = Poly(1/d.LC(), DE.t)
1078    a, d = a.mul(l), d.mul(l)
1079
1080    fp, fs, fn = canonical_representation(a, d, DE)
1081    a, d = fn
1082    l = Poly(1/d.LC(), DE.t)
1083    a, d = a.mul(l), d.mul(l)
1084
1085    ga = Poly(0, DE.t)
1086    gd = Poly(1, DE.t)
1087
1088    dd = derivation(d, DE)
1089    dm = gcd(d, dd).as_poly(DE.t)
1090    ds, r = d.div(dm)
1091
1092    while dm.degree(DE.t)>0:
1093
1094        ddm = derivation(dm, DE)
1095        dm2 = gcd(dm, ddm)
1096        dms, r = dm.div(dm2)
1097        ds_ddm = ds.mul(ddm)
1098        ds_ddm_dm, r = ds_ddm.div(dm)
1099
1100        b, c = gcdex_diophantine(-ds_ddm_dm.as_poly(DE.t), dms.as_poly(DE.t), a.as_poly(DE.t))
1101        b, c = b.as_poly(DE.t), c.as_poly(DE.t)
1102
1103        db = derivation(b, DE).as_poly(DE.t)
1104        ds_dms, r = ds.div(dms)
1105        a = c.as_poly(DE.t) - db.mul(ds_dms).as_poly(DE.t)
1106
1107        ga = ga*dm + b*gd
1108        gd = gd*dm
1109        ga, gd = ga.cancel(gd, include=True)
1110        dm = dm2
1111
1112    d = ds
1113    q, r = a.div(d)
1114    ga, gd = ga.cancel(gd, include=True)
1115
1116    r, d = r.cancel(d, include=True)
1117    rra = q*fs[1] + fp*fs[1] + fs[0]
1118    rrd = fs[1]
1119    rra, rrd = rra.cancel(rrd, include=True)
1120
1121    return ((ga, gd), (r, d), (rra, rrd))
1122
1123
1124def polynomial_reduce(p, DE):
1125    """
1126    Polynomial Reduction.
1127
1128    Explanation
1129    ===========
1130
1131    Given a derivation D on k(t) and p in k[t] where t is a nonlinear
1132    monomial over k, return q, r in k[t] such that p = Dq  + r, and
1133    deg(r) < deg_t(Dt).
1134    """
1135    q = Poly(0, DE.t)
1136    while p.degree(DE.t) >= DE.d.degree(DE.t):
1137        m = p.degree(DE.t) - DE.d.degree(DE.t) + 1
1138        q0 = Poly(DE.t**m, DE.t).mul(Poly(p.as_poly(DE.t).LC()/
1139            (m*DE.d.LC()), DE.t))
1140        q += q0
1141        p = p - derivation(q0, DE)
1142
1143    return (q, p)
1144
1145
1146def laurent_series(a, d, F, n, DE):
1147    """
1148    Contribution of ``F`` to the full partial fraction decomposition of A/D.
1149
1150    Explanation
1151    ===========
1152
1153    Given a field K of characteristic 0 and ``A``,``D``,``F`` in K[x] with D monic,
1154    nonzero, coprime with A, and ``F`` the factor of multiplicity n in the square-
1155    free factorization of D, return the principal parts of the Laurent series of
1156    A/D at all the zeros of ``F``.
1157    """
1158    if F.degree()==0:
1159        return 0
1160    Z = _symbols('z', n)
1161    Z.insert(0, z)
1162    delta_a = Poly(0, DE.t)
1163    delta_d = Poly(1, DE.t)
1164
1165    E = d.quo(F**n)
1166    ha, hd = (a, E*Poly(z**n, DE.t))
1167    dF = derivation(F,DE)
1168    B, G = gcdex_diophantine(E, F, Poly(1,DE.t))
1169    C, G = gcdex_diophantine(dF, F, Poly(1,DE.t))
1170
1171    # initialization
1172    F_store = F
1173    V, DE_D_list, H_list= [], [], []
1174
1175    for j in range(0, n):
1176    # jth derivative of z would be substituted with dfnth/(j+1) where dfnth =(d^n)f/(dx)^n
1177        F_store = derivation(F_store, DE)
1178        v = (F_store.as_expr())/(j + 1)
1179        V.append(v)
1180        DE_D_list.append(Poly(Z[j + 1],Z[j]))
1181
1182    DE_new = DifferentialExtension(extension = {'D': DE_D_list}) #a differential indeterminate
1183    for j in range(0, n):
1184        zEha = Poly(z**(n + j), DE.t)*E**(j + 1)*ha
1185        zEhd = hd
1186        Pa, Pd = cancel((zEha, zEhd))[1], cancel((zEha, zEhd))[2]
1187        Q = Pa.quo(Pd)
1188        for i in range(0, j + 1):
1189            Q = Q.subs(Z[i], V[i])
1190        Dha = (hd*derivation(ha, DE, basic=True).as_poly(DE.t)
1191             + ha*derivation(hd, DE, basic=True).as_poly(DE.t)
1192             + hd*derivation(ha, DE_new, basic=True).as_poly(DE.t)
1193             + ha*derivation(hd, DE_new, basic=True).as_poly(DE.t))
1194        Dhd = Poly(j + 1, DE.t)*hd**2
1195        ha, hd = Dha, Dhd
1196
1197        Ff, Fr = F.div(gcd(F, Q))
1198        F_stara, F_stard = frac_in(Ff, DE.t)
1199        if F_stara.degree(DE.t) - F_stard.degree(DE.t) > 0:
1200            QBC = Poly(Q, DE.t)*B**(1 + j)*C**(n + j)
1201            H = QBC
1202            H_list.append(H)
1203            H = (QBC*F_stard).rem(F_stara)
1204            alphas = real_roots(F_stara)
1205            for alpha in list(alphas):
1206                delta_a = delta_a*Poly((DE.t - alpha)**(n - j), DE.t) + Poly(H.eval(alpha), DE.t)
1207                delta_d = delta_d*Poly((DE.t - alpha)**(n - j), DE.t)
1208    return (delta_a, delta_d, H_list)
1209
1210
1211def recognize_derivative(a, d, DE, z=None):
1212    """
1213    Compute the squarefree factorization of the denominator of f
1214    and for each Di the polynomial H in K[x] (see Theorem 2.7.1), using the
1215    LaurentSeries algorithm. Write Di = GiEi where Gj = gcd(Hn, Di) and
1216    gcd(Ei,Hn) = 1. Since the residues of f at the roots of Gj are all 0, and
1217    the residue of f at a root alpha of Ei is Hi(a) != 0, f is the derivative of a
1218    rational function if and only if Ei = 1 for each i, which is equivalent to
1219    Di | H[-1] for each i.
1220    """
1221    flag =True
1222    a, d = a.cancel(d, include=True)
1223    q, r = a.div(d)
1224    Np, Sp = splitfactor_sqf(d, DE, coefficientD=True, z=z)
1225
1226    j = 1
1227    for (s, i) in Sp:
1228       delta_a, delta_d, H = laurent_series(r, d, s, j, DE)
1229       g = gcd(d, H[-1]).as_poly()
1230       if g is not d:
1231             flag = False
1232             break
1233       j = j + 1
1234    return flag
1235
1236def recognize_log_derivative(a, d, DE, z=None):
1237    """
1238    There exists a v in K(x)* such that f = dv/v
1239    where f a rational function if and only if f can be written as f = A/D
1240    where D is squarefree,deg(A) < deg(D), gcd(A, D) = 1,
1241    and all the roots of the Rothstein-Trager resultant are integers. In that case,
1242    any of the Rothstein-Trager, Lazard-Rioboo-Trager or Czichowski algorithm
1243    produces u in K(x) such that du/dx = uf.
1244    """
1245
1246    z = z or Dummy('z')
1247    a, d = a.cancel(d, include=True)
1248    p, a = a.div(d)
1249
1250    pz = Poly(z, DE.t)
1251    Dd = derivation(d, DE)
1252    q = a - pz*Dd
1253    r, R = d.resultant(q, includePRS=True)
1254    r = Poly(r, z)
1255    Np, Sp = splitfactor_sqf(r, DE, coefficientD=True, z=z)
1256
1257    for s, i in Sp:
1258        # TODO also consider the complex roots
1259        # incase we have complex roots it should turn the flag false
1260        a = real_roots(s.as_poly(z))
1261
1262        if any(not j.is_Integer for j in a):
1263            return False
1264    return True
1265
1266def residue_reduce(a, d, DE, z=None, invert=True):
1267    """
1268    Lazard-Rioboo-Rothstein-Trager resultant reduction.
1269
1270    Explanation
1271    ===========
1272
1273    Given a derivation ``D`` on k(t) and f in k(t) simple, return g
1274    elementary over k(t) and a Boolean b in {True, False} such that f -
1275    Dg in k[t] if b == True or f + h and f + h - Dg do not have an
1276    elementary integral over k(t) for any h in k<t> (reduced) if b ==
1277    False.
1278
1279    Returns (G, b), where G is a tuple of tuples of the form (s_i, S_i),
1280    such that g = Add(*[RootSum(s_i, lambda z: z*log(S_i(z, t))) for
1281    S_i, s_i in G]). f - Dg is the remaining integral, which is elementary
1282    only if b == True, and hence the integral of f is elementary only if
1283    b == True.
1284
1285    f - Dg is not calculated in this function because that would require
1286    explicitly calculating the RootSum.  Use residue_reduce_derivation().
1287    """
1288    # TODO: Use log_to_atan() from rationaltools.py
1289    # If r = residue_reduce(...), then the logarithmic part is given by:
1290    # sum([RootSum(a[0].as_poly(z), lambda i: i*log(a[1].as_expr()).subs(z,
1291    # i)).subs(t, log(x)) for a in r[0]])
1292
1293    z = z or Dummy('z')
1294    a, d = a.cancel(d, include=True)
1295    a, d = a.to_field().mul_ground(1/d.LC()), d.to_field().mul_ground(1/d.LC())
1296    kkinv = [1/x for x in DE.T[:DE.level]] + DE.T[:DE.level]
1297
1298    if a.is_zero:
1299        return ([], True)
1300    p, a = a.div(d)
1301
1302    pz = Poly(z, DE.t)
1303
1304    Dd = derivation(d, DE)
1305    q = a - pz*Dd
1306
1307    if Dd.degree(DE.t) <= d.degree(DE.t):
1308        r, R = d.resultant(q, includePRS=True)
1309    else:
1310        r, R = q.resultant(d, includePRS=True)
1311
1312    R_map, H = {}, []
1313    for i in R:
1314        R_map[i.degree()] = i
1315
1316    r = Poly(r, z)
1317    Np, Sp = splitfactor_sqf(r, DE, coefficientD=True, z=z)
1318
1319    for s, i in Sp:
1320        if i == d.degree(DE.t):
1321            s = Poly(s, z).monic()
1322            H.append((s, d))
1323        else:
1324            h = R_map.get(i)
1325            if h is None:
1326                continue
1327            h_lc = Poly(h.as_poly(DE.t).LC(), DE.t, field=True)
1328
1329            h_lc_sqf = h_lc.sqf_list_include(all=True)
1330
1331            for a, j in h_lc_sqf:
1332                h = Poly(h, DE.t, field=True).exquo(Poly(gcd(a, s**j, *kkinv),
1333                    DE.t))
1334
1335            s = Poly(s, z).monic()
1336
1337            if invert:
1338                h_lc = Poly(h.as_poly(DE.t).LC(), DE.t, field=True, expand=False)
1339                inv, coeffs = h_lc.as_poly(z, field=True).invert(s), [S.One]
1340
1341                for coeff in h.coeffs()[1:]:
1342                    L = reduced(inv*coeff.as_poly(inv.gens), [s])[1]
1343                    coeffs.append(L.as_expr())
1344
1345                h = Poly(dict(list(zip(h.monoms(), coeffs))), DE.t)
1346
1347            H.append((s, h))
1348
1349    b = all([not cancel(i.as_expr()).has(DE.t, z) for i, _ in Np])
1350
1351    return (H, b)
1352
1353
1354def residue_reduce_to_basic(H, DE, z):
1355    """
1356    Converts the tuple returned by residue_reduce() into a Basic expression.
1357    """
1358    # TODO: check what Lambda does with RootOf
1359    i = Dummy('i')
1360    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1361
1362    return sum(RootSum(a[0].as_poly(z), Lambda(i, i*log(a[1].as_expr()).subs(
1363        {z: i}).subs(s))) for a in H)
1364
1365
1366def residue_reduce_derivation(H, DE, z):
1367    """
1368    Computes the derivation of an expression returned by residue_reduce().
1369
1370    In general, this is a rational function in t, so this returns an
1371    as_expr() result.
1372    """
1373    # TODO: verify that this is correct for multiple extensions
1374    i = Dummy('i')
1375    return S(sum(RootSum(a[0].as_poly(z), Lambda(i, i*derivation(a[1],
1376        DE).as_expr().subs(z, i)/a[1].as_expr().subs(z, i))) for a in H))
1377
1378
1379def integrate_primitive_polynomial(p, DE):
1380    """
1381    Integration of primitive polynomials.
1382
1383    Explanation
1384    ===========
1385
1386    Given a primitive monomial t over k, and ``p`` in k[t], return q in k[t],
1387    r in k, and a bool b in {True, False} such that r = p - Dq is in k if b is
1388    True, or r = p - Dq does not have an elementary integral over k(t) if b is
1389    False.
1390    """
1391    from sympy.integrals.prde import limited_integrate
1392
1393    Zero = Poly(0, DE.t)
1394    q = Poly(0, DE.t)
1395
1396    if not p.expr.has(DE.t):
1397        return (Zero, p, True)
1398
1399    while True:
1400        if not p.expr.has(DE.t):
1401            return (q, p, True)
1402
1403        Dta, Dtb = frac_in(DE.d, DE.T[DE.level - 1])
1404
1405        with DecrementLevel(DE):  # We had better be integrating the lowest extension (x)
1406                                  # with ratint().
1407            a = p.LC()
1408            aa, ad = frac_in(a, DE.t)
1409
1410            try:
1411                rv = limited_integrate(aa, ad, [(Dta, Dtb)], DE)
1412                if rv is None:
1413                    raise NonElementaryIntegralException
1414                (ba, bd), c = rv
1415            except NonElementaryIntegralException:
1416                return (q, p, False)
1417
1418        m = p.degree(DE.t)
1419        q0 = c[0].as_poly(DE.t)*Poly(DE.t**(m + 1)/(m + 1), DE.t) + \
1420            (ba.as_expr()/bd.as_expr()).as_poly(DE.t)*Poly(DE.t**m, DE.t)
1421
1422        p = p - derivation(q0, DE)
1423        q = q + q0
1424
1425
1426def integrate_primitive(a, d, DE, z=None):
1427    """
1428    Integration of primitive functions.
1429
1430    Explanation
1431    ===========
1432
1433    Given a primitive monomial t over k and f in k(t), return g elementary over
1434    k(t), i in k(t), and b in {True, False} such that i = f - Dg is in k if b
1435    is True or i = f - Dg does not have an elementary integral over k(t) if b
1436    is False.
1437
1438    This function returns a Basic expression for the first argument.  If b is
1439    True, the second argument is Basic expression in k to recursively integrate.
1440    If b is False, the second argument is an unevaluated Integral, which has
1441    been proven to be nonelementary.
1442    """
1443    # XXX: a and d must be canceled, or this might return incorrect results
1444    z = z or Dummy("z")
1445    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1446
1447    g1, h, r = hermite_reduce(a, d, DE)
1448    g2, b = residue_reduce(h[0], h[1], DE, z=z)
1449    if not b:
1450        i = cancel(a.as_expr()/d.as_expr() - (g1[1]*derivation(g1[0], DE) -
1451            g1[0]*derivation(g1[1], DE)).as_expr()/(g1[1]**2).as_expr() -
1452            residue_reduce_derivation(g2, DE, z))
1453        i = NonElementaryIntegral(cancel(i).subs(s), DE.x)
1454        return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) +
1455            residue_reduce_to_basic(g2, DE, z), i, b)
1456
1457    # h - Dg2 + r
1458    p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2,
1459        DE, z) + r[0].as_expr()/r[1].as_expr())
1460    p = p.as_poly(DE.t)
1461
1462    q, i, b = integrate_primitive_polynomial(p, DE)
1463
1464    ret = ((g1[0].as_expr()/g1[1].as_expr() + q.as_expr()).subs(s) +
1465        residue_reduce_to_basic(g2, DE, z))
1466    if not b:
1467        # TODO: This does not do the right thing when b is False
1468        i = NonElementaryIntegral(cancel(i.as_expr()).subs(s), DE.x)
1469    else:
1470        i = cancel(i.as_expr())
1471
1472    return (ret, i, b)
1473
1474
1475def integrate_hyperexponential_polynomial(p, DE, z):
1476    """
1477    Integration of hyperexponential polynomials.
1478
1479    Explanation
1480    ===========
1481
1482    Given a hyperexponential monomial t over k and ``p`` in k[t, 1/t], return q in
1483    k[t, 1/t] and a bool b in {True, False} such that p - Dq in k if b is True,
1484    or p - Dq does not have an elementary integral over k(t) if b is False.
1485    """
1486    from sympy.integrals.rde import rischDE
1487
1488    t1 = DE.t
1489    dtt = DE.d.exquo(Poly(DE.t, DE.t))
1490    qa = Poly(0, DE.t)
1491    qd = Poly(1, DE.t)
1492    b = True
1493
1494    if p.is_zero:
1495        return(qa, qd, b)
1496
1497    with DecrementLevel(DE):
1498        for i in range(-p.degree(z), p.degree(t1) + 1):
1499            if not i:
1500                continue
1501            elif i < 0:
1502                # If you get AttributeError: 'NoneType' object has no attribute 'nth'
1503                # then this should really not have expand=False
1504                # But it shouldn't happen because p is already a Poly in t and z
1505                a = p.as_poly(z, expand=False).nth(-i)
1506            else:
1507                # If you get AttributeError: 'NoneType' object has no attribute 'nth'
1508                # then this should really not have expand=False
1509                a = p.as_poly(t1, expand=False).nth(i)
1510
1511            aa, ad = frac_in(a, DE.t, field=True)
1512            aa, ad = aa.cancel(ad, include=True)
1513            iDt = Poly(i, t1)*dtt
1514            iDta, iDtd = frac_in(iDt, DE.t, field=True)
1515            try:
1516                va, vd = rischDE(iDta, iDtd, Poly(aa, DE.t), Poly(ad, DE.t), DE)
1517                va, vd = frac_in((va, vd), t1, cancel=True)
1518            except NonElementaryIntegralException:
1519                b = False
1520            else:
1521                qa = qa*vd + va*Poly(t1**i)*qd
1522                qd *= vd
1523
1524    return (qa, qd, b)
1525
1526
1527def integrate_hyperexponential(a, d, DE, z=None, conds='piecewise'):
1528    """
1529    Integration of hyperexponential functions.
1530
1531    Explanation
1532    ===========
1533
1534    Given a hyperexponential monomial t over k and f in k(t), return g
1535    elementary over k(t), i in k(t), and a bool b in {True, False} such that
1536    i = f - Dg is in k if b is True or i = f - Dg does not have an elementary
1537    integral over k(t) if b is False.
1538
1539    This function returns a Basic expression for the first argument.  If b is
1540    True, the second argument is Basic expression in k to recursively integrate.
1541    If b is False, the second argument is an unevaluated Integral, which has
1542    been proven to be nonelementary.
1543    """
1544    # XXX: a and d must be canceled, or this might return incorrect results
1545    z = z or Dummy("z")
1546    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1547
1548    g1, h, r = hermite_reduce(a, d, DE)
1549    g2, b = residue_reduce(h[0], h[1], DE, z=z)
1550    if not b:
1551        i = cancel(a.as_expr()/d.as_expr() - (g1[1]*derivation(g1[0], DE) -
1552            g1[0]*derivation(g1[1], DE)).as_expr()/(g1[1]**2).as_expr() -
1553            residue_reduce_derivation(g2, DE, z))
1554        i = NonElementaryIntegral(cancel(i.subs(s)), DE.x)
1555        return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) +
1556            residue_reduce_to_basic(g2, DE, z), i, b)
1557
1558    # p should be a polynomial in t and 1/t, because Sirr == k[t, 1/t]
1559    # h - Dg2 + r
1560    p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2,
1561        DE, z) + r[0].as_expr()/r[1].as_expr())
1562    pp = as_poly_1t(p, DE.t, z)
1563
1564    qa, qd, b = integrate_hyperexponential_polynomial(pp, DE, z)
1565
1566    i = pp.nth(0, 0)
1567
1568    ret = ((g1[0].as_expr()/g1[1].as_expr()).subs(s) \
1569        + residue_reduce_to_basic(g2, DE, z))
1570
1571    qas = qa.as_expr().subs(s)
1572    qds = qd.as_expr().subs(s)
1573    if conds == 'piecewise' and DE.x not in qds.free_symbols:
1574        # We have to be careful if the exponent is S.Zero!
1575
1576        # XXX: Does qd = 0 always necessarily correspond to the exponential
1577        # equaling 1?
1578        ret += Piecewise(
1579                (qas/qds, Ne(qds, 0)),
1580                (integrate((p - i).subs(DE.t, 1).subs(s), DE.x), True)
1581            )
1582    else:
1583        ret += qas/qds
1584
1585    if not b:
1586        i = p - (qd*derivation(qa, DE) - qa*derivation(qd, DE)).as_expr()/\
1587            (qd**2).as_expr()
1588        i = NonElementaryIntegral(cancel(i).subs(s), DE.x)
1589    return (ret, i, b)
1590
1591
1592def integrate_hypertangent_polynomial(p, DE):
1593    """
1594    Integration of hypertangent polynomials.
1595
1596    Explanation
1597    ===========
1598
1599    Given a differential field k such that sqrt(-1) is not in k, a
1600    hypertangent monomial t over k, and p in k[t], return q in k[t] and
1601    c in k such that p - Dq - c*D(t**2 + 1)/(t**1 + 1) is in k and p -
1602    Dq does not have an elementary integral over k(t) if Dc != 0.
1603    """
1604    # XXX: Make sure that sqrt(-1) is not in k.
1605    q, r = polynomial_reduce(p, DE)
1606    a = DE.d.exquo(Poly(DE.t**2 + 1, DE.t))
1607    c = Poly(r.nth(1)/(2*a.as_expr()), DE.t)
1608    return (q, c)
1609
1610
1611def integrate_nonlinear_no_specials(a, d, DE, z=None):
1612    """
1613    Integration of nonlinear monomials with no specials.
1614
1615    Explanation
1616    ===========
1617
1618    Given a nonlinear monomial t over k such that Sirr ({p in k[t] | p is
1619    special, monic, and irreducible}) is empty, and f in k(t), returns g
1620    elementary over k(t) and a Boolean b in {True, False} such that f - Dg is
1621    in k if b == True, or f - Dg does not have an elementary integral over k(t)
1622    if b == False.
1623
1624    This function is applicable to all nonlinear extensions, but in the case
1625    where it returns b == False, it will only have proven that the integral of
1626    f - Dg is nonelementary if Sirr is empty.
1627
1628    This function returns a Basic expression.
1629    """
1630    # TODO: Integral from k?
1631    # TODO: split out nonelementary integral
1632    # XXX: a and d must be canceled, or this might not return correct results
1633    z = z or Dummy("z")
1634    s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs])))
1635
1636    g1, h, r = hermite_reduce(a, d, DE)
1637    g2, b = residue_reduce(h[0], h[1], DE, z=z)
1638    if not b:
1639        return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) +
1640            residue_reduce_to_basic(g2, DE, z), b)
1641
1642    # Because f has no specials, this should be a polynomial in t, or else
1643    # there is a bug.
1644    p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2,
1645        DE, z).as_expr() + r[0].as_expr()/r[1].as_expr()).as_poly(DE.t)
1646    q1, q2 = polynomial_reduce(p, DE)
1647
1648    if q2.expr.has(DE.t):
1649        b = False
1650    else:
1651        b = True
1652
1653    ret = (cancel(g1[0].as_expr()/g1[1].as_expr() + q1.as_expr()).subs(s) +
1654        residue_reduce_to_basic(g2, DE, z))
1655    return (ret, b)
1656
1657
1658class NonElementaryIntegral(Integral):
1659    """
1660    Represents a nonelementary Integral.
1661
1662    Explanation
1663    ===========
1664
1665    If the result of integrate() is an instance of this class, it is
1666    guaranteed to be nonelementary.  Note that integrate() by default will try
1667    to find any closed-form solution, even in terms of special functions which
1668    may themselves not be elementary.  To make integrate() only give
1669    elementary solutions, or, in the cases where it can prove the integral to
1670    be nonelementary, instances of this class, use integrate(risch=True).
1671    In this case, integrate() may raise NotImplementedError if it cannot make
1672    such a determination.
1673
1674    integrate() uses the deterministic Risch algorithm to integrate elementary
1675    functions or prove that they have no elementary integral.  In some cases,
1676    this algorithm can split an integral into an elementary and nonelementary
1677    part, so that the result of integrate will be the sum of an elementary
1678    expression and a NonElementaryIntegral.
1679
1680    Examples
1681    ========
1682
1683    >>> from sympy import integrate, exp, log, Integral
1684    >>> from sympy.abc import x
1685
1686    >>> a = integrate(exp(-x**2), x, risch=True)
1687    >>> print(a)
1688    Integral(exp(-x**2), x)
1689    >>> type(a)
1690    <class 'sympy.integrals.risch.NonElementaryIntegral'>
1691
1692    >>> expr = (2*log(x)**2 - log(x) - x**2)/(log(x)**3 - x**2*log(x))
1693    >>> b = integrate(expr, x, risch=True)
1694    >>> print(b)
1695    -log(-x + log(x))/2 + log(x + log(x))/2 + Integral(1/log(x), x)
1696    >>> type(b.atoms(Integral).pop())
1697    <class 'sympy.integrals.risch.NonElementaryIntegral'>
1698
1699    """
1700    # TODO: This is useful in and of itself, because isinstance(result,
1701    # NonElementaryIntegral) will tell if the integral has been proven to be
1702    # elementary. But should we do more?  Perhaps a no-op .doit() if
1703    # elementary=True?  Or maybe some information on why the integral is
1704    # nonelementary.
1705    pass
1706
1707
1708def risch_integrate(f, x, extension=None, handle_first='log',
1709                    separate_integral=False, rewrite_complex=None,
1710                    conds='piecewise'):
1711    r"""
1712    The Risch Integration Algorithm.
1713
1714    Explanation
1715    ===========
1716
1717    Only transcendental functions are supported.  Currently, only exponentials
1718    and logarithms are supported, but support for trigonometric functions is
1719    forthcoming.
1720
1721    If this function returns an unevaluated Integral in the result, it means
1722    that it has proven that integral to be nonelementary.  Any errors will
1723    result in raising NotImplementedError.  The unevaluated Integral will be
1724    an instance of NonElementaryIntegral, a subclass of Integral.
1725
1726    handle_first may be either 'exp' or 'log'.  This changes the order in
1727    which the extension is built, and may result in a different (but
1728    equivalent) solution (for an example of this, see issue 5109).  It is also
1729    possible that the integral may be computed with one but not the other,
1730    because not all cases have been implemented yet.  It defaults to 'log' so
1731    that the outer extension is exponential when possible, because more of the
1732    exponential case has been implemented.
1733
1734    If ``separate_integral`` is ``True``, the result is returned as a tuple (ans, i),
1735    where the integral is ans + i, ans is elementary, and i is either a
1736    NonElementaryIntegral or 0.  This useful if you want to try further
1737    integrating the NonElementaryIntegral part using other algorithms to
1738    possibly get a solution in terms of special functions.  It is False by
1739    default.
1740
1741    Examples
1742    ========
1743
1744    >>> from sympy.integrals.risch import risch_integrate
1745    >>> from sympy import exp, log, pprint
1746    >>> from sympy.abc import x
1747
1748    First, we try integrating exp(-x**2). Except for a constant factor of
1749    2/sqrt(pi), this is the famous error function.
1750
1751    >>> pprint(risch_integrate(exp(-x**2), x))
1752      /
1753     |
1754     |    2
1755     |  -x
1756     | e    dx
1757     |
1758    /
1759
1760    The unevaluated Integral in the result means that risch_integrate() has
1761    proven that exp(-x**2) does not have an elementary anti-derivative.
1762
1763    In many cases, risch_integrate() can split out the elementary
1764    anti-derivative part from the nonelementary anti-derivative part.
1765    For example,
1766
1767    >>> pprint(risch_integrate((2*log(x)**2 - log(x) - x**2)/(log(x)**3 -
1768    ... x**2*log(x)), x))
1769                                             /
1770                                            |
1771      log(-x + log(x))   log(x + log(x))    |   1
1772    - ---------------- + --------------- +  | ------ dx
1773             2                  2           | log(x)
1774                                            |
1775                                           /
1776
1777    This means that it has proven that the integral of 1/log(x) is
1778    nonelementary.  This function is also known as the logarithmic integral,
1779    and is often denoted as Li(x).
1780
1781    risch_integrate() currently only accepts purely transcendental functions
1782    with exponentials and logarithms, though note that this can include
1783    nested exponentials and logarithms, as well as exponentials with bases
1784    other than E.
1785
1786    >>> pprint(risch_integrate(exp(x)*exp(exp(x)), x))
1787     / x\
1788     \e /
1789    e
1790    >>> pprint(risch_integrate(exp(exp(x)), x))
1791      /
1792     |
1793     |  / x\
1794     |  \e /
1795     | e     dx
1796     |
1797    /
1798
1799    >>> pprint(risch_integrate(x*x**x*log(x) + x**x + x*x**x, x))
1800       x
1801    x*x
1802    >>> pprint(risch_integrate(x**x, x))
1803      /
1804     |
1805     |  x
1806     | x  dx
1807     |
1808    /
1809
1810    >>> pprint(risch_integrate(-1/(x*log(x)*log(log(x))**2), x))
1811         1
1812    -----------
1813    log(log(x))
1814
1815    """
1816    f = S(f)
1817
1818    DE = extension or DifferentialExtension(f, x, handle_first=handle_first,
1819            dummy=True, rewrite_complex=rewrite_complex)
1820    fa, fd = DE.fa, DE.fd
1821
1822    result = S.Zero
1823    for case in reversed(DE.cases):
1824        if not fa.expr.has(DE.t) and not fd.expr.has(DE.t) and not case == 'base':
1825            DE.decrement_level()
1826            fa, fd = frac_in((fa, fd), DE.t)
1827            continue
1828
1829        fa, fd = fa.cancel(fd, include=True)
1830        if case == 'exp':
1831            ans, i, b = integrate_hyperexponential(fa, fd, DE, conds=conds)
1832        elif case == 'primitive':
1833            ans, i, b = integrate_primitive(fa, fd, DE)
1834        elif case == 'base':
1835            # XXX: We can't call ratint() directly here because it doesn't
1836            # handle polynomials correctly.
1837            ans = integrate(fa.as_expr()/fd.as_expr(), DE.x, risch=False)
1838            b = False
1839            i = S.Zero
1840        else:
1841            raise NotImplementedError("Only exponential and logarithmic "
1842            "extensions are currently supported.")
1843
1844        result += ans
1845        if b:
1846            DE.decrement_level()
1847            fa, fd = frac_in(i, DE.t)
1848        else:
1849            result = result.subs(DE.backsubs)
1850            if not i.is_zero:
1851                i = NonElementaryIntegral(i.function.subs(DE.backsubs),i.limits)
1852            if not separate_integral:
1853                result += i
1854                return result
1855            else:
1856
1857                if isinstance(i, NonElementaryIntegral):
1858                    return (result, i)
1859                else:
1860                    return (result, 0)
1861