1""" Generic Rules for SymPy
2
3This file assumes knowledge of Basic and little else.
4"""
5from sympy.utilities.iterables import sift
6from .util import new
7
8# Functions that create rules
9
10def rm_id(isid, new=new):
11    """ Create a rule to remove identities.
12
13    isid - fn :: x -> Bool  --- whether or not this element is an identity.
14
15    Examples
16    ========
17
18    >>> from sympy.strategies import rm_id
19    >>> from sympy import Basic
20    >>> remove_zeros = rm_id(lambda x: x==0)
21    >>> remove_zeros(Basic(1, 0, 2))
22    Basic(1, 2)
23    >>> remove_zeros(Basic(0, 0)) # If only identites then we keep one
24    Basic(0)
25
26    See Also:
27        unpack
28    """
29    def ident_remove(expr):
30        """ Remove identities """
31        ids = list(map(isid, expr.args))
32        if sum(ids) == 0:           # No identities. Common case
33            return expr
34        elif sum(ids) != len(ids):  # there is at least one non-identity
35            return new(expr.__class__,
36                       *[arg for arg, x in zip(expr.args, ids) if not x])
37        else:
38            return new(expr.__class__, expr.args[0])
39
40    return ident_remove
41
42def glom(key, count, combine):
43    """ Create a rule to conglomerate identical args.
44
45    Examples
46    ========
47
48    >>> from sympy.strategies import glom
49    >>> from sympy import Add
50    >>> from sympy.abc import x
51
52    >>> key     = lambda x: x.as_coeff_Mul()[1]
53    >>> count   = lambda x: x.as_coeff_Mul()[0]
54    >>> combine = lambda cnt, arg: cnt * arg
55    >>> rl = glom(key, count, combine)
56
57    >>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
58    3*x + 5
59
60    Wait, how are key, count and combine supposed to work?
61
62    >>> key(2*x)
63    x
64    >>> count(2*x)
65    2
66    >>> combine(2, x)
67    2*x
68    """
69    def conglomerate(expr):
70        """ Conglomerate together identical args x + x -> 2x """
71        groups = sift(expr.args, key)
72        counts = {k: sum(map(count, args)) for k, args in groups.items()}
73        newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
74        if set(newargs) != set(expr.args):
75            return new(type(expr), *newargs)
76        else:
77            return expr
78
79    return conglomerate
80
81def sort(key, new=new):
82    """ Create a rule to sort by a key function.
83
84    Examples
85    ========
86
87    >>> from sympy.strategies import sort
88    >>> from sympy import Basic
89    >>> sort_rl = sort(str)
90    >>> sort_rl(Basic(3, 1, 2))
91    Basic(1, 2, 3)
92    """
93
94    def sort_rl(expr):
95        return new(expr.__class__, *sorted(expr.args, key=key))
96    return sort_rl
97
98def distribute(A, B):
99    """ Turns an A containing Bs into a B of As
100
101    where A, B are container types
102
103    >>> from sympy.strategies import distribute
104    >>> from sympy import Add, Mul, symbols
105    >>> x, y = symbols('x,y')
106    >>> dist = distribute(Mul, Add)
107    >>> expr = Mul(2, x+y, evaluate=False)
108    >>> expr
109    2*(x + y)
110    >>> dist(expr)
111    2*x + 2*y
112    """
113
114    def distribute_rl(expr):
115        for i, arg in enumerate(expr.args):
116            if isinstance(arg, B):
117                first, b, tail = expr.args[:i], expr.args[i], expr.args[i+1:]
118                return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
119        return expr
120    return distribute_rl
121
122def subs(a, b):
123    """ Replace expressions exactly """
124    def subs_rl(expr):
125        if expr == a:
126            return b
127        else:
128            return expr
129    return subs_rl
130
131# Functions that are rules
132
133def unpack(expr):
134    """ Rule to unpack singleton args
135
136    >>> from sympy.strategies import unpack
137    >>> from sympy import Basic
138    >>> unpack(Basic(2))
139    2
140    """
141    if len(expr.args) == 1:
142        return expr.args[0]
143    else:
144        return expr
145
146def flatten(expr, new=new):
147    """ Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
148    cls = expr.__class__
149    args = []
150    for arg in expr.args:
151        if arg.__class__ == cls:
152            args.extend(arg.args)
153        else:
154            args.append(arg)
155    return new(expr.__class__, *args)
156
157def rebuild(expr):
158    """ Rebuild a SymPy tree.
159
160    Explanation
161    ===========
162
163    This function recursively calls constructors in the expression tree.
164    This forces canonicalization and removes ugliness introduced by the use of
165    Basic.__new__
166    """
167    if expr.is_Atom:
168        return expr
169    else:
170        return expr.func(*list(map(rebuild, expr.args)))
171