1from sympy import symbols, sin, Matrix, Interval, Piecewise, Sum, lambdify, \ 2 Expr, sqrt 3from sympy.testing.pytest import raises 4 5from sympy.printing.tensorflow import TensorflowPrinter 6from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter 7 8 9x, y, z = symbols("x,y,z") 10i, a, b = symbols("i,a,b") 11j, c, d = symbols("j,c,d") 12 13 14def test_basic(): 15 assert lambdarepr(x*y) == "x*y" 16 assert lambdarepr(x + y) in ["y + x", "x + y"] 17 assert lambdarepr(x**y) == "x**y" 18 19 20def test_matrix(): 21 A = Matrix([[x, y], [y*x, z**2]]) 22 # assert lambdarepr(A) == "ImmutableDenseMatrix([[x, y], [x*y, z**2]])" 23 # Test printing a Matrix that has an element that is printed differently 24 # with the LambdaPrinter than in the StrPrinter. 25 p = Piecewise((x, True), evaluate=False) 26 A = Matrix([p]) 27 assert lambdarepr(A) == "ImmutableDenseMatrix([[((x))]])" 28 29 30def test_piecewise(): 31 # In each case, test eval() the lambdarepr() to make sure there are a 32 # correct number of parentheses. It will give a SyntaxError if there aren't. 33 34 h = "lambda x: " 35 36 p = Piecewise((x, True), evaluate=False) 37 l = lambdarepr(p) 38 eval(h + l) 39 assert l == "((x))" 40 41 p = Piecewise((x, x < 0)) 42 l = lambdarepr(p) 43 eval(h + l) 44 assert l == "((x) if (x < 0) else None)" 45 46 p = Piecewise( 47 (1, x < 1), 48 (2, x < 2), 49 (0, True) 50 ) 51 l = lambdarepr(p) 52 eval(h + l) 53 assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))" 54 55 p = Piecewise( 56 (1, x < 1), 57 (2, x < 2), 58 ) 59 l = lambdarepr(p) 60 eval(h + l) 61 assert l == "((1) if (x < 1) else (2) if (x < 2) else None)" 62 63 p = Piecewise( 64 (x, x < 1), 65 (x**2, Interval(3, 4, True, False).contains(x)), 66 (0, True), 67 ) 68 l = lambdarepr(p) 69 eval(h + l) 70 assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))" 71 72 p = Piecewise( 73 (x**2, x < 0), 74 (x, x < 1), 75 (2 - x, x >= 1), 76 (0, True), evaluate=False 77 ) 78 l = lambdarepr(p) 79 eval(h + l) 80 assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ 81 " else (2 - x) if (x >= 1) else (0))" 82 83 p = Piecewise( 84 (x**2, x < 0), 85 (x, x < 1), 86 (2 - x, x >= 1), evaluate=False 87 ) 88 l = lambdarepr(p) 89 eval(h + l) 90 assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ 91 " else (2 - x) if (x >= 1) else None)" 92 93 p = Piecewise( 94 (1, x >= 1), 95 (2, x >= 2), 96 (3, x >= 3), 97 (4, x >= 4), 98 (5, x >= 5), 99 (6, True) 100 ) 101 l = lambdarepr(p) 102 eval(h + l) 103 assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\ 104 " else (4) if (x >= 4) else (5) if (x >= 5) else (6))" 105 106 p = Piecewise( 107 (1, x <= 1), 108 (2, x <= 2), 109 (3, x <= 3), 110 (4, x <= 4), 111 (5, x <= 5), 112 (6, True) 113 ) 114 l = lambdarepr(p) 115 eval(h + l) 116 assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\ 117 " else (4) if (x <= 4) else (5) if (x <= 5) else (6))" 118 119 p = Piecewise( 120 (1, x > 1), 121 (2, x > 2), 122 (3, x > 3), 123 (4, x > 4), 124 (5, x > 5), 125 (6, True) 126 ) 127 l = lambdarepr(p) 128 eval(h + l) 129 assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\ 130 " else (4) if (x > 4) else (5) if (x > 5) else (6))" 131 132 p = Piecewise( 133 (1, x < 1), 134 (2, x < 2), 135 (3, x < 3), 136 (4, x < 4), 137 (5, x < 5), 138 (6, True) 139 ) 140 l = lambdarepr(p) 141 eval(h + l) 142 assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\ 143 " else (4) if (x < 4) else (5) if (x < 5) else (6))" 144 145 p = Piecewise( 146 (Piecewise( 147 (1, x > 0), 148 (2, True) 149 ), y > 0), 150 (3, True) 151 ) 152 l = lambdarepr(p) 153 eval(h + l) 154 assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))" 155 156 157def test_sum__1(): 158 # In each case, test eval() the lambdarepr() to make sure that 159 # it evaluates to the same results as the symbolic expression 160 s = Sum(x ** i, (i, a, b)) 161 l = lambdarepr(s) 162 assert l == "(builtins.sum(x**i for i in range(a, b+1)))" 163 164 args = x, a, b 165 f = lambdify(args, s) 166 v = 2, 3, 8 167 assert f(*v) == s.subs(zip(args, v)).doit() 168 169def test_sum__2(): 170 s = Sum(i * x, (i, a, b)) 171 l = lambdarepr(s) 172 assert l == "(builtins.sum(i*x for i in range(a, b+1)))" 173 174 args = x, a, b 175 f = lambdify(args, s) 176 v = 2, 3, 8 177 assert f(*v) == s.subs(zip(args, v)).doit() 178 179 180def test_multiple_sums(): 181 s = Sum(i * x + j, (i, a, b), (j, c, d)) 182 183 l = lambdarepr(s) 184 assert l == "(builtins.sum(i*x + j for i in range(a, b+1) for j in range(c, d+1)))" 185 186 args = x, a, b, c, d 187 f = lambdify(args, s) 188 vals = 2, 3, 4, 5, 6 189 f_ref = s.subs(zip(args, vals)).doit() 190 f_res = f(*vals) 191 assert f_res == f_ref 192 193 194def test_sqrt(): 195 prntr = LambdaPrinter({'standard' : 'python2'}) 196 assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)' 197 assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1./2.)' 198 prntr = LambdaPrinter({'standard' : 'python3'}) 199 assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' 200 201 202def test_settings(): 203 raises(TypeError, lambda: lambdarepr(sin(x), method="garbage")) 204 205 206def test_numexpr(): 207 # test ITE rewrite as Piecewise 208 from sympy.logic.boolalg import ITE 209 expr = ITE(x > 0, True, False, evaluate=False) 210 assert NumExprPrinter().doprint(expr) == \ 211 "evaluate('where((x > 0), True, False)', truediv=True)" 212 213 214class CustomPrintedObject(Expr): 215 def _lambdacode(self, printer): 216 return 'lambda' 217 218 def _tensorflowcode(self, printer): 219 return 'tensorflow' 220 221 def _numpycode(self, printer): 222 return 'numpy' 223 224 def _numexprcode(self, printer): 225 return 'numexpr' 226 227 def _mpmathcode(self, printer): 228 return 'mpmath' 229 230 231def test_printmethod(): 232 # In each case, printmethod is called to test 233 # its working 234 235 obj = CustomPrintedObject() 236 assert LambdaPrinter().doprint(obj) == 'lambda' 237 assert TensorflowPrinter().doprint(obj) == 'tensorflow' 238 assert NumExprPrinter().doprint(obj) == "evaluate('numexpr', truediv=True)" 239 240 assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \ 241 "evaluate('where((x >= 0), y, z)', truediv=True)" 242