1from collections import defaultdict 2 3from sympy import SYMPY_DEBUG 4 5from sympy.core import expand_power_base, sympify, Add, S, Mul, Derivative, Pow, symbols, expand_mul 6from sympy.core.add import _unevaluated_Add 7from sympy.core.compatibility import iterable, ordered, default_sort_key 8from sympy.core.parameters import global_parameters 9from sympy.core.exprtools import Factors, gcd_terms 10from sympy.core.function import _mexpand 11from sympy.core.mul import _keep_coeff, _unevaluated_Mul 12from sympy.core.numbers import Rational, zoo, nan 13from sympy.functions import exp, sqrt, log 14from sympy.functions.elementary.complexes import Abs 15from sympy.polys import gcd 16from sympy.simplify.sqrtdenest import sqrtdenest 17 18 19 20 21def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True): 22 """ 23 Collect additive terms of an expression. 24 25 Explanation 26 =========== 27 28 This function collects additive terms of an expression with respect 29 to a list of expression up to powers with rational exponents. By the 30 term symbol here are meant arbitrary expressions, which can contain 31 powers, products, sums etc. In other words symbol is a pattern which 32 will be searched for in the expression's terms. 33 34 The input expression is not expanded by :func:`collect`, so user is 35 expected to provide an expression in an appropriate form. This makes 36 :func:`collect` more predictable as there is no magic happening behind the 37 scenes. However, it is important to note, that powers of products are 38 converted to products of powers using the :func:`~.expand_power_base` 39 function. 40 41 There are two possible types of output. First, if ``evaluate`` flag is 42 set, this function will return an expression with collected terms or 43 else it will return a dictionary with expressions up to rational powers 44 as keys and collected coefficients as values. 45 46 Examples 47 ======== 48 49 >>> from sympy import S, collect, expand, factor, Wild 50 >>> from sympy.abc import a, b, c, x, y 51 52 This function can collect symbolic coefficients in polynomials or 53 rational expressions. It will manage to find all integer or rational 54 powers of collection variable:: 55 56 >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x) 57 c + x**2*(a + b) + x*(a - b) 58 59 The same result can be achieved in dictionary form:: 60 61 >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False) 62 >>> d[x**2] 63 a + b 64 >>> d[x] 65 a - b 66 >>> d[S.One] 67 c 68 69 You can also work with multivariate polynomials. However, remember that 70 this function is greedy so it will care only about a single symbol at time, 71 in specification order:: 72 73 >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y]) 74 x**2*(y + 1) + x*y + y*(a + 1) 75 76 Also more complicated expressions can be used as patterns:: 77 78 >>> from sympy import sin, log 79 >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x)) 80 (a + b)*sin(2*x) 81 82 >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x)) 83 x*(a + b)*log(x) 84 85 You can use wildcards in the pattern:: 86 87 >>> w = Wild('w1') 88 >>> collect(a*x**y - b*x**y, w**y) 89 x**y*(a - b) 90 91 It is also possible to work with symbolic powers, although it has more 92 complicated behavior, because in this case power's base and symbolic part 93 of the exponent are treated as a single symbol:: 94 95 >>> collect(a*x**c + b*x**c, x) 96 a*x**c + b*x**c 97 >>> collect(a*x**c + b*x**c, x**c) 98 x**c*(a + b) 99 100 However if you incorporate rationals to the exponents, then you will get 101 well known behavior:: 102 103 >>> collect(a*x**(2*c) + b*x**(2*c), x**c) 104 x**(2*c)*(a + b) 105 106 Note also that all previously stated facts about :func:`collect` function 107 apply to the exponential function, so you can get:: 108 109 >>> from sympy import exp 110 >>> collect(a*exp(2*x) + b*exp(2*x), exp(x)) 111 (a + b)*exp(2*x) 112 113 If you are interested only in collecting specific powers of some symbols 114 then set ``exact`` flag in arguments:: 115 116 >>> collect(a*x**7 + b*x**7, x, exact=True) 117 a*x**7 + b*x**7 118 >>> collect(a*x**7 + b*x**7, x**7, exact=True) 119 x**7*(a + b) 120 121 You can also apply this function to differential equations, where 122 derivatives of arbitrary order can be collected. Note that if you 123 collect with respect to a function or a derivative of a function, all 124 derivatives of that function will also be collected. Use 125 ``exact=True`` to prevent this from happening:: 126 127 >>> from sympy import Derivative as D, collect, Function 128 >>> f = Function('f') (x) 129 130 >>> collect(a*D(f,x) + b*D(f,x), D(f,x)) 131 (a + b)*Derivative(f(x), x) 132 133 >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f) 134 (a + b)*Derivative(f(x), (x, 2)) 135 136 >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True) 137 a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2)) 138 139 >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f) 140 (a + b)*f(x) + (a + b)*Derivative(f(x), x) 141 142 Or you can even match both derivative order and exponent at the same time:: 143 144 >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x)) 145 (a + b)*Derivative(f(x), (x, 2))**2 146 147 Finally, you can apply a function to each of the collected coefficients. 148 For example you can factorize symbolic coefficients of polynomial:: 149 150 >>> f = expand((x + a + 1)**3) 151 152 >>> collect(f, x, factor) 153 x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3 154 155 .. note:: Arguments are expected to be in expanded form, so you might have 156 to call :func:`~.expand` prior to calling this function. 157 158 See Also 159 ======== 160 161 collect_const, collect_sqrt, rcollect 162 """ 163 from sympy.core.assumptions import assumptions 164 from sympy.utilities.iterables import sift 165 from sympy.core.symbol import Dummy, Wild 166 expr = sympify(expr) 167 syms = [sympify(i) for i in (syms if iterable(syms) else [syms])] 168 # replace syms[i] if it is not x, -x or has Wild symbols 169 cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool( 170 x.atoms(Wild)) 171 _, nonsyms = sift(syms, cond, binary=True) 172 if nonsyms: 173 reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms])) 174 syms = [reps.get(s, s) for s in syms] 175 rv = collect(expr.subs(reps), syms, 176 func=func, evaluate=evaluate, exact=exact, 177 distribute_order_term=distribute_order_term) 178 urep = {v: k for k, v in reps.items()} 179 if not isinstance(rv, dict): 180 return rv.xreplace(urep) 181 else: 182 return {urep.get(k, k).xreplace(urep): v.xreplace(urep) 183 for k, v in rv.items()} 184 185 if evaluate is None: 186 evaluate = global_parameters.evaluate 187 188 def make_expression(terms): 189 product = [] 190 191 for term, rat, sym, deriv in terms: 192 if deriv is not None: 193 var, order = deriv 194 195 while order > 0: 196 term, order = Derivative(term, var), order - 1 197 198 if sym is None: 199 if rat is S.One: 200 product.append(term) 201 else: 202 product.append(Pow(term, rat)) 203 else: 204 product.append(Pow(term, rat*sym)) 205 206 return Mul(*product) 207 208 def parse_derivative(deriv): 209 # scan derivatives tower in the input expression and return 210 # underlying function and maximal differentiation order 211 expr, sym, order = deriv.expr, deriv.variables[0], 1 212 213 for s in deriv.variables[1:]: 214 if s == sym: 215 order += 1 216 else: 217 raise NotImplementedError( 218 'Improve MV Derivative support in collect') 219 220 while isinstance(expr, Derivative): 221 s0 = expr.variables[0] 222 223 for s in expr.variables: 224 if s != s0: 225 raise NotImplementedError( 226 'Improve MV Derivative support in collect') 227 228 if s0 == sym: 229 expr, order = expr.expr, order + len(expr.variables) 230 else: 231 break 232 233 return expr, (sym, Rational(order)) 234 235 def parse_term(expr): 236 """Parses expression expr and outputs tuple (sexpr, rat_expo, 237 sym_expo, deriv) 238 where: 239 - sexpr is the base expression 240 - rat_expo is the rational exponent that sexpr is raised to 241 - sym_expo is the symbolic exponent that sexpr is raised to 242 - deriv contains the derivatives the the expression 243 244 For example, the output of x would be (x, 1, None, None) 245 the output of 2**x would be (2, 1, x, None). 246 """ 247 rat_expo, sym_expo = S.One, None 248 sexpr, deriv = expr, None 249 250 if expr.is_Pow: 251 if isinstance(expr.base, Derivative): 252 sexpr, deriv = parse_derivative(expr.base) 253 else: 254 sexpr = expr.base 255 256 if expr.base == S.Exp1: 257 arg = expr.exp 258 if arg.is_Rational: 259 sexpr, rat_expo = S.Exp1, arg 260 elif arg.is_Mul: 261 coeff, tail = arg.as_coeff_Mul(rational=True) 262 sexpr, rat_expo = exp(tail), coeff 263 264 elif expr.exp.is_Number: 265 rat_expo = expr.exp 266 else: 267 coeff, tail = expr.exp.as_coeff_Mul() 268 269 if coeff.is_Number: 270 rat_expo, sym_expo = coeff, tail 271 else: 272 sym_expo = expr.exp 273 elif isinstance(expr, exp): 274 arg = expr.exp 275 if arg.is_Rational: 276 sexpr, rat_expo = S.Exp1, arg 277 elif arg.is_Mul: 278 coeff, tail = arg.as_coeff_Mul(rational=True) 279 sexpr, rat_expo = exp(tail), coeff 280 elif isinstance(expr, Derivative): 281 sexpr, deriv = parse_derivative(expr) 282 283 return sexpr, rat_expo, sym_expo, deriv 284 285 def parse_expression(terms, pattern): 286 """Parse terms searching for a pattern. 287 Terms is a list of tuples as returned by parse_terms; 288 Pattern is an expression treated as a product of factors. 289 """ 290 pattern = Mul.make_args(pattern) 291 292 if len(terms) < len(pattern): 293 # pattern is longer than matched product 294 # so no chance for positive parsing result 295 return None 296 else: 297 pattern = [parse_term(elem) for elem in pattern] 298 299 terms = terms[:] # need a copy 300 elems, common_expo, has_deriv = [], None, False 301 302 for elem, e_rat, e_sym, e_ord in pattern: 303 304 if elem.is_Number and e_rat == 1 and e_sym is None: 305 # a constant is a match for everything 306 continue 307 308 for j in range(len(terms)): 309 if terms[j] is None: 310 continue 311 312 term, t_rat, t_sym, t_ord = terms[j] 313 314 # keeping track of whether one of the terms had 315 # a derivative or not as this will require rebuilding 316 # the expression later 317 if t_ord is not None: 318 has_deriv = True 319 320 if (term.match(elem) is not None and 321 (t_sym == e_sym or t_sym is not None and 322 e_sym is not None and 323 t_sym.match(e_sym) is not None)): 324 if exact is False: 325 # we don't have to be exact so find common exponent 326 # for both expression's term and pattern's element 327 expo = t_rat / e_rat 328 329 if common_expo is None: 330 # first time 331 common_expo = expo 332 else: 333 # common exponent was negotiated before so 334 # there is no chance for a pattern match unless 335 # common and current exponents are equal 336 if common_expo != expo: 337 common_expo = 1 338 else: 339 # we ought to be exact so all fields of 340 # interest must match in every details 341 if e_rat != t_rat or e_ord != t_ord: 342 continue 343 344 # found common term so remove it from the expression 345 # and try to match next element in the pattern 346 elems.append(terms[j]) 347 terms[j] = None 348 349 break 350 351 else: 352 # pattern element not found 353 return None 354 355 return [_f for _f in terms if _f], elems, common_expo, has_deriv 356 357 if evaluate: 358 if expr.is_Add: 359 o = expr.getO() or 0 360 expr = expr.func(*[ 361 collect(a, syms, func, True, exact, distribute_order_term) 362 for a in expr.args if a != o]) + o 363 elif expr.is_Mul: 364 return expr.func(*[ 365 collect(term, syms, func, True, exact, distribute_order_term) 366 for term in expr.args]) 367 elif expr.is_Pow: 368 b = collect( 369 expr.base, syms, func, True, exact, distribute_order_term) 370 return Pow(b, expr.exp) 371 372 syms = [expand_power_base(i, deep=False) for i in syms] 373 374 order_term = None 375 376 if distribute_order_term: 377 order_term = expr.getO() 378 379 if order_term is not None: 380 if order_term.has(*syms): 381 order_term = None 382 else: 383 expr = expr.removeO() 384 385 summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)] 386 387 collected, disliked = defaultdict(list), S.Zero 388 for product in summa: 389 c, nc = product.args_cnc(split_1=False) 390 args = list(ordered(c)) + nc 391 terms = [parse_term(i) for i in args] 392 small_first = True 393 394 for symbol in syms: 395 if SYMPY_DEBUG: 396 print("DEBUG: parsing of expression %s with symbol %s " % ( 397 str(terms), str(symbol)) 398 ) 399 400 if isinstance(symbol, Derivative) and small_first: 401 terms = list(reversed(terms)) 402 small_first = not small_first 403 result = parse_expression(terms, symbol) 404 405 if SYMPY_DEBUG: 406 print("DEBUG: returned %s" % str(result)) 407 408 if result is not None: 409 if not symbol.is_commutative: 410 raise AttributeError("Can not collect noncommutative symbol") 411 412 terms, elems, common_expo, has_deriv = result 413 414 # when there was derivative in current pattern we 415 # will need to rebuild its expression from scratch 416 if not has_deriv: 417 margs = [] 418 for elem in elems: 419 if elem[2] is None: 420 e = elem[1] 421 else: 422 e = elem[1]*elem[2] 423 margs.append(Pow(elem[0], e)) 424 index = Mul(*margs) 425 else: 426 index = make_expression(elems) 427 terms = expand_power_base(make_expression(terms), deep=False) 428 index = expand_power_base(index, deep=False) 429 collected[index].append(terms) 430 break 431 else: 432 # none of the patterns matched 433 disliked += product 434 # add terms now for each key 435 collected = {k: Add(*v) for k, v in collected.items()} 436 437 if disliked is not S.Zero: 438 collected[S.One] = disliked 439 440 if order_term is not None: 441 for key, val in collected.items(): 442 collected[key] = val + order_term 443 444 if func is not None: 445 collected = { 446 key: func(val) for key, val in collected.items()} 447 448 if evaluate: 449 return Add(*[key*val for key, val in collected.items()]) 450 else: 451 return collected 452 453 454def rcollect(expr, *vars): 455 """ 456 Recursively collect sums in an expression. 457 458 Examples 459 ======== 460 461 >>> from sympy.simplify import rcollect 462 >>> from sympy.abc import x, y 463 464 >>> expr = (x**2*y + x*y + x + y)/(x + y) 465 466 >>> rcollect(expr, y) 467 (x + y*(x**2 + x + 1))/(x + y) 468 469 See Also 470 ======== 471 472 collect, collect_const, collect_sqrt 473 """ 474 if expr.is_Atom or not expr.has(*vars): 475 return expr 476 else: 477 expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args]) 478 479 if expr.is_Add: 480 return collect(expr, vars) 481 else: 482 return expr 483 484 485def collect_sqrt(expr, evaluate=None): 486 """Return expr with terms having common square roots collected together. 487 If ``evaluate`` is False a count indicating the number of sqrt-containing 488 terms will be returned and, if non-zero, the terms of the Add will be 489 returned, else the expression itself will be returned as a single term. 490 If ``evaluate`` is True, the expression with any collected terms will be 491 returned. 492 493 Note: since I = sqrt(-1), it is collected, too. 494 495 Examples 496 ======== 497 498 >>> from sympy import sqrt 499 >>> from sympy.simplify.radsimp import collect_sqrt 500 >>> from sympy.abc import a, b 501 502 >>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]] 503 >>> collect_sqrt(a*r2 + b*r2) 504 sqrt(2)*(a + b) 505 >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3) 506 sqrt(2)*(a + b) + sqrt(3)*(a + b) 507 >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5) 508 sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b) 509 510 If evaluate is False then the arguments will be sorted and 511 returned as a list and a count of the number of sqrt-containing 512 terms will be returned: 513 514 >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False) 515 ((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3) 516 >>> collect_sqrt(a*sqrt(2) + b, evaluate=False) 517 ((b, sqrt(2)*a), 1) 518 >>> collect_sqrt(a + b, evaluate=False) 519 ((a + b,), 0) 520 521 See Also 522 ======== 523 524 collect, collect_const, rcollect 525 """ 526 if evaluate is None: 527 evaluate = global_parameters.evaluate 528 # this step will help to standardize any complex arguments 529 # of sqrts 530 coeff, expr = expr.as_content_primitive() 531 vars = set() 532 for a in Add.make_args(expr): 533 for m in a.args_cnc()[0]: 534 if m.is_number and ( 535 m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or 536 m is S.ImaginaryUnit): 537 vars.add(m) 538 539 # we only want radicals, so exclude Number handling; in this case 540 # d will be evaluated 541 d = collect_const(expr, *vars, Numbers=False) 542 hit = expr != d 543 544 if not evaluate: 545 nrad = 0 546 # make the evaluated args canonical 547 args = list(ordered(Add.make_args(d))) 548 for i, m in enumerate(args): 549 c, nc = m.args_cnc() 550 for ci in c: 551 # XXX should this be restricted to ci.is_number as above? 552 if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \ 553 ci is S.ImaginaryUnit: 554 nrad += 1 555 break 556 args[i] *= coeff 557 if not (hit or nrad): 558 args = [Add(*args)] 559 return tuple(args), nrad 560 561 return coeff*d 562 563 564def collect_abs(expr): 565 """Return ``expr`` with arguments of multiple Abs in a term collected 566 under a single instance. 567 568 Examples 569 ======== 570 571 >>> from sympy.simplify.radsimp import collect_abs 572 >>> from sympy.abc import x 573 >>> collect_abs(abs(x + 1)/abs(x**2 - 1)) 574 Abs((x + 1)/(x**2 - 1)) 575 >>> collect_abs(abs(1/x)) 576 Abs(1/x) 577 """ 578 def _abs(mul): 579 from sympy.core.mul import _mulsort 580 c, nc = mul.args_cnc() 581 a = [] 582 o = [] 583 for i in c: 584 if isinstance(i, Abs): 585 a.append(i.args[0]) 586 elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real: 587 a.append(i.base.args[0]**i.exp) 588 else: 589 o.append(i) 590 if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)): 591 return mul 592 absarg = Mul(*a) 593 A = Abs(absarg) 594 args = [A] 595 args.extend(o) 596 if not A.has(Abs): 597 args.extend(nc) 598 return Mul(*args) 599 if not isinstance(A, Abs): 600 # reevaluate and make it unevaluated 601 A = Abs(absarg, evaluate=False) 602 args[0] = A 603 _mulsort(args) 604 args.extend(nc) # nc always go last 605 return Mul._from_args(args, is_commutative=not nc) 606 607 return expr.replace( 608 lambda x: isinstance(x, Mul), 609 lambda x: _abs(x)).replace( 610 lambda x: isinstance(x, Pow), 611 lambda x: _abs(x)) 612 613 614def collect_const(expr, *vars, Numbers=True): 615 """A non-greedy collection of terms with similar number coefficients in 616 an Add expr. If ``vars`` is given then only those constants will be 617 targeted. Although any Number can also be targeted, if this is not 618 desired set ``Numbers=False`` and no Float or Rational will be collected. 619 620 Parameters 621 ========== 622 623 expr : sympy expression 624 This parameter defines the expression the expression from which 625 terms with similar coefficients are to be collected. A non-Add 626 expression is returned as it is. 627 628 vars : variable length collection of Numbers, optional 629 Specifies the constants to target for collection. Can be multiple in 630 number. 631 632 Numbers : bool 633 Specifies to target all instance of 634 :class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then 635 no Float or Rational will be collected. 636 637 Returns 638 ======= 639 640 expr : Expr 641 Returns an expression with similar coefficient terms collected. 642 643 Examples 644 ======== 645 646 >>> from sympy import sqrt 647 >>> from sympy.abc import s, x, y, z 648 >>> from sympy.simplify.radsimp import collect_const 649 >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2))) 650 sqrt(3)*(sqrt(2) + 2) 651 >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7)) 652 (sqrt(3) + sqrt(7))*(s + 1) 653 >>> s = sqrt(2) + 2 654 >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7)) 655 (sqrt(2) + 3)*(sqrt(3) + sqrt(7)) 656 >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3)) 657 sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2) 658 659 The collection is sign-sensitive, giving higher precedence to the 660 unsigned values: 661 662 >>> collect_const(x - y - z) 663 x - (y + z) 664 >>> collect_const(-y - z) 665 -(y + z) 666 >>> collect_const(2*x - 2*y - 2*z, 2) 667 2*(x - y - z) 668 >>> collect_const(2*x - 2*y - 2*z, -2) 669 2*x - 2*(y + z) 670 671 See Also 672 ======== 673 674 collect, collect_sqrt, rcollect 675 """ 676 if not expr.is_Add: 677 return expr 678 679 recurse = False 680 681 if not vars: 682 recurse = True 683 vars = set() 684 for a in expr.args: 685 for m in Mul.make_args(a): 686 if m.is_number: 687 vars.add(m) 688 else: 689 vars = sympify(vars) 690 if not Numbers: 691 vars = [v for v in vars if not v.is_Number] 692 693 vars = list(ordered(vars)) 694 for v in vars: 695 terms = defaultdict(list) 696 Fv = Factors(v) 697 for m in Add.make_args(expr): 698 f = Factors(m) 699 q, r = f.div(Fv) 700 if r.is_one: 701 # only accept this as a true factor if 702 # it didn't change an exponent from an Integer 703 # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2) 704 # -- we aren't looking for this sort of change 705 fwas = f.factors.copy() 706 fnow = q.factors 707 if not any(k in fwas and fwas[k].is_Integer and not 708 fnow[k].is_Integer for k in fnow): 709 terms[v].append(q.as_expr()) 710 continue 711 terms[S.One].append(m) 712 713 args = [] 714 hit = False 715 uneval = False 716 for k in ordered(terms): 717 v = terms[k] 718 if k is S.One: 719 args.extend(v) 720 continue 721 722 if len(v) > 1: 723 v = Add(*v) 724 hit = True 725 if recurse and v != expr: 726 vars.append(v) 727 else: 728 v = v[0] 729 730 # be careful not to let uneval become True unless 731 # it must be because it's going to be more expensive 732 # to rebuild the expression as an unevaluated one 733 if Numbers and k.is_Number and v.is_Add: 734 args.append(_keep_coeff(k, v, sign=True)) 735 uneval = True 736 else: 737 args.append(k*v) 738 739 if hit: 740 if uneval: 741 expr = _unevaluated_Add(*args) 742 else: 743 expr = Add(*args) 744 if not expr.is_Add: 745 break 746 747 return expr 748 749 750def radsimp(expr, symbolic=True, max_terms=4): 751 r""" 752 Rationalize the denominator by removing square roots. 753 754 Explanation 755 =========== 756 757 The expression returned from radsimp must be used with caution 758 since if the denominator contains symbols, it will be possible to make 759 substitutions that violate the assumptions of the simplification process: 760 that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If 761 there are no symbols, this assumptions is made valid by collecting terms 762 of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If 763 you do not want the simplification to occur for symbolic denominators, set 764 ``symbolic`` to False. 765 766 If there are more than ``max_terms`` radical terms then the expression is 767 returned unchanged. 768 769 Examples 770 ======== 771 772 >>> from sympy import radsimp, sqrt, Symbol, pprint 773 >>> from sympy import factor_terms, fraction, signsimp 774 >>> from sympy.simplify.radsimp import collect_sqrt 775 >>> from sympy.abc import a, b, c 776 777 >>> radsimp(1/(2 + sqrt(2))) 778 (2 - sqrt(2))/2 779 >>> x,y = map(Symbol, 'xy') 780 >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2)) 781 >>> radsimp(e) 782 sqrt(2)*(x + y) 783 784 No simplification beyond removal of the gcd is done. One might 785 want to polish the result a little, however, by collecting 786 square root terms: 787 788 >>> r2 = sqrt(2) 789 >>> r5 = sqrt(5) 790 >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans) 791 ___ ___ ___ ___ 792 \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y 793 ------------------------------------------ 794 2 2 2 2 795 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y 796 797 >>> n, d = fraction(ans) 798 >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True)) 799 ___ ___ 800 \/ 5 *(a + b) - \/ 2 *(x + y) 801 ------------------------------------------ 802 2 2 2 2 803 5*a + 10*a*b + 5*b - 2*x - 4*x*y - 2*y 804 805 If radicals in the denominator cannot be removed or there is no denominator, 806 the original expression will be returned. 807 808 >>> radsimp(sqrt(2)*x + sqrt(2)) 809 sqrt(2)*x + sqrt(2) 810 811 Results with symbols will not always be valid for all substitutions: 812 813 >>> eq = 1/(a + b*sqrt(c)) 814 >>> eq.subs(a, b*sqrt(c)) 815 1/(2*b*sqrt(c)) 816 >>> radsimp(eq).subs(a, b*sqrt(c)) 817 nan 818 819 If ``symbolic=False``, symbolic denominators will not be transformed (but 820 numeric denominators will still be processed): 821 822 >>> radsimp(eq, symbolic=False) 823 1/(a + b*sqrt(c)) 824 825 """ 826 from sympy.simplify.simplify import signsimp 827 828 syms = symbols("a:d A:D") 829 def _num(rterms): 830 # return the multiplier that will simplify the expression described 831 # by rterms [(sqrt arg, coeff), ... ] 832 a, b, c, d, A, B, C, D = syms 833 if len(rterms) == 2: 834 reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i]))) 835 return ( 836 sqrt(A)*a - sqrt(B)*b).xreplace(reps) 837 if len(rterms) == 3: 838 reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i]))) 839 return ( 840 (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 - 841 B*b**2 + C*c**2)).xreplace(reps) 842 elif len(rterms) == 4: 843 reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i]))) 844 return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b 845 - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 + 846 D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 - 847 2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 - 848 2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 + 849 D**2*d**4)).xreplace(reps) 850 elif len(rterms) == 1: 851 return sqrt(rterms[0][0]) 852 else: 853 raise NotImplementedError 854 855 def ispow2(d, log2=False): 856 if not d.is_Pow: 857 return False 858 e = d.exp 859 if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2: 860 return True 861 if log2: 862 q = 1 863 if e.is_Rational: 864 q = e.q 865 elif symbolic: 866 d = denom(e) 867 if d.is_Integer: 868 q = d 869 if q != 1 and log(q, 2).is_Integer: 870 return True 871 return False 872 873 def handle(expr): 874 # Handle first reduces to the case 875 # expr = 1/d, where d is an add, or d is base**p/2. 876 # We do this by recursively calling handle on each piece. 877 from sympy.simplify.simplify import nsimplify 878 879 n, d = fraction(expr) 880 881 if expr.is_Atom or (d.is_Atom and n.is_Atom): 882 return expr 883 elif not n.is_Atom: 884 n = n.func(*[handle(a) for a in n.args]) 885 return _unevaluated_Mul(n, handle(1/d)) 886 elif n is not S.One: 887 return _unevaluated_Mul(n, handle(1/d)) 888 elif d.is_Mul: 889 return _unevaluated_Mul(*[handle(1/d) for d in d.args]) 890 891 # By this step, expr is 1/d, and d is not a mul. 892 if not symbolic and d.free_symbols: 893 return expr 894 895 if ispow2(d): 896 d2 = sqrtdenest(sqrt(d.base))**numer(d.exp) 897 if d2 != d: 898 return handle(1/d2) 899 elif d.is_Pow and (d.exp.is_integer or d.base.is_positive): 900 # (1/d**i) = (1/d)**i 901 return handle(1/d.base)**d.exp 902 903 if not (d.is_Add or ispow2(d)): 904 return 1/d.func(*[handle(a) for a in d.args]) 905 906 # handle 1/d treating d as an Add (though it may not be) 907 908 keep = True # keep changes that are made 909 910 # flatten it and collect radicals after checking for special 911 # conditions 912 d = _mexpand(d) 913 914 # did it change? 915 if d.is_Atom: 916 return 1/d 917 918 # is it a number that might be handled easily? 919 if d.is_number: 920 _d = nsimplify(d) 921 if _d.is_Number and _d.equals(d): 922 return 1/_d 923 924 while True: 925 # collect similar terms 926 collected = defaultdict(list) 927 for m in Add.make_args(d): # d might have become non-Add 928 p2 = [] 929 other = [] 930 for i in Mul.make_args(m): 931 if ispow2(i, log2=True): 932 p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp)) 933 elif i is S.ImaginaryUnit: 934 p2.append(S.NegativeOne) 935 else: 936 other.append(i) 937 collected[tuple(ordered(p2))].append(Mul(*other)) 938 rterms = list(ordered(list(collected.items()))) 939 rterms = [(Mul(*i), Add(*j)) for i, j in rterms] 940 nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0) 941 if nrad < 1: 942 break 943 elif nrad > max_terms: 944 # there may have been invalid operations leading to this point 945 # so don't keep changes, e.g. this expression is troublesome 946 # in collecting terms so as not to raise the issue of 2834: 947 # r = sqrt(sqrt(5) + 5) 948 # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r) 949 keep = False 950 break 951 if len(rterms) > 4: 952 # in general, only 4 terms can be removed with repeated squaring 953 # but other considerations can guide selection of radical terms 954 # so that radicals are removed 955 if all([x.is_Integer and (y**2).is_Rational for x, y in rterms]): 956 nd, d = rad_rationalize(S.One, Add._from_args( 957 [sqrt(x)*y for x, y in rterms])) 958 n *= nd 959 else: 960 # is there anything else that might be attempted? 961 keep = False 962 break 963 from sympy.simplify.powsimp import powsimp, powdenest 964 965 num = powsimp(_num(rterms)) 966 n *= num 967 d *= num 968 d = powdenest(_mexpand(d), force=symbolic) 969 if d.has(S.Zero, nan, zoo): 970 return expr 971 if d.is_Atom: 972 break 973 974 if not keep: 975 return expr 976 return _unevaluated_Mul(n, 1/d) 977 978 coeff, expr = expr.as_coeff_Add() 979 expr = expr.normal() 980 old = fraction(expr) 981 n, d = fraction(handle(expr)) 982 if old != (n, d): 983 if not d.is_Atom: 984 was = (n, d) 985 n = signsimp(n, evaluate=False) 986 d = signsimp(d, evaluate=False) 987 u = Factors(_unevaluated_Mul(n, 1/d)) 988 u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()]) 989 n, d = fraction(u) 990 if old == (n, d): 991 n, d = was 992 n = expand_mul(n) 993 if d.is_Number or d.is_Add: 994 n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d))) 995 if d2.is_Number or (d2.count_ops() <= d.count_ops()): 996 n, d = [signsimp(i) for i in (n2, d2)] 997 if n.is_Mul and n.args[0].is_Number: 998 n = n.func(*n.args) 999 1000 return coeff + _unevaluated_Mul(n, 1/d) 1001 1002 1003def rad_rationalize(num, den): 1004 """ 1005 Rationalize ``num/den`` by removing square roots in the denominator; 1006 num and den are sum of terms whose squares are positive rationals. 1007 1008 Examples 1009 ======== 1010 1011 >>> from sympy import sqrt 1012 >>> from sympy.simplify.radsimp import rad_rationalize 1013 >>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3) 1014 (-sqrt(3) + sqrt(6)/3, -7/9) 1015 """ 1016 if not den.is_Add: 1017 return num, den 1018 g, a, b = split_surds(den) 1019 a = a*sqrt(g) 1020 num = _mexpand((a - b)*num) 1021 den = _mexpand(a**2 - b**2) 1022 return rad_rationalize(num, den) 1023 1024 1025def fraction(expr, exact=False): 1026 """Returns a pair with expression's numerator and denominator. 1027 If the given expression is not a fraction then this function 1028 will return the tuple (expr, 1). 1029 1030 This function will not make any attempt to simplify nested 1031 fractions or to do any term rewriting at all. 1032 1033 If only one of the numerator/denominator pair is needed then 1034 use numer(expr) or denom(expr) functions respectively. 1035 1036 >>> from sympy import fraction, Rational, Symbol 1037 >>> from sympy.abc import x, y 1038 1039 >>> fraction(x/y) 1040 (x, y) 1041 >>> fraction(x) 1042 (x, 1) 1043 1044 >>> fraction(1/y**2) 1045 (1, y**2) 1046 1047 >>> fraction(x*y/2) 1048 (x*y, 2) 1049 >>> fraction(Rational(1, 2)) 1050 (1, 2) 1051 1052 This function will also work fine with assumptions: 1053 1054 >>> k = Symbol('k', negative=True) 1055 >>> fraction(x * y**k) 1056 (x, y**(-k)) 1057 1058 If we know nothing about sign of some exponent and ``exact`` 1059 flag is unset, then structure this exponent's structure will 1060 be analyzed and pretty fraction will be returned: 1061 1062 >>> from sympy import exp, Mul 1063 >>> fraction(2*x**(-y)) 1064 (2, x**y) 1065 1066 >>> fraction(exp(-x)) 1067 (1, exp(x)) 1068 1069 >>> fraction(exp(-x), exact=True) 1070 (exp(-x), 1) 1071 1072 The ``exact`` flag will also keep any unevaluated Muls from 1073 being evaluated: 1074 1075 >>> u = Mul(2, x + 1, evaluate=False) 1076 >>> fraction(u) 1077 (2*x + 2, 1) 1078 >>> fraction(u, exact=True) 1079 (2*(x + 1), 1) 1080 """ 1081 expr = sympify(expr) 1082 1083 numer, denom = [], [] 1084 1085 for term in Mul.make_args(expr): 1086 if term.is_commutative and (term.is_Pow or isinstance(term, exp)): 1087 b, ex = term.as_base_exp() 1088 if ex.is_negative: 1089 if ex is S.NegativeOne: 1090 denom.append(b) 1091 elif exact: 1092 if ex.is_constant(): 1093 denom.append(Pow(b, -ex)) 1094 else: 1095 numer.append(term) 1096 else: 1097 denom.append(Pow(b, -ex)) 1098 elif ex.is_positive: 1099 numer.append(term) 1100 elif not exact and ex.is_Mul: 1101 n, d = term.as_numer_denom() 1102 if n != 1: 1103 numer.append(n) 1104 denom.append(d) 1105 else: 1106 numer.append(term) 1107 elif term.is_Rational and not term.is_Integer: 1108 if term.p != 1: 1109 numer.append(term.p) 1110 denom.append(term.q) 1111 else: 1112 numer.append(term) 1113 return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact) 1114 1115 1116def numer(expr): 1117 return fraction(expr)[0] 1118 1119 1120def denom(expr): 1121 return fraction(expr)[1] 1122 1123 1124def fraction_expand(expr, **hints): 1125 return expr.expand(frac=True, **hints) 1126 1127 1128def numer_expand(expr, **hints): 1129 a, b = fraction(expr) 1130 return a.expand(numer=True, **hints) / b 1131 1132 1133def denom_expand(expr, **hints): 1134 a, b = fraction(expr) 1135 return a / b.expand(denom=True, **hints) 1136 1137 1138expand_numer = numer_expand 1139expand_denom = denom_expand 1140expand_fraction = fraction_expand 1141 1142 1143def split_surds(expr): 1144 """ 1145 Split an expression with terms whose squares are positive rationals 1146 into a sum of terms whose surds squared have gcd equal to g 1147 and a sum of terms with surds squared prime with g. 1148 1149 Examples 1150 ======== 1151 1152 >>> from sympy import sqrt 1153 >>> from sympy.simplify.radsimp import split_surds 1154 >>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15)) 1155 (3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10)) 1156 """ 1157 args = sorted(expr.args, key=default_sort_key) 1158 coeff_muls = [x.as_coeff_Mul() for x in args] 1159 surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow] 1160 surds.sort(key=default_sort_key) 1161 g, b1, b2 = _split_gcd(*surds) 1162 g2 = g 1163 if not b2 and len(b1) >= 2: 1164 b1n = [x/g for x in b1] 1165 b1n = [x for x in b1n if x != 1] 1166 # only a common factor has been factored; split again 1167 g1, b1n, b2 = _split_gcd(*b1n) 1168 g2 = g*g1 1169 a1v, a2v = [], [] 1170 for c, s in coeff_muls: 1171 if s.is_Pow and s.exp == S.Half: 1172 s1 = s.base 1173 if s1 in b1: 1174 a1v.append(c*sqrt(s1/g2)) 1175 else: 1176 a2v.append(c*s) 1177 else: 1178 a2v.append(c*s) 1179 a = Add(*a1v) 1180 b = Add(*a2v) 1181 return g2, a, b 1182 1183 1184def _split_gcd(*a): 1185 """ 1186 Split the list of integers ``a`` into a list of integers, ``a1`` having 1187 ``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by 1188 ``g``. Returns ``g, a1, a2``. 1189 1190 Examples 1191 ======== 1192 1193 >>> from sympy.simplify.radsimp import _split_gcd 1194 >>> _split_gcd(55, 35, 22, 14, 77, 10) 1195 (5, [55, 35, 10], [22, 14, 77]) 1196 """ 1197 g = a[0] 1198 b1 = [g] 1199 b2 = [] 1200 for x in a[1:]: 1201 g1 = gcd(g, x) 1202 if g1 == 1: 1203 b2.append(x) 1204 else: 1205 g = g1 1206 b1.append(x) 1207 return g, b1, b2 1208