1from __future__ import annotations 2 3import functools 4from itertools import permutations 5 6from ..core import Add, Basic, Dummy, E, Eq, Integer, Mul, Wild, pi, sympify 7from ..functions import (Ei, LambertW, Piecewise, acosh, asin, asinh, atan, 8 binomial, cos, cosh, cot, coth, erf, erfi, exp, li, 9 log, root, sin, sinh, sqrt, tan, tanh) 10from ..logic import And 11from ..polys import PolynomialError, cancel, factor, gcd, lcm, quo 12from ..polys.constructor import construct_domain 13from ..polys.monomials import itermonomials 14from ..polys.polyroots import root_factors 15from ..polys.solvers import solve_lin_sys 16from ..utilities import ordered 17from ..utilities.iterables import uniq 18 19 20def components(f, x): 21 """ 22 Returns a set of all functional components of the given expression 23 which includes symbols, function applications and compositions and 24 non-integer powers. Fractional powers are collected with with 25 minimal, positive exponents. 26 27 >>> components(sin(x)*cos(x)**2, x) 28 {x, sin(x), cos(x)} 29 30 See Also 31 ======== 32 33 heurisch 34 35 """ 36 result = set() 37 38 if x in f.free_symbols: 39 if f.is_Symbol: 40 result.add(f) 41 elif f.is_Function or f.is_Derivative: 42 for g in f.args: 43 result |= components(g, x) 44 45 result.add(f) 46 elif f.is_Pow: 47 result |= components(f.base, x) 48 49 if not f.exp.is_Integer: 50 if f.exp.is_Rational: 51 result.add(root(f.base, f.exp.denominator)) 52 else: 53 result |= components(f.exp, x) | {f} 54 else: 55 for g in f.args: 56 result |= components(g, x) 57 58 return result 59 60 61# name -> [] of symbols 62_symbols_cache: dict[str, list[Dummy]] = {} 63 64 65# NB @cacheit is not convenient here 66def _symbols(name, n): 67 """Get vector of symbols local to this module.""" 68 try: 69 lsyms = _symbols_cache[name] 70 except KeyError: 71 lsyms = [] 72 _symbols_cache[name] = lsyms 73 74 while len(lsyms) < n: 75 lsyms.append(Dummy(f'{name}{len(lsyms):d}')) 76 77 return lsyms[:n] 78 79 80def heurisch_wrapper(f, x, rewrite=False, hints=None, mappings=None, retries=3, 81 degree_offset=0, unnecessary_permutations=None): 82 """ 83 A wrapper around the heurisch integration algorithm. 84 85 This method takes the result from heurisch and checks for poles in the 86 denominator. For each of these poles, the integral is reevaluated, and 87 the final integration result is given in terms of a Piecewise. 88 89 Examples 90 ======== 91 92 >>> heurisch(cos(n*x), x) 93 sin(n*x)/n 94 >>> heurisch_wrapper(cos(n*x), x) 95 Piecewise((x, Eq(n, 0)), (sin(n*x)/n, true)) 96 97 See Also 98 ======== 99 100 heurisch 101 102 """ 103 from ..solvers.solvers import denoms, solve 104 f = sympify(f) 105 if x not in f.free_symbols: 106 return f*x 107 108 res = heurisch(f, x, rewrite, hints, mappings, retries, degree_offset, 109 unnecessary_permutations) 110 if not isinstance(res, Basic): 111 return res 112 # We consider each denominator in the expression, and try to find 113 # cases where one or more symbolic denominator might be zero. The 114 # conditions for these cases are stored in the list slns. 115 slns = [] 116 for d in denoms(res): 117 ds = list(ordered(d.free_symbols - {x})) 118 if ds: 119 slns += solve(d, *ds) 120 if not slns: 121 return res 122 slns = list(uniq(slns)) 123 # Remove the solutions corresponding to poles in the original expression. 124 slns0 = [] 125 for d in denoms(f): 126 ds = list(ordered(d.free_symbols - {x})) 127 if ds: 128 slns0 += solve(d, *ds) 129 slns = [s for s in slns if s not in slns0] 130 if not slns: 131 return res 132 if len(slns) > 1: 133 eqs = [] 134 for sub_dict in slns: 135 eqs.extend([Eq(key, value) for key, value in sub_dict.items()]) 136 slns = solve(eqs, *ordered(set().union(*[e.free_symbols 137 for e in eqs]) - {x})) + slns 138 # For each case listed in the list slns, we reevaluate the integral. 139 pairs = [] 140 for sub_dict in slns: 141 expr = heurisch(f.subs(sub_dict), x, rewrite, hints, mappings, retries, 142 degree_offset, unnecessary_permutations) 143 cond = And(*[Eq(key, value) for key, value in sub_dict.items()]) 144 pairs.append((expr, cond)) 145 pairs.append((heurisch(f, x, rewrite, hints, mappings, retries, 146 degree_offset, unnecessary_permutations), True)) 147 return Piecewise(*pairs) 148 149 150def heurisch(f, x, rewrite=False, hints=None, mappings=None, retries=3, 151 degree_offset=0, unnecessary_permutations=None): 152 """ 153 Compute indefinite integral using heuristic Risch algorithm. 154 155 This is a heuristic approach to indefinite integration in finite 156 terms using the extended heuristic (parallel) Risch algorithm, based 157 on Manuel Bronstein's "Poor Man's Integrator". 158 159 The algorithm supports various classes of functions including 160 transcendental elementary or special functions like Airy, 161 Bessel, Whittaker and Lambert. 162 163 Note that this algorithm is not a decision procedure. If it isn't 164 able to compute the antiderivative for a given function, then this is 165 not a proof that such a functions does not exist. One should use 166 recursive Risch algorithm in such case. It's an open question if 167 this algorithm can be made a full decision procedure. 168 169 This is an internal integrator procedure. You should use toplevel 170 'integrate' function in most cases, as this procedure needs some 171 preprocessing steps and otherwise may fail. 172 173 Parameters 174 ========== 175 176 f : Expr 177 expression 178 x : Symbol 179 variable 180 181 rewrite : Boolean, optional 182 force rewrite 'f' in terms of 'tan' and 'tanh', default False. 183 hints : None or list 184 a list of functions that may appear in anti-derivate. If 185 None (default) - no suggestions at all, if empty list - try 186 to figure out. 187 188 Examples 189 ======== 190 191 >>> heurisch(y*tan(x), x) 192 y*log(tan(x)**2 + 1)/2 193 194 References 195 ========== 196 197 * :cite:`Bronstein2005pmint` 198 199 See Also 200 ======== 201 202 diofant.integrals.integrals.Integral.doit 203 diofant.integrals.integrals.Integral 204 components 205 206 """ 207 f = sympify(f) 208 if x not in f.free_symbols: 209 return f*x 210 211 if not f.is_Add: 212 indep, f = f.as_independent(x) 213 else: 214 indep = Integer(1) 215 216 rewritables = { 217 (sin, cos, cot): tan, 218 (sinh, cosh, coth): tanh, 219 } 220 221 if rewrite: 222 for candidates, rule in rewritables.items(): 223 f = f.rewrite(candidates, rule) 224 else: 225 for candidates in rewritables: 226 if f.has(*candidates): 227 break 228 else: 229 rewrite = True 230 231 terms = components(f, x) 232 233 if hints is not None: 234 if not hints: 235 a = Wild('a', exclude=[x]) 236 b = Wild('b', exclude=[x]) 237 c = Wild('c', exclude=[x]) 238 239 for g in set(terms): # using copy of terms 240 if g.is_Function: 241 if isinstance(g, li): 242 M = g.args[0].match(a*x**b) 243 244 if M is not None: 245 terms.add(x*(li(M[a]*x**M[b]) - (M[a]*x**M[b])**(-1/M[b])*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b]))) 246 247 elif g.is_Pow: 248 if g.base is E: 249 M = g.exp.match(a*x**2) 250 251 if M is not None: 252 if M[a].is_positive: 253 terms.add(erfi(sqrt(M[a])*x)) 254 else: # M[a].is_negative or unknown 255 terms.add(erf(sqrt(-M[a])*x)) 256 257 M = g.exp.match(a*x**2 + b*x + c) 258 259 if M is not None: 260 if M[a].is_positive: 261 terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a])) * 262 erfi(sqrt(M[a])*x + M[b]/(2*sqrt(M[a])))) 263 elif M[a].is_negative: 264 terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a])) * 265 erf(sqrt(-M[a])*x - M[b]/(2*sqrt(-M[a])))) 266 267 M = g.exp.match(a*log(x)**2) 268 269 if M is not None: 270 if M[a].is_positive: 271 terms.add(erfi(sqrt(M[a])*log(x) + 1/(2*sqrt(M[a])))) 272 if M[a].is_negative: 273 terms.add(erf(sqrt(-M[a])*log(x) - 1/(2*sqrt(-M[a])))) 274 275 elif g.exp.is_Rational and g.exp.denominator == 2: 276 M = g.base.match(a*x**2 + b) 277 278 if M is not None and M[b].is_positive: 279 if M[a].is_positive: 280 terms.add(asinh(sqrt(M[a]/M[b])*x)) 281 elif M[a].is_negative: 282 terms.add(asin(sqrt(-M[a]/M[b])*x)) 283 284 M = g.base.match(a*x**2 - b) 285 286 if M is not None and M[b].is_positive: 287 if M[a].is_positive: 288 terms.add(acosh(sqrt(M[a]/M[b])*x)) 289 elif M[a].is_negative: 290 terms.add((-M[b]/2*sqrt(-M[a]) * 291 atan(sqrt(-M[a])*x/sqrt(M[a]*x**2 - M[b])))) 292 293 else: 294 terms |= set(hints) 295 296 for g in set(terms): # using copy of terms 297 terms |= components(cancel(g.diff(x)), x) 298 299 # TODO: caching is significant factor for why permutations work at all. Change this. 300 V = _symbols('x', len(terms)) 301 302 # sort mapping expressions from largest to smallest (last is always x). 303 mapping = list(reversed(list(zip(*ordered( 304 [(a[0].as_independent(x)[1], a) for a in zip(terms, V)])))[1])) 305 rev_mapping = {v: k for k, v in mapping} 306 if mappings is None: 307 # optimizing the number of permutations of mapping 308 assert mapping[-1][0] == x # if not, find it and correct this comment 309 unnecessary_permutations = [mapping.pop(-1)] 310 mappings = permutations(mapping) 311 else: 312 unnecessary_permutations = unnecessary_permutations or [] 313 314 def _substitute(expr): 315 return expr.subs(mapping) 316 317 for mapping in mappings: 318 mapping = list(mapping) 319 mapping = mapping + unnecessary_permutations 320 diffs = [_substitute(cancel(g.diff(x))) for g in terms] 321 denoms = [g.as_numer_denom()[1] for g in diffs] 322 if all(h.is_polynomial(*V) for h in denoms) and _substitute(f).is_rational_function(*V): 323 denom = functools.reduce(lambda p, q: lcm(p, q, *V), denoms) 324 break 325 else: 326 if not rewrite: 327 result = heurisch(f, x, rewrite=True, hints=hints, 328 unnecessary_permutations=unnecessary_permutations) 329 330 if result is not None: 331 return indep*result 332 return 333 334 numers = [cancel(denom*g) for g in diffs] 335 336 def _derivation(h): 337 return Add(*[d * h.diff(v) for d, v in zip(numers, V)]) 338 339 def _deflation(p): 340 for y in V: 341 if not p.has(y): 342 continue 343 344 if _derivation(p) != 0: 345 c, q = p.as_poly(y).primitive() 346 return _deflation(c)*gcd(q, q.diff(y)).as_expr() 347 348 return p 349 350 def _splitter(p): 351 for y in V: 352 if not p.has(y): 353 continue 354 355 if _derivation(y) != 0: 356 c, q = p.as_poly(y).primitive() 357 358 q = q.as_expr() 359 360 h = gcd(q, _derivation(q), y) 361 s = quo(h, gcd(q, q.diff(y), y), y) 362 363 c_split = _splitter(c) 364 365 if s.as_poly(y).degree() == 0: 366 return c_split[0], q*c_split[1] 367 368 q_split = _splitter(cancel(q / s)) 369 370 return c_split[0]*q_split[0]*s, c_split[1]*q_split[1] 371 372 return Integer(1), p 373 374 special = {} 375 376 for term in terms: 377 if term.is_Function: 378 if isinstance(term, tan): 379 special[1 + _substitute(term)**2] = False 380 elif isinstance(term, tanh): 381 special[1 + _substitute(term)] = False 382 special[1 - _substitute(term)] = False 383 elif isinstance(term, LambertW): 384 special[_substitute(term)] = True 385 386 F = _substitute(f) 387 388 P, Q = F.as_numer_denom() 389 390 u_split = _splitter(denom) 391 v_split = _splitter(Q) 392 393 polys = set(list(v_split) + [u_split[0]] + list(special)) 394 395 s = u_split[0] * Mul(*[k for k, v in special.items() if v]) 396 polified = [p.as_poly(*V) for p in [s, P, Q]] 397 398 if None in polified: 399 return 400 401 # --- definitions for _integrate --- 402 a, b, c = [p.total_degree() for p in polified] 403 404 poly_denom = (s * v_split[0] * _deflation(v_split[1])).as_expr() 405 406 def _exponent(g): 407 if g.is_Pow: 408 if g.exp.is_Rational and g.exp.denominator != 1: 409 if g.exp.numerator > 0: 410 return g.exp.numerator + g.exp.denominator - 1 411 else: 412 return abs(g.exp.numerator + g.exp.denominator) 413 else: 414 return 1 415 elif not g.is_Atom and g.args: 416 return max(_exponent(h) for h in g.args) 417 else: 418 return 1 419 420 A, B = _exponent(f), a + max(b, c) 421 422 degree = A + B + degree_offset 423 if A > 1 and B > 1: 424 degree -= 1 425 426 monoms = itermonomials(V, degree) 427 poly_coeffs = _symbols('A', binomial(len(V) + degree, degree)) 428 poly_part = Add(*[poly_coeffs[i]*monomial 429 for i, monomial in enumerate(ordered(monoms))]) 430 431 reducibles = set() 432 433 for poly in polys: 434 if poly.has(*V): 435 try: 436 factorization = factor(poly, greedy=True) 437 except PolynomialError: 438 factorization = poly 439 factorization = poly 440 441 if factorization.is_Mul: 442 reducibles |= set(factorization.args) 443 else: 444 reducibles.add(factorization) 445 446 def _integrate(field=None): 447 irreducibles = set() 448 449 for poly in reducibles: 450 for z in poly.free_symbols: 451 if z in V: 452 break # should this be: `irreducibles |= \ 453 else: # set(root_factors(poly, z, filter=field))` 454 continue # and the line below deleted? 455 # | 456 # V 457 irreducibles |= set(root_factors(poly, z, filter=field)) 458 459 log_part = [] 460 B = _symbols('B', len(irreducibles)) 461 462 # Note: the ordering matters here 463 for poly, b in reversed(list(ordered(zip(irreducibles, B)))): 464 if poly.has(*V): 465 poly_coeffs.append(b) 466 log_part.append(b * log(poly)) 467 468 # TODO: Currently it's better to use symbolic expressions here instead 469 # of rational functions, because it's simpler and FracElement doesn't 470 # give big speed improvement yet. This is because cancellation is slow 471 # due to slow polynomial GCD algorithms. If this gets improved then 472 # revise this code. 473 candidate = poly_part/poly_denom + Add(*log_part) 474 h = F - _derivation(candidate) / denom 475 raw_numer = h.as_numer_denom()[0] 476 477 # Rewrite raw_numer as a polynomial in K[coeffs][V] where K is a field 478 # that we have to determine. We can't use simply atoms() because log(3), 479 # sqrt(y) and similar expressions can appear, leading to non-trivial 480 # domains. 481 syms = set(poly_coeffs) | set(V) 482 non_syms = set() 483 484 def find_non_syms(expr): 485 if expr.is_Integer or expr.is_Rational: 486 pass # ignore trivial numbers 487 elif expr in syms: 488 pass # ignore variables 489 elif not expr.has(*syms): 490 non_syms.add(expr) 491 elif expr.is_Add or expr.is_Mul or expr.is_Pow: 492 list(map(find_non_syms, expr.args)) 493 else: 494 # TODO: Non-polynomial expression. This should have been 495 # filtered out at an earlier stage. 496 raise PolynomialError 497 498 try: 499 find_non_syms(raw_numer) 500 except PolynomialError: 501 return 502 else: 503 ground, _ = construct_domain(non_syms, field=True) 504 505 coeff_ring = ground.poly_ring(*poly_coeffs) 506 ring = coeff_ring.poly_ring(*V) 507 508 try: 509 numer = ring.from_expr(raw_numer) 510 except ValueError: 511 raise PolynomialError 512 513 solution = solve_lin_sys(numer.values(), coeff_ring) 514 515 if solution is not None: 516 solution = [(coeff_ring.symbols[coeff_ring.index(k)], 517 coeff_ring.to_expr(v)) for k, v in solution.items()] 518 return candidate.subs(solution).subs( 519 list(zip(poly_coeffs, [Integer(0)]*len(poly_coeffs)))) 520 521 if not F.free_symbols - set(V): 522 solution = _integrate('Q') 523 524 if solution is None: 525 solution = _integrate() 526 else: 527 solution = _integrate() 528 529 if solution is not None: 530 antideriv = solution.subs(rev_mapping) 531 antideriv = cancel(antideriv).expand() 532 533 if antideriv.is_Add: 534 antideriv = antideriv.as_independent(x)[1] 535 536 return indep*antideriv 537 else: 538 if retries >= 0: 539 result = heurisch(f, x, mappings=mappings, rewrite=rewrite, hints=hints, retries=retries - 1, unnecessary_permutations=unnecessary_permutations) 540 541 if result is not None: 542 return indep*result 543