1from sympy import ( 2 Symbol, Wild, sin, cos, exp, sqrt, pi, Function, Derivative, 3 Integer, Eq, symbols, Add, I, Float, log, Rational, 4 Lambda, atan2, cse, cot, tan, S, Tuple, Basic, Dict, 5 Piecewise, oo, Mul, factor, nsimplify, zoo, Subs, RootOf, 6 AccumBounds, Matrix, zeros, ZeroMatrix) 7from sympy.core.basic import _aresame 8from sympy.testing.pytest import XFAIL 9from sympy.abc import a, x, y, z, t 10 11def test_subs(): 12 n3 = Rational(3) 13 e = x 14 e = e.subs(x, n3) 15 assert e == Rational(3) 16 17 e = 2*x 18 assert e == 2*x 19 e = e.subs(x, n3) 20 assert e == Rational(6) 21 22 23def test_subs_Matrix(): 24 z = zeros(2) 25 z1 = ZeroMatrix(2, 2) 26 assert (x*y).subs({x:z, y:0}) in [z, z1] 27 assert (x*y).subs({y:z, x:0}) == 0 28 assert (x*y).subs({y:z, x:0}, simultaneous=True) in [z, z1] 29 assert (x + y).subs({x: z, y: z}, simultaneous=True) in [z, z1] 30 assert (x + y).subs({x: z, y: z}) in [z, z1] 31 32 # Issue #15528 33 assert Mul(Matrix([[3]]), x).subs(x, 2.0) == Matrix([[6.0]]) 34 # Does not raise a TypeError, see comment on the MatAdd postprocessor 35 assert Add(Matrix([[3]]), x).subs(x, 2.0) == Add(Matrix([[3]]), 2.0) 36 37def test_subs_AccumBounds(): 38 e = x 39 e = e.subs(x, AccumBounds(1, 3)) 40 assert e == AccumBounds(1, 3) 41 42 e = 2*x 43 e = e.subs(x, AccumBounds(1, 3)) 44 assert e == AccumBounds(2, 6) 45 46 e = x + x**2 47 e = e.subs(x, AccumBounds(-1, 1)) 48 assert e == AccumBounds(-1, 2) 49 50 51def test_trigonometric(): 52 n3 = Rational(3) 53 e = (sin(x)**2).diff(x) 54 assert e == 2*sin(x)*cos(x) 55 e = e.subs(x, n3) 56 assert e == 2*cos(n3)*sin(n3) 57 58 e = (sin(x)**2).diff(x) 59 assert e == 2*sin(x)*cos(x) 60 e = e.subs(sin(x), cos(x)) 61 assert e == 2*cos(x)**2 62 63 assert exp(pi).subs(exp, sin) == 0 64 assert cos(exp(pi)).subs(exp, sin) == 1 65 66 i = Symbol('i', integer=True) 67 zoo = S.ComplexInfinity 68 assert tan(x).subs(x, pi/2) is zoo 69 assert cot(x).subs(x, pi) is zoo 70 assert cot(i*x).subs(x, pi) is zoo 71 assert tan(i*x).subs(x, pi/2) == tan(i*pi/2) 72 assert tan(i*x).subs(x, pi/2).subs(i, 1) is zoo 73 o = Symbol('o', odd=True) 74 assert tan(o*x).subs(x, pi/2) == tan(o*pi/2) 75 76 77def test_powers(): 78 assert sqrt(1 - sqrt(x)).subs(x, 4) == I 79 assert (sqrt(1 - x**2)**3).subs(x, 2) == - 3*I*sqrt(3) 80 assert (x**Rational(1, 3)).subs(x, 27) == 3 81 assert (x**Rational(1, 3)).subs(x, -27) == 3*(-1)**Rational(1, 3) 82 assert ((-x)**Rational(1, 3)).subs(x, 27) == 3*(-1)**Rational(1, 3) 83 n = Symbol('n', negative=True) 84 assert (x**n).subs(x, 0) is S.ComplexInfinity 85 assert exp(-1).subs(S.Exp1, 0) is S.ComplexInfinity 86 assert (x**(4.0*y)).subs(x**(2.0*y), n) == n**2.0 87 assert (2**(x + 2)).subs(2, 3) == 3**(x + 3) 88 89 90def test_logexppow(): # no eval() 91 x = Symbol('x', real=True) 92 w = Symbol('w') 93 e = (3**(1 + x) + 2**(1 + x))/(3**x + 2**x) 94 assert e.subs(2**x, w) != e 95 assert e.subs(exp(x*log(Rational(2))), w) != e 96 97 98def test_bug(): 99 x1 = Symbol('x1') 100 x2 = Symbol('x2') 101 y = x1*x2 102 assert y.subs(x1, Float(3.0)) == Float(3.0)*x2 103 104 105def test_subbug1(): 106 # see that they don't fail 107 (x**x).subs(x, 1) 108 (x**x).subs(x, 1.0) 109 110 111def test_subbug2(): 112 # Ensure this does not cause infinite recursion 113 assert Float(7.7).epsilon_eq(abs(x).subs(x, -7.7)) 114 115 116def test_dict_set(): 117 a, b, c = map(Wild, 'abc') 118 119 f = 3*cos(4*x) 120 r = f.match(a*cos(b*x)) 121 assert r == {a: 3, b: 4} 122 e = a/b*sin(b*x) 123 assert e.subs(r) == r[a]/r[b]*sin(r[b]*x) 124 assert e.subs(r) == 3*sin(4*x) / 4 125 s = set(r.items()) 126 assert e.subs(s) == r[a]/r[b]*sin(r[b]*x) 127 assert e.subs(s) == 3*sin(4*x) / 4 128 129 assert e.subs(r) == r[a]/r[b]*sin(r[b]*x) 130 assert e.subs(r) == 3*sin(4*x) / 4 131 assert x.subs(Dict((x, 1))) == 1 132 133 134def test_dict_ambigous(): # see issue 3566 135 f = x*exp(x) 136 g = z*exp(z) 137 138 df = {x: y, exp(x): y} 139 dg = {z: y, exp(z): y} 140 141 assert f.subs(df) == y**2 142 assert g.subs(dg) == y**2 143 144 # and this is how order can affect the result 145 assert f.subs(x, y).subs(exp(x), y) == y*exp(y) 146 assert f.subs(exp(x), y).subs(x, y) == y**2 147 148 # length of args and count_ops are the same so 149 # default_sort_key resolves ordering...if one 150 # doesn't want this result then an unordered 151 # sequence should not be used. 152 e = 1 + x*y 153 assert e.subs({x: y, y: 2}) == 5 154 # here, there are no obviously clashing keys or values 155 # but the results depend on the order 156 assert exp(x/2 + y).subs({exp(y + 1): 2, x: 2}) == exp(y + 1) 157 158 159def test_deriv_sub_bug3(): 160 f = Function('f') 161 pat = Derivative(f(x), x, x) 162 assert pat.subs(y, y**2) == Derivative(f(x), x, x) 163 assert pat.subs(y, y**2) != Derivative(f(x), x) 164 165 166def test_equality_subs1(): 167 f = Function('f') 168 eq = Eq(f(x)**2, x) 169 res = Eq(Integer(16), x) 170 assert eq.subs(f(x), 4) == res 171 172 173def test_equality_subs2(): 174 f = Function('f') 175 eq = Eq(f(x)**2, 16) 176 assert bool(eq.subs(f(x), 3)) is False 177 assert bool(eq.subs(f(x), 4)) is True 178 179 180def test_issue_3742(): 181 e = sqrt(x)*exp(y) 182 assert e.subs(sqrt(x), 1) == exp(y) 183 184 185def test_subs_dict1(): 186 assert (1 + x*y).subs(x, pi) == 1 + pi*y 187 assert (1 + x*y).subs({x: pi, y: 2}) == 1 + 2*pi 188 189 c2, c3, q1p, q2p, c1, s1, s2, s3 = symbols('c2 c3 q1p q2p c1 s1 s2 s3') 190 test = (c2**2*q2p*c3 + c1**2*s2**2*q2p*c3 + s1**2*s2**2*q2p*c3 191 - c1**2*q1p*c2*s3 - s1**2*q1p*c2*s3) 192 assert (test.subs({c1**2: 1 - s1**2, c2**2: 1 - s2**2, c3**3: 1 - s3**2}) 193 == c3*q2p*(1 - s2**2) + c3*q2p*s2**2*(1 - s1**2) 194 - c2*q1p*s3*(1 - s1**2) + c3*q2p*s1**2*s2**2 - c2*q1p*s3*s1**2) 195 196 197def test_mul(): 198 x, y, z, a, b, c = symbols('x y z a b c') 199 A, B, C = symbols('A B C', commutative=0) 200 assert (x*y*z).subs(z*x, y) == y**2 201 assert (z*x).subs(1/x, z) == 1 202 assert (x*y/z).subs(1/z, a) == a*x*y 203 assert (x*y/z).subs(x/z, a) == a*y 204 assert (x*y/z).subs(y/z, a) == a*x 205 assert (x*y/z).subs(x/z, 1/a) == y/a 206 assert (x*y/z).subs(x, 1/a) == y/(z*a) 207 assert (2*x*y).subs(5*x*y, z) != z*Rational(2, 5) 208 assert (x*y*A).subs(x*y, a) == a*A 209 assert (x**2*y**(x*Rational(3, 2))).subs(x*y**(x/2), 2) == 4*y**(x/2) 210 assert (x*exp(x*2)).subs(x*exp(x), 2) == 2*exp(x) 211 assert ((x**(2*y))**3).subs(x**y, 2) == 64 212 assert (x*A*B).subs(x*A, y) == y*B 213 assert (x*y*(1 + x)*(1 + x*y)).subs(x*y, 2) == 6*(1 + x) 214 assert ((1 + A*B)*A*B).subs(A*B, x*A*B) 215 assert (x*a/z).subs(x/z, A) == a*A 216 assert (x**3*A).subs(x**2*A, a) == a*x 217 assert (x**2*A*B).subs(x**2*B, a) == a*A 218 assert (x**2*A*B).subs(x**2*A, a) == a*B 219 assert (b*A**3/(a**3*c**3)).subs(a**4*c**3*A**3/b**4, z) == \ 220 b*A**3/(a**3*c**3) 221 assert (6*x).subs(2*x, y) == 3*y 222 assert (y*exp(x*Rational(3, 2))).subs(y*exp(x), 2) == 2*exp(x/2) 223 assert (y*exp(x*Rational(3, 2))).subs(y*exp(x), 2) == 2*exp(x/2) 224 assert (A**2*B*A**2*B*A**2).subs(A*B*A, C) == A*C**2*A 225 assert (x*A**3).subs(x*A, y) == y*A**2 226 assert (x**2*A**3).subs(x*A, y) == y**2*A 227 assert (x*A**3).subs(x*A, B) == B*A**2 228 assert (x*A*B*A*exp(x*A*B)).subs(x*A, B) == B**2*A*exp(B*B) 229 assert (x**2*A*B*A*exp(x*A*B)).subs(x*A, B) == B**3*exp(B**2) 230 assert (x**3*A*exp(x*A*B)*A*exp(x*A*B)).subs(x*A, B) == \ 231 x*B*exp(B**2)*B*exp(B**2) 232 assert (x*A*B*C*A*B).subs(x*A*B, C) == C**2*A*B 233 assert (-I*a*b).subs(a*b, 2) == -2*I 234 235 # issue 6361 236 assert (-8*I*a).subs(-2*a, 1) == 4*I 237 assert (-I*a).subs(-a, 1) == I 238 239 # issue 6441 240 assert (4*x**2).subs(2*x, y) == y**2 241 assert (2*4*x**2).subs(2*x, y) == 2*y**2 242 assert (-x**3/9).subs(-x/3, z) == -z**2*x 243 assert (-x**3/9).subs(x/3, z) == -z**2*x 244 assert (-2*x**3/9).subs(x/3, z) == -2*x*z**2 245 assert (-2*x**3/9).subs(-x/3, z) == -2*x*z**2 246 assert (-2*x**3/9).subs(-2*x, z) == z*x**2/9 247 assert (-2*x**3/9).subs(2*x, z) == -z*x**2/9 248 assert (2*(3*x/5/7)**2).subs(3*x/5, z) == 2*(Rational(1, 7))**2*z**2 249 assert (4*x).subs(-2*x, z) == 4*x # try keep subs literal 250 251 252def test_subs_simple(): 253 a = symbols('a', commutative=True) 254 x = symbols('x', commutative=False) 255 256 assert (2*a).subs(1, 3) == 2*a 257 assert (2*a).subs(2, 3) == 3*a 258 assert (2*a).subs(a, 3) == 6 259 assert sin(2).subs(1, 3) == sin(2) 260 assert sin(2).subs(2, 3) == sin(3) 261 assert sin(a).subs(a, 3) == sin(3) 262 263 assert (2*x).subs(1, 3) == 2*x 264 assert (2*x).subs(2, 3) == 3*x 265 assert (2*x).subs(x, 3) == 6 266 assert sin(x).subs(x, 3) == sin(3) 267 268 269def test_subs_constants(): 270 a, b = symbols('a b', commutative=True) 271 x, y = symbols('x y', commutative=False) 272 273 assert (a*b).subs(2*a, 1) == a*b 274 assert (1.5*a*b).subs(a, 1) == 1.5*b 275 assert (2*a*b).subs(2*a, 1) == b 276 assert (2*a*b).subs(4*a, 1) == 2*a*b 277 278 assert (x*y).subs(2*x, 1) == x*y 279 assert (1.5*x*y).subs(x, 1) == 1.5*y 280 assert (2*x*y).subs(2*x, 1) == y 281 assert (2*x*y).subs(4*x, 1) == 2*x*y 282 283 284def test_subs_commutative(): 285 a, b, c, d, K = symbols('a b c d K', commutative=True) 286 287 assert (a*b).subs(a*b, K) == K 288 assert (a*b*a*b).subs(a*b, K) == K**2 289 assert (a*a*b*b).subs(a*b, K) == K**2 290 assert (a*b*c*d).subs(a*b*c, K) == d*K 291 assert (a*b**c).subs(a, K) == K*b**c 292 assert (a*b**c).subs(b, K) == a*K**c 293 assert (a*b**c).subs(c, K) == a*b**K 294 assert (a*b*c*b*a).subs(a*b, K) == c*K**2 295 assert (a**3*b**2*a).subs(a*b, K) == a**2*K**2 296 297 298def test_subs_noncommutative(): 299 w, x, y, z, L = symbols('w x y z L', commutative=False) 300 alpha = symbols('alpha', commutative=True) 301 someint = symbols('someint', commutative=True, integer=True) 302 303 assert (x*y).subs(x*y, L) == L 304 assert (w*y*x).subs(x*y, L) == w*y*x 305 assert (w*x*y*z).subs(x*y, L) == w*L*z 306 assert (x*y*x*y).subs(x*y, L) == L**2 307 assert (x*x*y).subs(x*y, L) == x*L 308 assert (x*x*y*y).subs(x*y, L) == x*L*y 309 assert (w*x*y).subs(x*y*z, L) == w*x*y 310 assert (x*y**z).subs(x, L) == L*y**z 311 assert (x*y**z).subs(y, L) == x*L**z 312 assert (x*y**z).subs(z, L) == x*y**L 313 assert (w*x*y*z*x*y).subs(x*y*z, L) == w*L*x*y 314 assert (w*x*y*y*w*x*x*y*x*y*y*x*y).subs(x*y, L) == w*L*y*w*x*L**2*y*L 315 316 # Check fractional power substitutions. It should not do 317 # substitutions that choose a value for noncommutative log, 318 # or inverses that don't already appear in the expressions. 319 assert (x*x*x).subs(x*x, L) == L*x 320 assert (x*x*x*y*x*x*x*x).subs(x*x, L) == L*x*y*L**2 321 for p in range(1, 5): 322 for k in range(10): 323 assert (y * x**k).subs(x**p, L) == y * L**(k//p) * x**(k % p) 324 assert (x**Rational(3, 2)).subs(x**S.Half, L) == x**Rational(3, 2) 325 assert (x**S.Half).subs(x**S.Half, L) == L 326 assert (x**Rational(-1, 2)).subs(x**S.Half, L) == x**Rational(-1, 2) 327 assert (x**Rational(-1, 2)).subs(x**Rational(-1, 2), L) == L 328 329 assert (x**(2*someint)).subs(x**someint, L) == L**2 330 assert (x**(2*someint + 3)).subs(x**someint, L) == L**2*x**3 331 assert (x**(3*someint + 3)).subs(x**someint, L) == L**3*x**3 332 assert (x**(3*someint)).subs(x**(2*someint), L) == L * x**someint 333 assert (x**(4*someint)).subs(x**(2*someint), L) == L**2 334 assert (x**(4*someint + 1)).subs(x**(2*someint), L) == L**2 * x 335 assert (x**(4*someint)).subs(x**(3*someint), L) == L * x**someint 336 assert (x**(4*someint + 1)).subs(x**(3*someint), L) == L * x**(someint + 1) 337 338 assert (x**(2*alpha)).subs(x**alpha, L) == x**(2*alpha) 339 assert (x**(2*alpha + 2)).subs(x**2, L) == x**(2*alpha + 2) 340 assert ((2*z)**alpha).subs(z**alpha, y) == (2*z)**alpha 341 assert (x**(2*someint*alpha)).subs(x**someint, L) == x**(2*someint*alpha) 342 assert (x**(2*someint + alpha)).subs(x**someint, L) == x**(2*someint + alpha) 343 344 # This could in principle be substituted, but is not currently 345 # because it requires recognizing that someint**2 is divisible by 346 # someint. 347 assert (x**(someint**2 + 3)).subs(x**someint, L) == x**(someint**2 + 3) 348 349 # alpha**z := exp(log(alpha) z) is usually well-defined 350 assert (4**z).subs(2**z, y) == y**2 351 352 # Negative powers 353 assert (x**(-1)).subs(x**3, L) == x**(-1) 354 assert (x**(-2)).subs(x**3, L) == x**(-2) 355 assert (x**(-3)).subs(x**3, L) == L**(-1) 356 assert (x**(-4)).subs(x**3, L) == L**(-1) * x**(-1) 357 assert (x**(-5)).subs(x**3, L) == L**(-1) * x**(-2) 358 359 assert (x**(-1)).subs(x**(-3), L) == x**(-1) 360 assert (x**(-2)).subs(x**(-3), L) == x**(-2) 361 assert (x**(-3)).subs(x**(-3), L) == L 362 assert (x**(-4)).subs(x**(-3), L) == L * x**(-1) 363 assert (x**(-5)).subs(x**(-3), L) == L * x**(-2) 364 365 assert (x**1).subs(x**(-3), L) == x 366 assert (x**2).subs(x**(-3), L) == x**2 367 assert (x**3).subs(x**(-3), L) == L**(-1) 368 assert (x**4).subs(x**(-3), L) == L**(-1) * x 369 assert (x**5).subs(x**(-3), L) == L**(-1) * x**2 370 371 372def test_subs_basic_funcs(): 373 a, b, c, d, K = symbols('a b c d K', commutative=True) 374 w, x, y, z, L = symbols('w x y z L', commutative=False) 375 376 assert (x + y).subs(x + y, L) == L 377 assert (x - y).subs(x - y, L) == L 378 assert (x/y).subs(x, L) == L/y 379 assert (x**y).subs(x, L) == L**y 380 assert (x**y).subs(y, L) == x**L 381 assert ((a - c)/b).subs(b, K) == (a - c)/K 382 assert (exp(x*y - z)).subs(x*y, L) == exp(L - z) 383 assert (a*exp(x*y - w*z) + b*exp(x*y + w*z)).subs(z, 0) == \ 384 a*exp(x*y) + b*exp(x*y) 385 assert ((a - b)/(c*d - a*b)).subs(c*d - a*b, K) == (a - b)/K 386 assert (w*exp(a*b - c)*x*y/4).subs(x*y, L) == w*exp(a*b - c)*L/4 387 388 389def test_subs_wild(): 390 R, S, T, U = symbols('R S T U', cls=Wild) 391 392 assert (R*S).subs(R*S, T) == T 393 assert (S*R).subs(R*S, T) == T 394 assert (R + S).subs(R + S, T) == T 395 assert (R**S).subs(R, T) == T**S 396 assert (R**S).subs(S, T) == R**T 397 assert (R*S**T).subs(R, U) == U*S**T 398 assert (R*S**T).subs(S, U) == R*U**T 399 assert (R*S**T).subs(T, U) == R*S**U 400 401 402def test_subs_mixed(): 403 a, b, c, d, K = symbols('a b c d K', commutative=True) 404 w, x, y, z, L = symbols('w x y z L', commutative=False) 405 R, S, T, U = symbols('R S T U', cls=Wild) 406 407 assert (a*x*y).subs(x*y, L) == a*L 408 assert (a*b*x*y*x).subs(x*y, L) == a*b*L*x 409 assert (R*x*y*exp(x*y)).subs(x*y, L) == R*L*exp(L) 410 assert (a*x*y*y*x - x*y*z*exp(a*b)).subs(x*y, L) == a*L*y*x - L*z*exp(a*b) 411 e = c*y*x*y*x**(R*S - a*b) - T*(a*R*b*S) 412 assert e.subs(x*y, L).subs(a*b, K).subs(R*S, U) == \ 413 c*y*L*x**(U - K) - T*(U*K) 414 415 416def test_division(): 417 a, b, c = symbols('a b c', commutative=True) 418 x, y, z = symbols('x y z', commutative=True) 419 420 assert (1/a).subs(a, c) == 1/c 421 assert (1/a**2).subs(a, c) == 1/c**2 422 assert (1/a**2).subs(a, -2) == Rational(1, 4) 423 assert (-(1/a**2)).subs(a, -2) == Rational(-1, 4) 424 425 assert (1/x).subs(x, z) == 1/z 426 assert (1/x**2).subs(x, z) == 1/z**2 427 assert (1/x**2).subs(x, -2) == Rational(1, 4) 428 assert (-(1/x**2)).subs(x, -2) == Rational(-1, 4) 429 430 #issue 5360 431 assert (1/x).subs(x, 0) == 1/S.Zero 432 433 434def test_add(): 435 a, b, c, d, x, y, t = symbols('a b c d x y t') 436 437 assert (a**2 - b - c).subs(a**2 - b, d) in [d - c, a**2 - b - c] 438 assert (a**2 - c).subs(a**2 - c, d) == d 439 assert (a**2 - b - c).subs(a**2 - c, d) in [d - b, a**2 - b - c] 440 assert (a**2 - x - c).subs(a**2 - c, d) in [d - x, a**2 - x - c] 441 assert (a**2 - b - sqrt(a)).subs(a**2 - sqrt(a), c) == c - b 442 assert (a + b + exp(a + b)).subs(a + b, c) == c + exp(c) 443 assert (c + b + exp(c + b)).subs(c + b, a) == a + exp(a) 444 assert (a + b + c + d).subs(b + c, x) == a + d + x 445 assert (a + b + c + d).subs(-b - c, x) == a + d - x 446 assert ((x + 1)*y).subs(x + 1, t) == t*y 447 assert ((-x - 1)*y).subs(x + 1, t) == -t*y 448 assert ((x - 1)*y).subs(x + 1, t) == y*(t - 2) 449 assert ((-x + 1)*y).subs(x + 1, t) == y*(-t + 2) 450 451 # this should work every time: 452 e = a**2 - b - c 453 assert e.subs(Add(*e.args[:2]), d) == d + e.args[2] 454 assert e.subs(a**2 - c, d) == d - b 455 456 # the fallback should recognize when a change has 457 # been made; while .1 == Rational(1, 10) they are not the same 458 # and the change should be made 459 assert (0.1 + a).subs(0.1, Rational(1, 10)) == Rational(1, 10) + a 460 461 e = (-x*(-y + 1) - y*(y - 1)) 462 ans = (-x*(x) - y*(-x)).expand() 463 assert e.subs(-y + 1, x) == ans 464 465 #Test issue 18747 466 assert (exp(x) + cos(x)).subs(x, oo) == oo 467 assert Add(*[AccumBounds(-1, 1), oo]) == oo 468 assert Add(*[oo, AccumBounds(-1, 1)]) == oo 469 470def test_subs_issue_4009(): 471 assert (I*Symbol('a')).subs(1, 2) == I*Symbol('a') 472 473 474def test_functions_subs(): 475 f, g = symbols('f g', cls=Function) 476 l = Lambda((x, y), sin(x) + y) 477 assert (g(y, x) + cos(x)).subs(g, l) == sin(y) + x + cos(x) 478 assert (f(x)**2).subs(f, sin) == sin(x)**2 479 assert (f(x, y)).subs(f, log) == log(x, y) 480 assert (f(x, y)).subs(f, sin) == f(x, y) 481 assert (sin(x) + atan2(x, y)).subs([[atan2, f], [sin, g]]) == \ 482 f(x, y) + g(x) 483 assert (g(f(x + y, x))).subs([[f, l], [g, exp]]) == exp(x + sin(x + y)) 484 485 486def test_derivative_subs(): 487 f = Function('f') 488 g = Function('g') 489 assert Derivative(f(x), x).subs(f(x), y) != 0 490 # need xreplace to put the function back, see #13803 491 assert Derivative(f(x), x).subs(f(x), y).xreplace({y: f(x)}) == \ 492 Derivative(f(x), x) 493 # issues 5085, 5037 494 assert cse(Derivative(f(x), x) + f(x))[1][0].has(Derivative) 495 assert cse(Derivative(f(x, y), x) + 496 Derivative(f(x, y), y))[1][0].has(Derivative) 497 eq = Derivative(g(x), g(x)) 498 assert eq.subs(g, f) == Derivative(f(x), f(x)) 499 assert eq.subs(g(x), f(x)) == Derivative(f(x), f(x)) 500 assert eq.subs(g, cos) == Subs(Derivative(y, y), y, cos(x)) 501 502 503def test_derivative_subs2(): 504 f_func, g_func = symbols('f g', cls=Function) 505 f, g = f_func(x, y, z), g_func(x, y, z) 506 assert Derivative(f, x, y).subs(Derivative(f, x, y), g) == g 507 assert Derivative(f, y, x).subs(Derivative(f, x, y), g) == g 508 assert Derivative(f, x, y).subs(Derivative(f, x), g) == Derivative(g, y) 509 assert Derivative(f, x, y).subs(Derivative(f, y), g) == Derivative(g, x) 510 assert (Derivative(f, x, y, z).subs( 511 Derivative(f, x, z), g) == Derivative(g, y)) 512 assert (Derivative(f, x, y, z).subs( 513 Derivative(f, z, y), g) == Derivative(g, x)) 514 assert (Derivative(f, x, y, z).subs( 515 Derivative(f, z, y, x), g) == g) 516 517 # Issue 9135 518 assert (Derivative(f, x, x, y).subs( 519 Derivative(f, y, y), g) == Derivative(f, x, x, y)) 520 assert (Derivative(f, x, y, y, z).subs( 521 Derivative(f, x, y, y, y), g) == Derivative(f, x, y, y, z)) 522 523 assert Derivative(f, x, y).subs(Derivative(f_func(x), x, y), g) == Derivative(f, x, y) 524 525 526def test_derivative_subs3(): 527 dex = Derivative(exp(x), x) 528 assert Derivative(dex, x).subs(dex, exp(x)) == dex 529 assert dex.subs(exp(x), dex) == Derivative(exp(x), x, x) 530 531 532def test_issue_5284(): 533 A, B = symbols('A B', commutative=False) 534 assert (x*A).subs(x**2*A, B) == x*A 535 assert (A**2).subs(A**3, B) == A**2 536 assert (A**6).subs(A**3, B) == B**2 537 538 539def test_subs_iter(): 540 assert x.subs(reversed([[x, y]])) == y 541 it = iter([[x, y]]) 542 assert x.subs(it) == y 543 assert x.subs(Tuple((x, y))) == y 544 545 546def test_subs_dict(): 547 a, b, c, d, e = symbols('a b c d e') 548 549 assert (2*x + y + z).subs(dict(x=1, y=2)) == 4 + z 550 551 l = [(sin(x), 2), (x, 1)] 552 assert (sin(x)).subs(l) == \ 553 (sin(x)).subs(dict(l)) == 2 554 assert sin(x).subs(reversed(l)) == sin(1) 555 556 expr = sin(2*x) + sqrt(sin(2*x))*cos(2*x)*sin(exp(x)*x) 557 reps = dict([ 558 (sin(2*x), c), 559 (sqrt(sin(2*x)), a), 560 (cos(2*x), b), 561 (exp(x), e), 562 (x, d), 563 ]) 564 assert expr.subs(reps) == c + a*b*sin(d*e) 565 566 l = [(x, 3), (y, x**2)] 567 assert (x + y).subs(l) == 3 + x**2 568 assert (x + y).subs(reversed(l)) == 12 569 570 # If changes are made to convert lists into dictionaries and do 571 # a dictionary-lookup replacement, these tests will help to catch 572 # some logical errors that might occur 573 l = [(y, z + 2), (1 + z, 5), (z, 2)] 574 assert (y - 1 + 3*x).subs(l) == 5 + 3*x 575 l = [(y, z + 2), (z, 3)] 576 assert (y - 2).subs(l) == 3 577 578 579def test_no_arith_subs_on_floats(): 580 assert (x + 3).subs(x + 3, a) == a 581 assert (x + 3).subs(x + 2, a) == a + 1 582 583 assert (x + y + 3).subs(x + 3, a) == a + y 584 assert (x + y + 3).subs(x + 2, a) == a + y + 1 585 586 assert (x + 3.0).subs(x + 3.0, a) == a 587 assert (x + 3.0).subs(x + 2.0, a) == x + 3.0 588 589 assert (x + y + 3.0).subs(x + 3.0, a) == a + y 590 assert (x + y + 3.0).subs(x + 2.0, a) == x + y + 3.0 591 592 593def test_issue_5651(): 594 a, b, c, K = symbols('a b c K', commutative=True) 595 assert (a/(b*c)).subs(b*c, K) == a/K 596 assert (a/(b**2*c**3)).subs(b*c, K) == a/(c*K**2) 597 assert (1/(x*y)).subs(x*y, 2) == S.Half 598 assert ((1 + x*y)/(x*y)).subs(x*y, 1) == 2 599 assert (x*y*z).subs(x*y, 2) == 2*z 600 assert ((1 + x*y)/(x*y)/z).subs(x*y, 1) == 2/z 601 602 603def test_issue_6075(): 604 assert Tuple(1, True).subs(1, 2) == Tuple(2, True) 605 606 607def test_issue_6079(): 608 # since x + 2.0 == x + 2 we can't do a simple equality test 609 assert _aresame((x + 2.0).subs(2, 3), x + 2.0) 610 assert _aresame((x + 2.0).subs(2.0, 3), x + 3) 611 assert not _aresame(x + 2, x + 2.0) 612 assert not _aresame(Basic(cos, 1), Basic(cos, 1.)) 613 assert _aresame(cos, cos) 614 assert not _aresame(1, S.One) 615 assert not _aresame(x, symbols('x', positive=True)) 616 617 618def test_issue_4680(): 619 N = Symbol('N') 620 assert N.subs(dict(N=3)) == 3 621 622 623def test_issue_6158(): 624 assert (x - 1).subs(1, y) == x - y 625 assert (x - 1).subs(-1, y) == x + y 626 assert (x - oo).subs(oo, y) == x - y 627 assert (x - oo).subs(-oo, y) == x + y 628 629 630def test_Function_subs(): 631 f, g, h, i = symbols('f g h i', cls=Function) 632 p = Piecewise((g(f(x, y)), x < -1), (g(x), x <= 1)) 633 assert p.subs(g, h) == Piecewise((h(f(x, y)), x < -1), (h(x), x <= 1)) 634 assert (f(y) + g(x)).subs({f: h, g: i}) == i(x) + h(y) 635 636 637def test_simultaneous_subs(): 638 reps = {x: 0, y: 0} 639 assert (x/y).subs(reps) != (y/x).subs(reps) 640 assert (x/y).subs(reps, simultaneous=True) == \ 641 (y/x).subs(reps, simultaneous=True) 642 reps = reps.items() 643 assert (x/y).subs(reps) != (y/x).subs(reps) 644 assert (x/y).subs(reps, simultaneous=True) == \ 645 (y/x).subs(reps, simultaneous=True) 646 assert Derivative(x, y, z).subs(reps, simultaneous=True) == \ 647 Subs(Derivative(0, y, z), y, 0) 648 649 650def test_issue_6419_6421(): 651 assert (1/(1 + x/y)).subs(x/y, x) == 1/(1 + x) 652 assert (-2*I).subs(2*I, x) == -x 653 assert (-I*x).subs(I*x, x) == -x 654 assert (-3*I*y**4).subs(3*I*y**2, x) == -x*y**2 655 656 657def test_issue_6559(): 658 assert (-12*x + y).subs(-x, 1) == 12 + y 659 # though this involves cse it generated a failure in Mul._eval_subs 660 x0, x1 = symbols('x0 x1') 661 e = -log(-12*sqrt(2) + 17)/24 - log(-2*sqrt(2) + 3)/12 + sqrt(2)/3 662 # XXX modify cse so x1 is eliminated and x0 = -sqrt(2)? 663 assert cse(e) == ( 664 [(x0, sqrt(2))], [x0/3 - log(-12*x0 + 17)/24 - log(-2*x0 + 3)/12]) 665 666 667def test_issue_5261(): 668 x = symbols('x', real=True) 669 e = I*x 670 assert exp(e).subs(exp(x), y) == y**I 671 assert (2**e).subs(2**x, y) == y**I 672 eq = (-2)**e 673 assert eq.subs((-2)**x, y) == eq 674 675 676def test_issue_6923(): 677 assert (-2*x*sqrt(2)).subs(2*x, y) == -sqrt(2)*y 678 679 680def test_2arg_hack(): 681 N = Symbol('N', commutative=False) 682 ans = Mul(2, y + 1, evaluate=False) 683 assert (2*x*(y + 1)).subs(x, 1, hack2=True) == ans 684 assert (2*(y + 1 + N)).subs(N, 0, hack2=True) == ans 685 686 687@XFAIL 688def test_mul2(): 689 """When this fails, remove things labelled "2-arg hack" 690 1) remove special handling in the fallback of subs that 691 was added in the same commit as this test 692 2) remove the special handling in Mul.flatten 693 """ 694 assert (2*(x + 1)).is_Mul 695 696 697def test_noncommutative_subs(): 698 x,y = symbols('x,y', commutative=False) 699 assert (x*y*x).subs([(x, x*y), (y, x)], simultaneous=True) == (x*y*x**2*y) 700 701 702def test_issue_2877(): 703 f = Float(2.0) 704 assert (x + f).subs({f: 2}) == x + 2 705 706 def r(a, b, c): 707 return factor(a*x**2 + b*x + c) 708 e = r(5.0/6, 10, 5) 709 assert nsimplify(e) == 5*x**2/6 + 10*x + 5 710 711 712def test_issue_5910(): 713 t = Symbol('t') 714 assert (1/(1 - t)).subs(t, 1) is zoo 715 n = t 716 d = t - 1 717 assert (n/d).subs(t, 1) is zoo 718 assert (-n/-d).subs(t, 1) is zoo 719 720 721def test_issue_5217(): 722 s = Symbol('s') 723 z = (1 - 2*x*x) 724 w = (1 + 2*x*x) 725 q = 2*x*x*2*y*y 726 sub = {2*x*x: s} 727 assert w.subs(sub) == 1 + s 728 assert z.subs(sub) == 1 - s 729 assert q == 4*x**2*y**2 730 assert q.subs(sub) == 2*y**2*s 731 732 733def test_issue_10829(): 734 assert (4**x).subs(2**x, y) == y**2 735 assert (9**x).subs(3**x, y) == y**2 736 737 738def test_pow_eval_subs_no_cache(): 739 # Tests pull request 9376 is working 740 from sympy.core.cache import clear_cache 741 742 s = 1/sqrt(x**2) 743 # This bug only appeared when the cache was turned off. 744 # We need to approximate running this test without the cache. 745 # This creates approximately the same situation. 746 clear_cache() 747 748 # This used to fail with a wrong result. 749 # It incorrectly returned 1/sqrt(x**2) before this pull request. 750 result = s.subs(sqrt(x**2), y) 751 assert result == 1/y 752 753 754def test_RootOf_issue_10092(): 755 x = Symbol('x', real=True) 756 eq = x**3 - 17*x**2 + 81*x - 118 757 r = RootOf(eq, 0) 758 assert (x < r).subs(x, r) is S.false 759 760 761def test_issue_8886(): 762 from sympy.physics.mechanics import ReferenceFrame as R 763 # if something can't be sympified we assume that it 764 # doesn't play well with SymPy and disallow the 765 # substitution 766 v = R('A').x 767 assert x.subs(x, v) == x 768 assert v.subs(v, x) == v 769 assert v.__eq__(x) is False 770 771 772def test_issue_12657(): 773 # treat -oo like the atom that it is 774 reps = [(-oo, 1), (oo, 2)] 775 assert (x < -oo).subs(reps) == (x < 1) 776 assert (x < -oo).subs(list(reversed(reps))) == (x < 1) 777 reps = [(-oo, 2), (oo, 1)] 778 assert (x < oo).subs(reps) == (x < 1) 779 assert (x < oo).subs(list(reversed(reps))) == (x < 1) 780 781 782def test_recurse_Application_args(): 783 F = Lambda((x, y), exp(2*x + 3*y)) 784 f = Function('f') 785 A = f(x, f(x, x)) 786 C = F(x, F(x, x)) 787 assert A.subs(f, F) == A.replace(f, F) == C 788 789 790def test_Subs_subs(): 791 assert Subs(x*y, x, x).subs(x, y) == Subs(x*y, x, y) 792 assert Subs(x*y, x, x + 1).subs(x, y) == \ 793 Subs(x*y, x, y + 1) 794 assert Subs(x*y, y, x + 1).subs(x, y) == \ 795 Subs(y**2, y, y + 1) 796 a = Subs(x*y*z, (y, x, z), (x + 1, x + z, x)) 797 b = Subs(x*y*z, (y, x, z), (x + 1, y + z, y)) 798 assert a.subs(x, y) == b and \ 799 a.doit().subs(x, y) == a.subs(x, y).doit() 800 f = Function('f') 801 g = Function('g') 802 assert Subs(2*f(x, y) + g(x), f(x, y), 1).subs(y, 2) == Subs( 803 2*f(x, y) + g(x), (f(x, y), y), (1, 2)) 804 805 806def test_issue_13333(): 807 eq = 1/x 808 assert eq.subs(dict(x='1/2')) == 2 809 assert eq.subs(dict(x='(1/2)')) == 2 810 811 812def test_issue_15234(): 813 x, y = symbols('x y', real=True) 814 p = 6*x**5 + x**4 - 4*x**3 + 4*x**2 - 2*x + 3 815 p_subbed = 6*x**5 - 4*x**3 - 2*x + y**4 + 4*y**2 + 3 816 assert p.subs([(x**i, y**i) for i in [2, 4]]) == p_subbed 817 x, y = symbols('x y', complex=True) 818 p = 6*x**5 + x**4 - 4*x**3 + 4*x**2 - 2*x + 3 819 p_subbed = 6*x**5 - 4*x**3 - 2*x + y**4 + 4*y**2 + 3 820 assert p.subs([(x**i, y**i) for i in [2, 4]]) == p_subbed 821 822 823def test_issue_6976(): 824 x, y = symbols('x y') 825 assert (sqrt(x)**3 + sqrt(x) + x + x**2).subs(sqrt(x), y) == \ 826 y**4 + y**3 + y**2 + y 827 assert (x**4 + x**3 + x**2 + x + sqrt(x)).subs(x**2, y) == \ 828 sqrt(x) + x**3 + x + y**2 + y 829 assert x.subs(x**3, y) == x 830 assert x.subs(x**Rational(1, 3), y) == y**3 831 832 # More substitutions are possible with nonnegative symbols 833 x, y = symbols('x y', nonnegative=True) 834 assert (x**4 + x**3 + x**2 + x + sqrt(x)).subs(x**2, y) == \ 835 y**Rational(1, 4) + y**Rational(3, 2) + sqrt(y) + y**2 + y 836 assert x.subs(x**3, y) == y**Rational(1, 3) 837 838 839def test_issue_11746(): 840 assert (1/x).subs(x**2, 1) == 1/x 841 assert (1/(x**3)).subs(x**2, 1) == x**(-3) 842 assert (1/(x**4)).subs(x**2, 1) == 1 843 assert (1/(x**3)).subs(x**4, 1) == x**(-3) 844 assert (1/(y**5)).subs(x**5, 1) == y**(-5) 845 846 847def test_issue_17823(): 848 from sympy.physics.mechanics import dynamicsymbols 849 q1, q2 = dynamicsymbols('q1, q2') 850 expr = q1.diff().diff()**2*q1 + q1.diff()*q2.diff() 851 reps={q1: a, q1.diff(): a*x*y, q1.diff().diff(): z} 852 assert expr.subs(reps) == a*x*y*Derivative(q2, t) + a*z**2 853 854 855def test_issue_19326(): 856 x, y = [i(t) for i in map(Function, 'xy')] 857 assert (x*y).subs({x: 1 + x, y: x}) == (1 + x)*x 858 859def test_issue_19558(): 860 e = (7*x*cos(x) - 12*log(x)**3)*(-log(x)**4 + 2*sin(x) + 1)**2/ \ 861 (2*(x*cos(x) - 2*log(x)**3)*(3*log(x)**4 - 7*sin(x) + 3)**2) 862 863 assert e.subs(x, oo) == AccumBounds(-oo, oo) 864 assert (sin(x) + cos(x)).subs(x, oo) == AccumBounds(-2, 2) 865