1""" Generic Rules for SymPy 2 3This file assumes knowledge of Basic and little else. 4""" 5from sympy.utilities.iterables import sift 6from .util import new 7 8# Functions that create rules 9 10def rm_id(isid, new=new): 11 """ Create a rule to remove identities. 12 13 isid - fn :: x -> Bool --- whether or not this element is an identity. 14 15 Examples 16 ======== 17 18 >>> from sympy.strategies import rm_id 19 >>> from sympy import Basic 20 >>> remove_zeros = rm_id(lambda x: x==0) 21 >>> remove_zeros(Basic(1, 0, 2)) 22 Basic(1, 2) 23 >>> remove_zeros(Basic(0, 0)) # If only identites then we keep one 24 Basic(0) 25 26 See Also: 27 unpack 28 """ 29 def ident_remove(expr): 30 """ Remove identities """ 31 ids = list(map(isid, expr.args)) 32 if sum(ids) == 0: # No identities. Common case 33 return expr 34 elif sum(ids) != len(ids): # there is at least one non-identity 35 return new(expr.__class__, 36 *[arg for arg, x in zip(expr.args, ids) if not x]) 37 else: 38 return new(expr.__class__, expr.args[0]) 39 40 return ident_remove 41 42def glom(key, count, combine): 43 """ Create a rule to conglomerate identical args. 44 45 Examples 46 ======== 47 48 >>> from sympy.strategies import glom 49 >>> from sympy import Add 50 >>> from sympy.abc import x 51 52 >>> key = lambda x: x.as_coeff_Mul()[1] 53 >>> count = lambda x: x.as_coeff_Mul()[0] 54 >>> combine = lambda cnt, arg: cnt * arg 55 >>> rl = glom(key, count, combine) 56 57 >>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False)) 58 3*x + 5 59 60 Wait, how are key, count and combine supposed to work? 61 62 >>> key(2*x) 63 x 64 >>> count(2*x) 65 2 66 >>> combine(2, x) 67 2*x 68 """ 69 def conglomerate(expr): 70 """ Conglomerate together identical args x + x -> 2x """ 71 groups = sift(expr.args, key) 72 counts = {k: sum(map(count, args)) for k, args in groups.items()} 73 newargs = [combine(cnt, mat) for mat, cnt in counts.items()] 74 if set(newargs) != set(expr.args): 75 return new(type(expr), *newargs) 76 else: 77 return expr 78 79 return conglomerate 80 81def sort(key, new=new): 82 """ Create a rule to sort by a key function. 83 84 Examples 85 ======== 86 87 >>> from sympy.strategies import sort 88 >>> from sympy import Basic 89 >>> sort_rl = sort(str) 90 >>> sort_rl(Basic(3, 1, 2)) 91 Basic(1, 2, 3) 92 """ 93 94 def sort_rl(expr): 95 return new(expr.__class__, *sorted(expr.args, key=key)) 96 return sort_rl 97 98def distribute(A, B): 99 """ Turns an A containing Bs into a B of As 100 101 where A, B are container types 102 103 >>> from sympy.strategies import distribute 104 >>> from sympy import Add, Mul, symbols 105 >>> x, y = symbols('x,y') 106 >>> dist = distribute(Mul, Add) 107 >>> expr = Mul(2, x+y, evaluate=False) 108 >>> expr 109 2*(x + y) 110 >>> dist(expr) 111 2*x + 2*y 112 """ 113 114 def distribute_rl(expr): 115 for i, arg in enumerate(expr.args): 116 if isinstance(arg, B): 117 first, b, tail = expr.args[:i], expr.args[i], expr.args[i+1:] 118 return B(*[A(*(first + (arg,) + tail)) for arg in b.args]) 119 return expr 120 return distribute_rl 121 122def subs(a, b): 123 """ Replace expressions exactly """ 124 def subs_rl(expr): 125 if expr == a: 126 return b 127 else: 128 return expr 129 return subs_rl 130 131# Functions that are rules 132 133def unpack(expr): 134 """ Rule to unpack singleton args 135 136 >>> from sympy.strategies import unpack 137 >>> from sympy import Basic 138 >>> unpack(Basic(2)) 139 2 140 """ 141 if len(expr.args) == 1: 142 return expr.args[0] 143 else: 144 return expr 145 146def flatten(expr, new=new): 147 """ Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """ 148 cls = expr.__class__ 149 args = [] 150 for arg in expr.args: 151 if arg.__class__ == cls: 152 args.extend(arg.args) 153 else: 154 args.append(arg) 155 return new(expr.__class__, *args) 156 157def rebuild(expr): 158 """ Rebuild a SymPy tree. 159 160 Explanation 161 =========== 162 163 This function recursively calls constructors in the expression tree. 164 This forces canonicalization and removes ugliness introduced by the use of 165 Basic.__new__ 166 """ 167 if expr.is_Atom: 168 return expr 169 else: 170 return expr.func(*list(map(rebuild, expr.args))) 171