1from __future__ import annotations
2
3import functools
4from itertools import permutations
5
6from ..core import Add, Basic, Dummy, E, Eq, Integer, Mul, Wild, pi, sympify
7from ..functions import (Ei, LambertW, Piecewise, acosh, asin, asinh, atan,
8                         binomial, cos, cosh, cot, coth, erf, erfi, exp, li,
9                         log, root, sin, sinh, sqrt, tan, tanh)
10from ..logic import And
11from ..polys import PolynomialError, cancel, factor, gcd, lcm, quo
12from ..polys.constructor import construct_domain
13from ..polys.monomials import itermonomials
14from ..polys.polyroots import root_factors
15from ..polys.solvers import solve_lin_sys
16from ..utilities import ordered
17from ..utilities.iterables import uniq
18
19
20def components(f, x):
21    """
22    Returns a set of all functional components of the given expression
23    which includes symbols, function applications and compositions and
24    non-integer powers. Fractional powers are collected with with
25    minimal, positive exponents.
26
27    >>> components(sin(x)*cos(x)**2, x)
28    {x, sin(x), cos(x)}
29
30    See Also
31    ========
32
33    heurisch
34
35    """
36    result = set()
37
38    if x in f.free_symbols:
39        if f.is_Symbol:
40            result.add(f)
41        elif f.is_Function or f.is_Derivative:
42            for g in f.args:
43                result |= components(g, x)
44
45            result.add(f)
46        elif f.is_Pow:
47            result |= components(f.base, x)
48
49            if not f.exp.is_Integer:
50                if f.exp.is_Rational:
51                    result.add(root(f.base, f.exp.denominator))
52                else:
53                    result |= components(f.exp, x) | {f}
54        else:
55            for g in f.args:
56                result |= components(g, x)
57
58    return result
59
60
61# name -> [] of symbols
62_symbols_cache: dict[str, list[Dummy]] = {}
63
64
65# NB @cacheit is not convenient here
66def _symbols(name, n):
67    """Get vector of symbols local to this module."""
68    try:
69        lsyms = _symbols_cache[name]
70    except KeyError:
71        lsyms = []
72        _symbols_cache[name] = lsyms
73
74    while len(lsyms) < n:
75        lsyms.append(Dummy(f'{name}{len(lsyms):d}'))
76
77    return lsyms[:n]
78
79
80def heurisch_wrapper(f, x, rewrite=False, hints=None, mappings=None, retries=3,
81                     degree_offset=0, unnecessary_permutations=None):
82    """
83    A wrapper around the heurisch integration algorithm.
84
85    This method takes the result from heurisch and checks for poles in the
86    denominator. For each of these poles, the integral is reevaluated, and
87    the final integration result is given in terms of a Piecewise.
88
89    Examples
90    ========
91
92    >>> heurisch(cos(n*x), x)
93    sin(n*x)/n
94    >>> heurisch_wrapper(cos(n*x), x)
95    Piecewise((x, Eq(n, 0)), (sin(n*x)/n, true))
96
97    See Also
98    ========
99
100    heurisch
101
102    """
103    from ..solvers.solvers import denoms, solve
104    f = sympify(f)
105    if x not in f.free_symbols:
106        return f*x
107
108    res = heurisch(f, x, rewrite, hints, mappings, retries, degree_offset,
109                   unnecessary_permutations)
110    if not isinstance(res, Basic):
111        return res
112    # We consider each denominator in the expression, and try to find
113    # cases where one or more symbolic denominator might be zero. The
114    # conditions for these cases are stored in the list slns.
115    slns = []
116    for d in denoms(res):
117        ds = list(ordered(d.free_symbols - {x}))
118        if ds:
119            slns += solve(d, *ds)
120    if not slns:
121        return res
122    slns = list(uniq(slns))
123    # Remove the solutions corresponding to poles in the original expression.
124    slns0 = []
125    for d in denoms(f):
126        ds = list(ordered(d.free_symbols - {x}))
127        if ds:
128            slns0 += solve(d, *ds)
129    slns = [s for s in slns if s not in slns0]
130    if not slns:
131        return res
132    if len(slns) > 1:
133        eqs = []
134        for sub_dict in slns:
135            eqs.extend([Eq(key, value) for key, value in sub_dict.items()])
136        slns = solve(eqs, *ordered(set().union(*[e.free_symbols
137                                                 for e in eqs]) - {x})) + slns
138    # For each case listed in the list slns, we reevaluate the integral.
139    pairs = []
140    for sub_dict in slns:
141        expr = heurisch(f.subs(sub_dict), x, rewrite, hints, mappings, retries,
142                        degree_offset, unnecessary_permutations)
143        cond = And(*[Eq(key, value) for key, value in sub_dict.items()])
144        pairs.append((expr, cond))
145    pairs.append((heurisch(f, x, rewrite, hints, mappings, retries,
146                           degree_offset, unnecessary_permutations), True))
147    return Piecewise(*pairs)
148
149
150def heurisch(f, x, rewrite=False, hints=None, mappings=None, retries=3,
151             degree_offset=0, unnecessary_permutations=None):
152    """
153    Compute indefinite integral using heuristic Risch algorithm.
154
155    This is a heuristic approach to indefinite integration in finite
156    terms using the extended heuristic (parallel) Risch algorithm, based
157    on Manuel Bronstein's "Poor Man's Integrator".
158
159    The algorithm supports various classes of functions including
160    transcendental elementary or special functions like Airy,
161    Bessel, Whittaker and Lambert.
162
163    Note that this algorithm is not a decision procedure. If it isn't
164    able to compute the antiderivative for a given function, then this is
165    not a proof that such a functions does not exist.  One should use
166    recursive Risch algorithm in such case.  It's an open question if
167    this algorithm can be made a full decision procedure.
168
169    This is an internal integrator procedure. You should use toplevel
170    'integrate' function in most cases,  as this procedure needs some
171    preprocessing steps and otherwise may fail.
172
173    Parameters
174    ==========
175
176    f : Expr
177        expression
178    x : Symbol
179        variable
180
181    rewrite : Boolean, optional
182        force rewrite 'f' in terms of 'tan' and 'tanh', default False.
183    hints : None or list
184        a list of functions that may appear in anti-derivate.  If
185        None (default) - no suggestions at all, if empty list - try
186        to figure out.
187
188    Examples
189    ========
190
191    >>> heurisch(y*tan(x), x)
192    y*log(tan(x)**2 + 1)/2
193
194    References
195    ==========
196
197    * :cite:`Bronstein2005pmint`
198
199    See Also
200    ========
201
202    diofant.integrals.integrals.Integral.doit
203    diofant.integrals.integrals.Integral
204    components
205
206    """
207    f = sympify(f)
208    if x not in f.free_symbols:
209        return f*x
210
211    if not f.is_Add:
212        indep, f = f.as_independent(x)
213    else:
214        indep = Integer(1)
215
216    rewritables = {
217        (sin, cos, cot): tan,
218        (sinh, cosh, coth): tanh,
219    }
220
221    if rewrite:
222        for candidates, rule in rewritables.items():
223            f = f.rewrite(candidates, rule)
224    else:
225        for candidates in rewritables:
226            if f.has(*candidates):
227                break
228        else:
229            rewrite = True
230
231    terms = components(f, x)
232
233    if hints is not None:
234        if not hints:
235            a = Wild('a', exclude=[x])
236            b = Wild('b', exclude=[x])
237            c = Wild('c', exclude=[x])
238
239            for g in set(terms):  # using copy of terms
240                if g.is_Function:
241                    if isinstance(g, li):
242                        M = g.args[0].match(a*x**b)
243
244                        if M is not None:
245                            terms.add(x*(li(M[a]*x**M[b]) - (M[a]*x**M[b])**(-1/M[b])*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b])))
246
247                elif g.is_Pow:
248                    if g.base is E:
249                        M = g.exp.match(a*x**2)
250
251                        if M is not None:
252                            if M[a].is_positive:
253                                terms.add(erfi(sqrt(M[a])*x))
254                            else:  # M[a].is_negative or unknown
255                                terms.add(erf(sqrt(-M[a])*x))
256
257                        M = g.exp.match(a*x**2 + b*x + c)
258
259                        if M is not None:
260                            if M[a].is_positive:
261                                terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a])) *
262                                          erfi(sqrt(M[a])*x + M[b]/(2*sqrt(M[a]))))
263                            elif M[a].is_negative:
264                                terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a])) *
265                                          erf(sqrt(-M[a])*x - M[b]/(2*sqrt(-M[a]))))
266
267                        M = g.exp.match(a*log(x)**2)
268
269                        if M is not None:
270                            if M[a].is_positive:
271                                terms.add(erfi(sqrt(M[a])*log(x) + 1/(2*sqrt(M[a]))))
272                            if M[a].is_negative:
273                                terms.add(erf(sqrt(-M[a])*log(x) - 1/(2*sqrt(-M[a]))))
274
275                    elif g.exp.is_Rational and g.exp.denominator == 2:
276                        M = g.base.match(a*x**2 + b)
277
278                        if M is not None and M[b].is_positive:
279                            if M[a].is_positive:
280                                terms.add(asinh(sqrt(M[a]/M[b])*x))
281                            elif M[a].is_negative:
282                                terms.add(asin(sqrt(-M[a]/M[b])*x))
283
284                        M = g.base.match(a*x**2 - b)
285
286                        if M is not None and M[b].is_positive:
287                            if M[a].is_positive:
288                                terms.add(acosh(sqrt(M[a]/M[b])*x))
289                            elif M[a].is_negative:
290                                terms.add((-M[b]/2*sqrt(-M[a]) *
291                                           atan(sqrt(-M[a])*x/sqrt(M[a]*x**2 - M[b]))))
292
293        else:
294            terms |= set(hints)
295
296    for g in set(terms):  # using copy of terms
297        terms |= components(cancel(g.diff(x)), x)
298
299    # TODO: caching is significant factor for why permutations work at all. Change this.
300    V = _symbols('x', len(terms))
301
302    # sort mapping expressions from largest to smallest (last is always x).
303    mapping = list(reversed(list(zip(*ordered(
304        [(a[0].as_independent(x)[1], a) for a in zip(terms, V)])))[1]))
305    rev_mapping = {v: k for k, v in mapping}
306    if mappings is None:
307        # optimizing the number of permutations of mapping
308        assert mapping[-1][0] == x  # if not, find it and correct this comment
309        unnecessary_permutations = [mapping.pop(-1)]
310        mappings = permutations(mapping)
311    else:
312        unnecessary_permutations = unnecessary_permutations or []
313
314    def _substitute(expr):
315        return expr.subs(mapping)
316
317    for mapping in mappings:
318        mapping = list(mapping)
319        mapping = mapping + unnecessary_permutations
320        diffs = [_substitute(cancel(g.diff(x))) for g in terms]
321        denoms = [g.as_numer_denom()[1] for g in diffs]
322        if all(h.is_polynomial(*V) for h in denoms) and _substitute(f).is_rational_function(*V):
323            denom = functools.reduce(lambda p, q: lcm(p, q, *V), denoms)
324            break
325    else:
326        if not rewrite:
327            result = heurisch(f, x, rewrite=True, hints=hints,
328                              unnecessary_permutations=unnecessary_permutations)
329
330            if result is not None:
331                return indep*result
332        return
333
334    numers = [cancel(denom*g) for g in diffs]
335
336    def _derivation(h):
337        return Add(*[d * h.diff(v) for d, v in zip(numers, V)])
338
339    def _deflation(p):
340        for y in V:
341            if not p.has(y):
342                continue
343
344            if _derivation(p) != 0:
345                c, q = p.as_poly(y).primitive()
346                return _deflation(c)*gcd(q, q.diff(y)).as_expr()
347
348        return p
349
350    def _splitter(p):
351        for y in V:
352            if not p.has(y):
353                continue
354
355            if _derivation(y) != 0:
356                c, q = p.as_poly(y).primitive()
357
358                q = q.as_expr()
359
360                h = gcd(q, _derivation(q), y)
361                s = quo(h, gcd(q, q.diff(y), y), y)
362
363                c_split = _splitter(c)
364
365                if s.as_poly(y).degree() == 0:
366                    return c_split[0], q*c_split[1]
367
368                q_split = _splitter(cancel(q / s))
369
370                return c_split[0]*q_split[0]*s, c_split[1]*q_split[1]
371
372        return Integer(1), p
373
374    special = {}
375
376    for term in terms:
377        if term.is_Function:
378            if isinstance(term, tan):
379                special[1 + _substitute(term)**2] = False
380            elif isinstance(term, tanh):
381                special[1 + _substitute(term)] = False
382                special[1 - _substitute(term)] = False
383            elif isinstance(term, LambertW):
384                special[_substitute(term)] = True
385
386    F = _substitute(f)
387
388    P, Q = F.as_numer_denom()
389
390    u_split = _splitter(denom)
391    v_split = _splitter(Q)
392
393    polys = set(list(v_split) + [u_split[0]] + list(special))
394
395    s = u_split[0] * Mul(*[k for k, v in special.items() if v])
396    polified = [p.as_poly(*V) for p in [s, P, Q]]
397
398    if None in polified:
399        return
400
401    # --- definitions for _integrate ---
402    a, b, c = [p.total_degree() for p in polified]
403
404    poly_denom = (s * v_split[0] * _deflation(v_split[1])).as_expr()
405
406    def _exponent(g):
407        if g.is_Pow:
408            if g.exp.is_Rational and g.exp.denominator != 1:
409                if g.exp.numerator > 0:
410                    return g.exp.numerator + g.exp.denominator - 1
411                else:
412                    return abs(g.exp.numerator + g.exp.denominator)
413            else:
414                return 1
415        elif not g.is_Atom and g.args:
416            return max(_exponent(h) for h in g.args)
417        else:
418            return 1
419
420    A, B = _exponent(f), a + max(b, c)
421
422    degree = A + B + degree_offset
423    if A > 1 and B > 1:
424        degree -= 1
425
426    monoms = itermonomials(V, degree)
427    poly_coeffs = _symbols('A', binomial(len(V) + degree, degree))
428    poly_part = Add(*[poly_coeffs[i]*monomial
429                      for i, monomial in enumerate(ordered(monoms))])
430
431    reducibles = set()
432
433    for poly in polys:
434        if poly.has(*V):
435            try:
436                factorization = factor(poly, greedy=True)
437            except PolynomialError:
438                factorization = poly
439            factorization = poly
440
441            if factorization.is_Mul:
442                reducibles |= set(factorization.args)
443            else:
444                reducibles.add(factorization)
445
446    def _integrate(field=None):
447        irreducibles = set()
448
449        for poly in reducibles:
450            for z in poly.free_symbols:
451                if z in V:
452                    break  # should this be: `irreducibles |= \
453            else:          # set(root_factors(poly, z, filter=field))`
454                continue   # and the line below deleted?
455                #                          |
456                #                          V
457            irreducibles |= set(root_factors(poly, z, filter=field))
458
459        log_part = []
460        B = _symbols('B', len(irreducibles))
461
462        # Note: the ordering matters here
463        for poly, b in reversed(list(ordered(zip(irreducibles, B)))):
464            if poly.has(*V):
465                poly_coeffs.append(b)
466                log_part.append(b * log(poly))
467
468        # TODO: Currently it's better to use symbolic expressions here instead
469        # of rational functions, because it's simpler and FracElement doesn't
470        # give big speed improvement yet. This is because cancellation is slow
471        # due to slow polynomial GCD algorithms. If this gets improved then
472        # revise this code.
473        candidate = poly_part/poly_denom + Add(*log_part)
474        h = F - _derivation(candidate) / denom
475        raw_numer = h.as_numer_denom()[0]
476
477        # Rewrite raw_numer as a polynomial in K[coeffs][V] where K is a field
478        # that we have to determine. We can't use simply atoms() because log(3),
479        # sqrt(y) and similar expressions can appear, leading to non-trivial
480        # domains.
481        syms = set(poly_coeffs) | set(V)
482        non_syms = set()
483
484        def find_non_syms(expr):
485            if expr.is_Integer or expr.is_Rational:
486                pass  # ignore trivial numbers
487            elif expr in syms:
488                pass  # ignore variables
489            elif not expr.has(*syms):
490                non_syms.add(expr)
491            elif expr.is_Add or expr.is_Mul or expr.is_Pow:
492                list(map(find_non_syms, expr.args))
493            else:
494                # TODO: Non-polynomial expression. This should have been
495                # filtered out at an earlier stage.
496                raise PolynomialError
497
498        try:
499            find_non_syms(raw_numer)
500        except PolynomialError:
501            return
502        else:
503            ground, _ = construct_domain(non_syms, field=True)
504
505        coeff_ring = ground.poly_ring(*poly_coeffs)
506        ring = coeff_ring.poly_ring(*V)
507
508        try:
509            numer = ring.from_expr(raw_numer)
510        except ValueError:
511            raise PolynomialError
512
513        solution = solve_lin_sys(numer.values(), coeff_ring)
514
515        if solution is not None:
516            solution = [(coeff_ring.symbols[coeff_ring.index(k)],
517                         coeff_ring.to_expr(v)) for k, v in solution.items()]
518            return candidate.subs(solution).subs(
519                list(zip(poly_coeffs, [Integer(0)]*len(poly_coeffs))))
520
521    if not F.free_symbols - set(V):
522        solution = _integrate('Q')
523
524        if solution is None:
525            solution = _integrate()
526    else:
527        solution = _integrate()
528
529    if solution is not None:
530        antideriv = solution.subs(rev_mapping)
531        antideriv = cancel(antideriv).expand()
532
533        if antideriv.is_Add:
534            antideriv = antideriv.as_independent(x)[1]
535
536        return indep*antideriv
537    else:
538        if retries >= 0:
539            result = heurisch(f, x, mappings=mappings, rewrite=rewrite, hints=hints, retries=retries - 1, unnecessary_permutations=unnecessary_permutations)
540
541            if result is not None:
542                return indep*result
543