1from collections import defaultdict
2
3from sympy import SYMPY_DEBUG
4
5from sympy.core import expand_power_base, sympify, Add, S, Mul, Derivative, Pow, symbols, expand_mul
6from sympy.core.add import _unevaluated_Add
7from sympy.core.compatibility import iterable, ordered, default_sort_key
8from sympy.core.parameters import global_parameters
9from sympy.core.exprtools import Factors, gcd_terms
10from sympy.core.function import _mexpand
11from sympy.core.mul import _keep_coeff, _unevaluated_Mul
12from sympy.core.numbers import Rational, zoo, nan
13from sympy.functions import exp, sqrt, log
14from sympy.functions.elementary.complexes import Abs
15from sympy.polys import gcd
16from sympy.simplify.sqrtdenest import sqrtdenest
17
18
19
20
21def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True):
22    """
23    Collect additive terms of an expression.
24
25    Explanation
26    ===========
27
28    This function collects additive terms of an expression with respect
29    to a list of expression up to powers with rational exponents. By the
30    term symbol here are meant arbitrary expressions, which can contain
31    powers, products, sums etc. In other words symbol is a pattern which
32    will be searched for in the expression's terms.
33
34    The input expression is not expanded by :func:`collect`, so user is
35    expected to provide an expression in an appropriate form. This makes
36    :func:`collect` more predictable as there is no magic happening behind the
37    scenes. However, it is important to note, that powers of products are
38    converted to products of powers using the :func:`~.expand_power_base`
39    function.
40
41    There are two possible types of output. First, if ``evaluate`` flag is
42    set, this function will return an expression with collected terms or
43    else it will return a dictionary with expressions up to rational powers
44    as keys and collected coefficients as values.
45
46    Examples
47    ========
48
49    >>> from sympy import S, collect, expand, factor, Wild
50    >>> from sympy.abc import a, b, c, x, y
51
52    This function can collect symbolic coefficients in polynomials or
53    rational expressions. It will manage to find all integer or rational
54    powers of collection variable::
55
56        >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x)
57        c + x**2*(a + b) + x*(a - b)
58
59    The same result can be achieved in dictionary form::
60
61        >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False)
62        >>> d[x**2]
63        a + b
64        >>> d[x]
65        a - b
66        >>> d[S.One]
67        c
68
69    You can also work with multivariate polynomials. However, remember that
70    this function is greedy so it will care only about a single symbol at time,
71    in specification order::
72
73        >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y])
74        x**2*(y + 1) + x*y + y*(a + 1)
75
76    Also more complicated expressions can be used as patterns::
77
78        >>> from sympy import sin, log
79        >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x))
80        (a + b)*sin(2*x)
81
82        >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x))
83        x*(a + b)*log(x)
84
85    You can use wildcards in the pattern::
86
87        >>> w = Wild('w1')
88        >>> collect(a*x**y - b*x**y, w**y)
89        x**y*(a - b)
90
91    It is also possible to work with symbolic powers, although it has more
92    complicated behavior, because in this case power's base and symbolic part
93    of the exponent are treated as a single symbol::
94
95        >>> collect(a*x**c + b*x**c, x)
96        a*x**c + b*x**c
97        >>> collect(a*x**c + b*x**c, x**c)
98        x**c*(a + b)
99
100    However if you incorporate rationals to the exponents, then you will get
101    well known behavior::
102
103        >>> collect(a*x**(2*c) + b*x**(2*c), x**c)
104        x**(2*c)*(a + b)
105
106    Note also that all previously stated facts about :func:`collect` function
107    apply to the exponential function, so you can get::
108
109        >>> from sympy import exp
110        >>> collect(a*exp(2*x) + b*exp(2*x), exp(x))
111        (a + b)*exp(2*x)
112
113    If you are interested only in collecting specific powers of some symbols
114    then set ``exact`` flag in arguments::
115
116        >>> collect(a*x**7 + b*x**7, x, exact=True)
117        a*x**7 + b*x**7
118        >>> collect(a*x**7 + b*x**7, x**7, exact=True)
119        x**7*(a + b)
120
121    You can also apply this function to differential equations, where
122    derivatives of arbitrary order can be collected. Note that if you
123    collect with respect to a function or a derivative of a function, all
124    derivatives of that function will also be collected. Use
125    ``exact=True`` to prevent this from happening::
126
127        >>> from sympy import Derivative as D, collect, Function
128        >>> f = Function('f') (x)
129
130        >>> collect(a*D(f,x) + b*D(f,x), D(f,x))
131        (a + b)*Derivative(f(x), x)
132
133        >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f)
134        (a + b)*Derivative(f(x), (x, 2))
135
136        >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True)
137        a*Derivative(f(x), (x, 2)) + b*Derivative(f(x), (x, 2))
138
139        >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f)
140        (a + b)*f(x) + (a + b)*Derivative(f(x), x)
141
142    Or you can even match both derivative order and exponent at the same time::
143
144        >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x))
145        (a + b)*Derivative(f(x), (x, 2))**2
146
147    Finally, you can apply a function to each of the collected coefficients.
148    For example you can factorize symbolic coefficients of polynomial::
149
150        >>> f = expand((x + a + 1)**3)
151
152        >>> collect(f, x, factor)
153        x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3
154
155    .. note:: Arguments are expected to be in expanded form, so you might have
156              to call :func:`~.expand` prior to calling this function.
157
158    See Also
159    ========
160
161    collect_const, collect_sqrt, rcollect
162    """
163    from sympy.core.assumptions import assumptions
164    from sympy.utilities.iterables import sift
165    from sympy.core.symbol import Dummy, Wild
166    expr = sympify(expr)
167    syms = [sympify(i) for i in (syms if iterable(syms) else [syms])]
168    # replace syms[i] if it is not x, -x or has Wild symbols
169    cond = lambda x: x.is_Symbol or (-x).is_Symbol or bool(
170        x.atoms(Wild))
171    _, nonsyms = sift(syms, cond, binary=True)
172    if nonsyms:
173        reps = dict(zip(nonsyms, [Dummy(**assumptions(i)) for i in nonsyms]))
174        syms = [reps.get(s, s) for s in syms]
175        rv = collect(expr.subs(reps), syms,
176            func=func, evaluate=evaluate, exact=exact,
177            distribute_order_term=distribute_order_term)
178        urep = {v: k for k, v in reps.items()}
179        if not isinstance(rv, dict):
180            return rv.xreplace(urep)
181        else:
182            return {urep.get(k, k).xreplace(urep): v.xreplace(urep)
183                    for k, v in rv.items()}
184
185    if evaluate is None:
186        evaluate = global_parameters.evaluate
187
188    def make_expression(terms):
189        product = []
190
191        for term, rat, sym, deriv in terms:
192            if deriv is not None:
193                var, order = deriv
194
195                while order > 0:
196                    term, order = Derivative(term, var), order - 1
197
198            if sym is None:
199                if rat is S.One:
200                    product.append(term)
201                else:
202                    product.append(Pow(term, rat))
203            else:
204                product.append(Pow(term, rat*sym))
205
206        return Mul(*product)
207
208    def parse_derivative(deriv):
209        # scan derivatives tower in the input expression and return
210        # underlying function and maximal differentiation order
211        expr, sym, order = deriv.expr, deriv.variables[0], 1
212
213        for s in deriv.variables[1:]:
214            if s == sym:
215                order += 1
216            else:
217                raise NotImplementedError(
218                    'Improve MV Derivative support in collect')
219
220        while isinstance(expr, Derivative):
221            s0 = expr.variables[0]
222
223            for s in expr.variables:
224                if s != s0:
225                    raise NotImplementedError(
226                        'Improve MV Derivative support in collect')
227
228            if s0 == sym:
229                expr, order = expr.expr, order + len(expr.variables)
230            else:
231                break
232
233        return expr, (sym, Rational(order))
234
235    def parse_term(expr):
236        """Parses expression expr and outputs tuple (sexpr, rat_expo,
237        sym_expo, deriv)
238        where:
239         - sexpr is the base expression
240         - rat_expo is the rational exponent that sexpr is raised to
241         - sym_expo is the symbolic exponent that sexpr is raised to
242         - deriv contains the derivatives the the expression
243
244         For example, the output of x would be (x, 1, None, None)
245         the output of 2**x would be (2, 1, x, None).
246        """
247        rat_expo, sym_expo = S.One, None
248        sexpr, deriv = expr, None
249
250        if expr.is_Pow:
251            if isinstance(expr.base, Derivative):
252                sexpr, deriv = parse_derivative(expr.base)
253            else:
254                sexpr = expr.base
255
256            if expr.base == S.Exp1:
257                arg = expr.exp
258                if arg.is_Rational:
259                    sexpr, rat_expo = S.Exp1, arg
260                elif arg.is_Mul:
261                    coeff, tail = arg.as_coeff_Mul(rational=True)
262                    sexpr, rat_expo = exp(tail), coeff
263
264            elif expr.exp.is_Number:
265                rat_expo = expr.exp
266            else:
267                coeff, tail = expr.exp.as_coeff_Mul()
268
269                if coeff.is_Number:
270                    rat_expo, sym_expo = coeff, tail
271                else:
272                    sym_expo = expr.exp
273        elif isinstance(expr, exp):
274            arg = expr.exp
275            if arg.is_Rational:
276                sexpr, rat_expo = S.Exp1, arg
277            elif arg.is_Mul:
278                coeff, tail = arg.as_coeff_Mul(rational=True)
279                sexpr, rat_expo = exp(tail), coeff
280        elif isinstance(expr, Derivative):
281            sexpr, deriv = parse_derivative(expr)
282
283        return sexpr, rat_expo, sym_expo, deriv
284
285    def parse_expression(terms, pattern):
286        """Parse terms searching for a pattern.
287        Terms is a list of tuples as returned by parse_terms;
288        Pattern is an expression treated as a product of factors.
289        """
290        pattern = Mul.make_args(pattern)
291
292        if len(terms) < len(pattern):
293            # pattern is longer than matched product
294            # so no chance for positive parsing result
295            return None
296        else:
297            pattern = [parse_term(elem) for elem in pattern]
298
299            terms = terms[:]  # need a copy
300            elems, common_expo, has_deriv = [], None, False
301
302            for elem, e_rat, e_sym, e_ord in pattern:
303
304                if elem.is_Number and e_rat == 1 and e_sym is None:
305                    # a constant is a match for everything
306                    continue
307
308                for j in range(len(terms)):
309                    if terms[j] is None:
310                        continue
311
312                    term, t_rat, t_sym, t_ord = terms[j]
313
314                    # keeping track of whether one of the terms had
315                    # a derivative or not as this will require rebuilding
316                    # the expression later
317                    if t_ord is not None:
318                        has_deriv = True
319
320                    if (term.match(elem) is not None and
321                            (t_sym == e_sym or t_sym is not None and
322                            e_sym is not None and
323                            t_sym.match(e_sym) is not None)):
324                        if exact is False:
325                            # we don't have to be exact so find common exponent
326                            # for both expression's term and pattern's element
327                            expo = t_rat / e_rat
328
329                            if common_expo is None:
330                                # first time
331                                common_expo = expo
332                            else:
333                                # common exponent was negotiated before so
334                                # there is no chance for a pattern match unless
335                                # common and current exponents are equal
336                                if common_expo != expo:
337                                    common_expo = 1
338                        else:
339                            # we ought to be exact so all fields of
340                            # interest must match in every details
341                            if e_rat != t_rat or e_ord != t_ord:
342                                continue
343
344                        # found common term so remove it from the expression
345                        # and try to match next element in the pattern
346                        elems.append(terms[j])
347                        terms[j] = None
348
349                        break
350
351                else:
352                    # pattern element not found
353                    return None
354
355            return [_f for _f in terms if _f], elems, common_expo, has_deriv
356
357    if evaluate:
358        if expr.is_Add:
359            o = expr.getO() or 0
360            expr = expr.func(*[
361                    collect(a, syms, func, True, exact, distribute_order_term)
362                    for a in expr.args if a != o]) + o
363        elif expr.is_Mul:
364            return expr.func(*[
365                collect(term, syms, func, True, exact, distribute_order_term)
366                for term in expr.args])
367        elif expr.is_Pow:
368            b = collect(
369                expr.base, syms, func, True, exact, distribute_order_term)
370            return Pow(b, expr.exp)
371
372    syms = [expand_power_base(i, deep=False) for i in syms]
373
374    order_term = None
375
376    if distribute_order_term:
377        order_term = expr.getO()
378
379        if order_term is not None:
380            if order_term.has(*syms):
381                order_term = None
382            else:
383                expr = expr.removeO()
384
385    summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)]
386
387    collected, disliked = defaultdict(list), S.Zero
388    for product in summa:
389        c, nc = product.args_cnc(split_1=False)
390        args = list(ordered(c)) + nc
391        terms = [parse_term(i) for i in args]
392        small_first = True
393
394        for symbol in syms:
395            if SYMPY_DEBUG:
396                print("DEBUG: parsing of expression %s with symbol %s " % (
397                    str(terms), str(symbol))
398                )
399
400            if isinstance(symbol, Derivative) and small_first:
401                terms = list(reversed(terms))
402                small_first = not small_first
403            result = parse_expression(terms, symbol)
404
405            if SYMPY_DEBUG:
406                print("DEBUG: returned %s" % str(result))
407
408            if result is not None:
409                if not symbol.is_commutative:
410                    raise AttributeError("Can not collect noncommutative symbol")
411
412                terms, elems, common_expo, has_deriv = result
413
414                # when there was derivative in current pattern we
415                # will need to rebuild its expression from scratch
416                if not has_deriv:
417                    margs = []
418                    for elem in elems:
419                        if elem[2] is None:
420                            e = elem[1]
421                        else:
422                            e = elem[1]*elem[2]
423                        margs.append(Pow(elem[0], e))
424                    index = Mul(*margs)
425                else:
426                    index = make_expression(elems)
427                terms = expand_power_base(make_expression(terms), deep=False)
428                index = expand_power_base(index, deep=False)
429                collected[index].append(terms)
430                break
431        else:
432            # none of the patterns matched
433            disliked += product
434    # add terms now for each key
435    collected = {k: Add(*v) for k, v in collected.items()}
436
437    if disliked is not S.Zero:
438        collected[S.One] = disliked
439
440    if order_term is not None:
441        for key, val in collected.items():
442            collected[key] = val + order_term
443
444    if func is not None:
445        collected = {
446            key: func(val) for key, val in collected.items()}
447
448    if evaluate:
449        return Add(*[key*val for key, val in collected.items()])
450    else:
451        return collected
452
453
454def rcollect(expr, *vars):
455    """
456    Recursively collect sums in an expression.
457
458    Examples
459    ========
460
461    >>> from sympy.simplify import rcollect
462    >>> from sympy.abc import x, y
463
464    >>> expr = (x**2*y + x*y + x + y)/(x + y)
465
466    >>> rcollect(expr, y)
467    (x + y*(x**2 + x + 1))/(x + y)
468
469    See Also
470    ========
471
472    collect, collect_const, collect_sqrt
473    """
474    if expr.is_Atom or not expr.has(*vars):
475        return expr
476    else:
477        expr = expr.__class__(*[rcollect(arg, *vars) for arg in expr.args])
478
479        if expr.is_Add:
480            return collect(expr, vars)
481        else:
482            return expr
483
484
485def collect_sqrt(expr, evaluate=None):
486    """Return expr with terms having common square roots collected together.
487    If ``evaluate`` is False a count indicating the number of sqrt-containing
488    terms will be returned and, if non-zero, the terms of the Add will be
489    returned, else the expression itself will be returned as a single term.
490    If ``evaluate`` is True, the expression with any collected terms will be
491    returned.
492
493    Note: since I = sqrt(-1), it is collected, too.
494
495    Examples
496    ========
497
498    >>> from sympy import sqrt
499    >>> from sympy.simplify.radsimp import collect_sqrt
500    >>> from sympy.abc import a, b
501
502    >>> r2, r3, r5 = [sqrt(i) for i in [2, 3, 5]]
503    >>> collect_sqrt(a*r2 + b*r2)
504    sqrt(2)*(a + b)
505    >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r3)
506    sqrt(2)*(a + b) + sqrt(3)*(a + b)
507    >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5)
508    sqrt(3)*a + sqrt(5)*b + sqrt(2)*(a + b)
509
510    If evaluate is False then the arguments will be sorted and
511    returned as a list and a count of the number of sqrt-containing
512    terms will be returned:
513
514    >>> collect_sqrt(a*r2 + b*r2 + a*r3 + b*r5, evaluate=False)
515    ((sqrt(3)*a, sqrt(5)*b, sqrt(2)*(a + b)), 3)
516    >>> collect_sqrt(a*sqrt(2) + b, evaluate=False)
517    ((b, sqrt(2)*a), 1)
518    >>> collect_sqrt(a + b, evaluate=False)
519    ((a + b,), 0)
520
521    See Also
522    ========
523
524    collect, collect_const, rcollect
525    """
526    if evaluate is None:
527        evaluate = global_parameters.evaluate
528    # this step will help to standardize any complex arguments
529    # of sqrts
530    coeff, expr = expr.as_content_primitive()
531    vars = set()
532    for a in Add.make_args(expr):
533        for m in a.args_cnc()[0]:
534            if m.is_number and (
535                    m.is_Pow and m.exp.is_Rational and m.exp.q == 2 or
536                    m is S.ImaginaryUnit):
537                vars.add(m)
538
539    # we only want radicals, so exclude Number handling; in this case
540    # d will be evaluated
541    d = collect_const(expr, *vars, Numbers=False)
542    hit = expr != d
543
544    if not evaluate:
545        nrad = 0
546        # make the evaluated args canonical
547        args = list(ordered(Add.make_args(d)))
548        for i, m in enumerate(args):
549            c, nc = m.args_cnc()
550            for ci in c:
551                # XXX should this be restricted to ci.is_number as above?
552                if ci.is_Pow and ci.exp.is_Rational and ci.exp.q == 2 or \
553                        ci is S.ImaginaryUnit:
554                    nrad += 1
555                    break
556            args[i] *= coeff
557        if not (hit or nrad):
558            args = [Add(*args)]
559        return tuple(args), nrad
560
561    return coeff*d
562
563
564def collect_abs(expr):
565    """Return ``expr`` with arguments of multiple Abs in a term collected
566    under a single instance.
567
568    Examples
569    ========
570
571    >>> from sympy.simplify.radsimp import collect_abs
572    >>> from sympy.abc import x
573    >>> collect_abs(abs(x + 1)/abs(x**2 - 1))
574    Abs((x + 1)/(x**2 - 1))
575    >>> collect_abs(abs(1/x))
576    Abs(1/x)
577    """
578    def _abs(mul):
579      from sympy.core.mul import _mulsort
580      c, nc = mul.args_cnc()
581      a = []
582      o = []
583      for i in c:
584          if isinstance(i, Abs):
585              a.append(i.args[0])
586          elif isinstance(i, Pow) and isinstance(i.base, Abs) and i.exp.is_real:
587              a.append(i.base.args[0]**i.exp)
588          else:
589              o.append(i)
590      if len(a) < 2 and not any(i.exp.is_negative for i in a if isinstance(i, Pow)):
591          return mul
592      absarg = Mul(*a)
593      A = Abs(absarg)
594      args = [A]
595      args.extend(o)
596      if not A.has(Abs):
597          args.extend(nc)
598          return Mul(*args)
599      if not isinstance(A, Abs):
600          # reevaluate and make it unevaluated
601          A = Abs(absarg, evaluate=False)
602      args[0] = A
603      _mulsort(args)
604      args.extend(nc)  # nc always go last
605      return Mul._from_args(args, is_commutative=not nc)
606
607    return expr.replace(
608        lambda x: isinstance(x, Mul),
609        lambda x: _abs(x)).replace(
610            lambda x: isinstance(x, Pow),
611            lambda x: _abs(x))
612
613
614def collect_const(expr, *vars, Numbers=True):
615    """A non-greedy collection of terms with similar number coefficients in
616    an Add expr. If ``vars`` is given then only those constants will be
617    targeted. Although any Number can also be targeted, if this is not
618    desired set ``Numbers=False`` and no Float or Rational will be collected.
619
620    Parameters
621    ==========
622
623    expr : sympy expression
624        This parameter defines the expression the expression from which
625        terms with similar coefficients are to be collected. A non-Add
626        expression is returned as it is.
627
628    vars : variable length collection of Numbers, optional
629        Specifies the constants to target for collection. Can be multiple in
630        number.
631
632    Numbers : bool
633        Specifies to target all instance of
634        :class:`sympy.core.numbers.Number` class. If ``Numbers=False``, then
635        no Float or Rational will be collected.
636
637    Returns
638    =======
639
640    expr : Expr
641        Returns an expression with similar coefficient terms collected.
642
643    Examples
644    ========
645
646    >>> from sympy import sqrt
647    >>> from sympy.abc import s, x, y, z
648    >>> from sympy.simplify.radsimp import collect_const
649    >>> collect_const(sqrt(3) + sqrt(3)*(1 + sqrt(2)))
650    sqrt(3)*(sqrt(2) + 2)
651    >>> collect_const(sqrt(3)*s + sqrt(7)*s + sqrt(3) + sqrt(7))
652    (sqrt(3) + sqrt(7))*(s + 1)
653    >>> s = sqrt(2) + 2
654    >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7))
655    (sqrt(2) + 3)*(sqrt(3) + sqrt(7))
656    >>> collect_const(sqrt(3)*s + sqrt(3) + sqrt(7)*s + sqrt(7), sqrt(3))
657    sqrt(7) + sqrt(3)*(sqrt(2) + 3) + sqrt(7)*(sqrt(2) + 2)
658
659    The collection is sign-sensitive, giving higher precedence to the
660    unsigned values:
661
662    >>> collect_const(x - y - z)
663    x - (y + z)
664    >>> collect_const(-y - z)
665    -(y + z)
666    >>> collect_const(2*x - 2*y - 2*z, 2)
667    2*(x - y - z)
668    >>> collect_const(2*x - 2*y - 2*z, -2)
669    2*x - 2*(y + z)
670
671    See Also
672    ========
673
674    collect, collect_sqrt, rcollect
675    """
676    if not expr.is_Add:
677        return expr
678
679    recurse = False
680
681    if not vars:
682        recurse = True
683        vars = set()
684        for a in expr.args:
685            for m in Mul.make_args(a):
686                if m.is_number:
687                    vars.add(m)
688    else:
689        vars = sympify(vars)
690    if not Numbers:
691        vars = [v for v in vars if not v.is_Number]
692
693    vars = list(ordered(vars))
694    for v in vars:
695        terms = defaultdict(list)
696        Fv = Factors(v)
697        for m in Add.make_args(expr):
698            f = Factors(m)
699            q, r = f.div(Fv)
700            if r.is_one:
701                # only accept this as a true factor if
702                # it didn't change an exponent from an Integer
703                # to a non-Integer, e.g. 2/sqrt(2) -> sqrt(2)
704                # -- we aren't looking for this sort of change
705                fwas = f.factors.copy()
706                fnow = q.factors
707                if not any(k in fwas and fwas[k].is_Integer and not
708                        fnow[k].is_Integer for k in fnow):
709                    terms[v].append(q.as_expr())
710                    continue
711            terms[S.One].append(m)
712
713        args = []
714        hit = False
715        uneval = False
716        for k in ordered(terms):
717            v = terms[k]
718            if k is S.One:
719                args.extend(v)
720                continue
721
722            if len(v) > 1:
723                v = Add(*v)
724                hit = True
725                if recurse and v != expr:
726                    vars.append(v)
727            else:
728                v = v[0]
729
730            # be careful not to let uneval become True unless
731            # it must be because it's going to be more expensive
732            # to rebuild the expression as an unevaluated one
733            if Numbers and k.is_Number and v.is_Add:
734                args.append(_keep_coeff(k, v, sign=True))
735                uneval = True
736            else:
737                args.append(k*v)
738
739        if hit:
740            if uneval:
741                expr = _unevaluated_Add(*args)
742            else:
743                expr = Add(*args)
744            if not expr.is_Add:
745                break
746
747    return expr
748
749
750def radsimp(expr, symbolic=True, max_terms=4):
751    r"""
752    Rationalize the denominator by removing square roots.
753
754    Explanation
755    ===========
756
757    The expression returned from radsimp must be used with caution
758    since if the denominator contains symbols, it will be possible to make
759    substitutions that violate the assumptions of the simplification process:
760    that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If
761    there are no symbols, this assumptions is made valid by collecting terms
762    of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If
763    you do not want the simplification to occur for symbolic denominators, set
764    ``symbolic`` to False.
765
766    If there are more than ``max_terms`` radical terms then the expression is
767    returned unchanged.
768
769    Examples
770    ========
771
772    >>> from sympy import radsimp, sqrt, Symbol, pprint
773    >>> from sympy import factor_terms, fraction, signsimp
774    >>> from sympy.simplify.radsimp import collect_sqrt
775    >>> from sympy.abc import a, b, c
776
777    >>> radsimp(1/(2 + sqrt(2)))
778    (2 - sqrt(2))/2
779    >>> x,y = map(Symbol, 'xy')
780    >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2))
781    >>> radsimp(e)
782    sqrt(2)*(x + y)
783
784    No simplification beyond removal of the gcd is done. One might
785    want to polish the result a little, however, by collecting
786    square root terms:
787
788    >>> r2 = sqrt(2)
789    >>> r5 = sqrt(5)
790    >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans)
791        ___       ___       ___       ___
792      \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y
793    ------------------------------------------
794       2               2      2              2
795    5*a  + 10*a*b + 5*b  - 2*x  - 4*x*y - 2*y
796
797    >>> n, d = fraction(ans)
798    >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True))
799            ___             ___
800          \/ 5 *(a + b) - \/ 2 *(x + y)
801    ------------------------------------------
802       2               2      2              2
803    5*a  + 10*a*b + 5*b  - 2*x  - 4*x*y - 2*y
804
805    If radicals in the denominator cannot be removed or there is no denominator,
806    the original expression will be returned.
807
808    >>> radsimp(sqrt(2)*x + sqrt(2))
809    sqrt(2)*x + sqrt(2)
810
811    Results with symbols will not always be valid for all substitutions:
812
813    >>> eq = 1/(a + b*sqrt(c))
814    >>> eq.subs(a, b*sqrt(c))
815    1/(2*b*sqrt(c))
816    >>> radsimp(eq).subs(a, b*sqrt(c))
817    nan
818
819    If ``symbolic=False``, symbolic denominators will not be transformed (but
820    numeric denominators will still be processed):
821
822    >>> radsimp(eq, symbolic=False)
823    1/(a + b*sqrt(c))
824
825    """
826    from sympy.simplify.simplify import signsimp
827
828    syms = symbols("a:d A:D")
829    def _num(rterms):
830        # return the multiplier that will simplify the expression described
831        # by rterms [(sqrt arg, coeff), ... ]
832        a, b, c, d, A, B, C, D = syms
833        if len(rterms) == 2:
834            reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i])))
835            return (
836            sqrt(A)*a - sqrt(B)*b).xreplace(reps)
837        if len(rterms) == 3:
838            reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i])))
839            return (
840            (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 -
841            B*b**2 + C*c**2)).xreplace(reps)
842        elif len(rterms) == 4:
843            reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i])))
844            return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b
845                - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 +
846                D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 -
847                2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 -
848                2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 +
849                D**2*d**4)).xreplace(reps)
850        elif len(rterms) == 1:
851            return sqrt(rterms[0][0])
852        else:
853            raise NotImplementedError
854
855    def ispow2(d, log2=False):
856        if not d.is_Pow:
857            return False
858        e = d.exp
859        if e.is_Rational and e.q == 2 or symbolic and denom(e) == 2:
860            return True
861        if log2:
862            q = 1
863            if e.is_Rational:
864                q = e.q
865            elif symbolic:
866                d = denom(e)
867                if d.is_Integer:
868                    q = d
869            if q != 1 and log(q, 2).is_Integer:
870                return True
871        return False
872
873    def handle(expr):
874        # Handle first reduces to the case
875        # expr = 1/d, where d is an add, or d is base**p/2.
876        # We do this by recursively calling handle on each piece.
877        from sympy.simplify.simplify import nsimplify
878
879        n, d = fraction(expr)
880
881        if expr.is_Atom or (d.is_Atom and n.is_Atom):
882            return expr
883        elif not n.is_Atom:
884            n = n.func(*[handle(a) for a in n.args])
885            return _unevaluated_Mul(n, handle(1/d))
886        elif n is not S.One:
887            return _unevaluated_Mul(n, handle(1/d))
888        elif d.is_Mul:
889            return _unevaluated_Mul(*[handle(1/d) for d in d.args])
890
891        # By this step, expr is 1/d, and d is not a mul.
892        if not symbolic and d.free_symbols:
893            return expr
894
895        if ispow2(d):
896            d2 = sqrtdenest(sqrt(d.base))**numer(d.exp)
897            if d2 != d:
898                return handle(1/d2)
899        elif d.is_Pow and (d.exp.is_integer or d.base.is_positive):
900            # (1/d**i) = (1/d)**i
901            return handle(1/d.base)**d.exp
902
903        if not (d.is_Add or ispow2(d)):
904            return 1/d.func(*[handle(a) for a in d.args])
905
906        # handle 1/d treating d as an Add (though it may not be)
907
908        keep = True  # keep changes that are made
909
910        # flatten it and collect radicals after checking for special
911        # conditions
912        d = _mexpand(d)
913
914        # did it change?
915        if d.is_Atom:
916            return 1/d
917
918        # is it a number that might be handled easily?
919        if d.is_number:
920            _d = nsimplify(d)
921            if _d.is_Number and _d.equals(d):
922                return 1/_d
923
924        while True:
925            # collect similar terms
926            collected = defaultdict(list)
927            for m in Add.make_args(d):  # d might have become non-Add
928                p2 = []
929                other = []
930                for i in Mul.make_args(m):
931                    if ispow2(i, log2=True):
932                        p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp))
933                    elif i is S.ImaginaryUnit:
934                        p2.append(S.NegativeOne)
935                    else:
936                        other.append(i)
937                collected[tuple(ordered(p2))].append(Mul(*other))
938            rterms = list(ordered(list(collected.items())))
939            rterms = [(Mul(*i), Add(*j)) for i, j in rterms]
940            nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0)
941            if nrad < 1:
942                break
943            elif nrad > max_terms:
944                # there may have been invalid operations leading to this point
945                # so don't keep changes, e.g. this expression is troublesome
946                # in collecting terms so as not to raise the issue of 2834:
947                # r = sqrt(sqrt(5) + 5)
948                # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r)
949                keep = False
950                break
951            if len(rterms) > 4:
952                # in general, only 4 terms can be removed with repeated squaring
953                # but other considerations can guide selection of radical terms
954                # so that radicals are removed
955                if all([x.is_Integer and (y**2).is_Rational for x, y in rterms]):
956                    nd, d = rad_rationalize(S.One, Add._from_args(
957                        [sqrt(x)*y for x, y in rterms]))
958                    n *= nd
959                else:
960                    # is there anything else that might be attempted?
961                    keep = False
962                break
963            from sympy.simplify.powsimp import powsimp, powdenest
964
965            num = powsimp(_num(rterms))
966            n *= num
967            d *= num
968            d = powdenest(_mexpand(d), force=symbolic)
969            if d.has(S.Zero, nan, zoo):
970                return expr
971            if d.is_Atom:
972                break
973
974        if not keep:
975            return expr
976        return _unevaluated_Mul(n, 1/d)
977
978    coeff, expr = expr.as_coeff_Add()
979    expr = expr.normal()
980    old = fraction(expr)
981    n, d = fraction(handle(expr))
982    if old != (n, d):
983        if not d.is_Atom:
984            was = (n, d)
985            n = signsimp(n, evaluate=False)
986            d = signsimp(d, evaluate=False)
987            u = Factors(_unevaluated_Mul(n, 1/d))
988            u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()])
989            n, d = fraction(u)
990            if old == (n, d):
991                n, d = was
992        n = expand_mul(n)
993        if d.is_Number or d.is_Add:
994            n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d)))
995            if d2.is_Number or (d2.count_ops() <= d.count_ops()):
996                n, d = [signsimp(i) for i in (n2, d2)]
997                if n.is_Mul and n.args[0].is_Number:
998                    n = n.func(*n.args)
999
1000    return coeff + _unevaluated_Mul(n, 1/d)
1001
1002
1003def rad_rationalize(num, den):
1004    """
1005    Rationalize ``num/den`` by removing square roots in the denominator;
1006    num and den are sum of terms whose squares are positive rationals.
1007
1008    Examples
1009    ========
1010
1011    >>> from sympy import sqrt
1012    >>> from sympy.simplify.radsimp import rad_rationalize
1013    >>> rad_rationalize(sqrt(3), 1 + sqrt(2)/3)
1014    (-sqrt(3) + sqrt(6)/3, -7/9)
1015    """
1016    if not den.is_Add:
1017        return num, den
1018    g, a, b = split_surds(den)
1019    a = a*sqrt(g)
1020    num = _mexpand((a - b)*num)
1021    den = _mexpand(a**2 - b**2)
1022    return rad_rationalize(num, den)
1023
1024
1025def fraction(expr, exact=False):
1026    """Returns a pair with expression's numerator and denominator.
1027       If the given expression is not a fraction then this function
1028       will return the tuple (expr, 1).
1029
1030       This function will not make any attempt to simplify nested
1031       fractions or to do any term rewriting at all.
1032
1033       If only one of the numerator/denominator pair is needed then
1034       use numer(expr) or denom(expr) functions respectively.
1035
1036       >>> from sympy import fraction, Rational, Symbol
1037       >>> from sympy.abc import x, y
1038
1039       >>> fraction(x/y)
1040       (x, y)
1041       >>> fraction(x)
1042       (x, 1)
1043
1044       >>> fraction(1/y**2)
1045       (1, y**2)
1046
1047       >>> fraction(x*y/2)
1048       (x*y, 2)
1049       >>> fraction(Rational(1, 2))
1050       (1, 2)
1051
1052       This function will also work fine with assumptions:
1053
1054       >>> k = Symbol('k', negative=True)
1055       >>> fraction(x * y**k)
1056       (x, y**(-k))
1057
1058       If we know nothing about sign of some exponent and ``exact``
1059       flag is unset, then structure this exponent's structure will
1060       be analyzed and pretty fraction will be returned:
1061
1062       >>> from sympy import exp, Mul
1063       >>> fraction(2*x**(-y))
1064       (2, x**y)
1065
1066       >>> fraction(exp(-x))
1067       (1, exp(x))
1068
1069       >>> fraction(exp(-x), exact=True)
1070       (exp(-x), 1)
1071
1072       The ``exact`` flag will also keep any unevaluated Muls from
1073       being evaluated:
1074
1075       >>> u = Mul(2, x + 1, evaluate=False)
1076       >>> fraction(u)
1077       (2*x + 2, 1)
1078       >>> fraction(u, exact=True)
1079       (2*(x  + 1), 1)
1080    """
1081    expr = sympify(expr)
1082
1083    numer, denom = [], []
1084
1085    for term in Mul.make_args(expr):
1086        if term.is_commutative and (term.is_Pow or isinstance(term, exp)):
1087            b, ex = term.as_base_exp()
1088            if ex.is_negative:
1089                if ex is S.NegativeOne:
1090                    denom.append(b)
1091                elif exact:
1092                    if ex.is_constant():
1093                        denom.append(Pow(b, -ex))
1094                    else:
1095                        numer.append(term)
1096                else:
1097                    denom.append(Pow(b, -ex))
1098            elif ex.is_positive:
1099                numer.append(term)
1100            elif not exact and ex.is_Mul:
1101                n, d = term.as_numer_denom()
1102                if n != 1:
1103                    numer.append(n)
1104                denom.append(d)
1105            else:
1106                numer.append(term)
1107        elif term.is_Rational and not term.is_Integer:
1108            if term.p != 1:
1109                numer.append(term.p)
1110            denom.append(term.q)
1111        else:
1112            numer.append(term)
1113    return Mul(*numer, evaluate=not exact), Mul(*denom, evaluate=not exact)
1114
1115
1116def numer(expr):
1117    return fraction(expr)[0]
1118
1119
1120def denom(expr):
1121    return fraction(expr)[1]
1122
1123
1124def fraction_expand(expr, **hints):
1125    return expr.expand(frac=True, **hints)
1126
1127
1128def numer_expand(expr, **hints):
1129    a, b = fraction(expr)
1130    return a.expand(numer=True, **hints) / b
1131
1132
1133def denom_expand(expr, **hints):
1134    a, b = fraction(expr)
1135    return a / b.expand(denom=True, **hints)
1136
1137
1138expand_numer = numer_expand
1139expand_denom = denom_expand
1140expand_fraction = fraction_expand
1141
1142
1143def split_surds(expr):
1144    """
1145    Split an expression with terms whose squares are positive rationals
1146    into a sum of terms whose surds squared have gcd equal to g
1147    and a sum of terms with surds squared prime with g.
1148
1149    Examples
1150    ========
1151
1152    >>> from sympy import sqrt
1153    >>> from sympy.simplify.radsimp import split_surds
1154    >>> split_surds(3*sqrt(3) + sqrt(5)/7 + sqrt(6) + sqrt(10) + sqrt(15))
1155    (3, sqrt(2) + sqrt(5) + 3, sqrt(5)/7 + sqrt(10))
1156    """
1157    args = sorted(expr.args, key=default_sort_key)
1158    coeff_muls = [x.as_coeff_Mul() for x in args]
1159    surds = [x[1]**2 for x in coeff_muls if x[1].is_Pow]
1160    surds.sort(key=default_sort_key)
1161    g, b1, b2 = _split_gcd(*surds)
1162    g2 = g
1163    if not b2 and len(b1) >= 2:
1164        b1n = [x/g for x in b1]
1165        b1n = [x for x in b1n if x != 1]
1166        # only a common factor has been factored; split again
1167        g1, b1n, b2 = _split_gcd(*b1n)
1168        g2 = g*g1
1169    a1v, a2v = [], []
1170    for c, s in coeff_muls:
1171        if s.is_Pow and s.exp == S.Half:
1172            s1 = s.base
1173            if s1 in b1:
1174                a1v.append(c*sqrt(s1/g2))
1175            else:
1176                a2v.append(c*s)
1177        else:
1178            a2v.append(c*s)
1179    a = Add(*a1v)
1180    b = Add(*a2v)
1181    return g2, a, b
1182
1183
1184def _split_gcd(*a):
1185    """
1186    Split the list of integers ``a`` into a list of integers, ``a1`` having
1187    ``g = gcd(a1)``, and a list ``a2`` whose elements are not divisible by
1188    ``g``.  Returns ``g, a1, a2``.
1189
1190    Examples
1191    ========
1192
1193    >>> from sympy.simplify.radsimp import _split_gcd
1194    >>> _split_gcd(55, 35, 22, 14, 77, 10)
1195    (5, [55, 35, 10], [22, 14, 77])
1196    """
1197    g = a[0]
1198    b1 = [g]
1199    b2 = []
1200    for x in a[1:]:
1201        g1 = gcd(g, x)
1202        if g1 == 1:
1203            b2.append(x)
1204        else:
1205            g = g1
1206            b1.append(x)
1207    return g, b1, b2
1208