1from sympy import (
2    Symbol, Wild, sin, cos, exp, sqrt, pi, Function, Derivative,
3    Integer, Eq, symbols, Add, I, Float, log, Rational,
4    Lambda, atan2, cse, cot, tan, S, Tuple, Basic, Dict,
5    Piecewise, oo, Mul, factor, nsimplify, zoo, Subs, RootOf,
6    AccumBounds, Matrix, zeros, ZeroMatrix)
7from sympy.core.basic import _aresame
8from sympy.testing.pytest import XFAIL
9from sympy.abc import a, x, y, z, t
10
11def test_subs():
12    n3 = Rational(3)
13    e = x
14    e = e.subs(x, n3)
15    assert e == Rational(3)
16
17    e = 2*x
18    assert e == 2*x
19    e = e.subs(x, n3)
20    assert e == Rational(6)
21
22
23def test_subs_Matrix():
24    z = zeros(2)
25    z1 = ZeroMatrix(2, 2)
26    assert (x*y).subs({x:z, y:0}) in [z, z1]
27    assert (x*y).subs({y:z, x:0}) == 0
28    assert (x*y).subs({y:z, x:0}, simultaneous=True) in [z, z1]
29    assert (x + y).subs({x: z, y: z}, simultaneous=True) in [z, z1]
30    assert (x + y).subs({x: z, y: z}) in [z, z1]
31
32    # Issue #15528
33    assert Mul(Matrix([[3]]), x).subs(x, 2.0) == Matrix([[6.0]])
34    # Does not raise a TypeError, see comment on the MatAdd postprocessor
35    assert Add(Matrix([[3]]), x).subs(x, 2.0) == Add(Matrix([[3]]), 2.0)
36
37def test_subs_AccumBounds():
38    e = x
39    e = e.subs(x, AccumBounds(1, 3))
40    assert e == AccumBounds(1, 3)
41
42    e = 2*x
43    e = e.subs(x, AccumBounds(1, 3))
44    assert e == AccumBounds(2, 6)
45
46    e = x + x**2
47    e = e.subs(x, AccumBounds(-1, 1))
48    assert e == AccumBounds(-1, 2)
49
50
51def test_trigonometric():
52    n3 = Rational(3)
53    e = (sin(x)**2).diff(x)
54    assert e == 2*sin(x)*cos(x)
55    e = e.subs(x, n3)
56    assert e == 2*cos(n3)*sin(n3)
57
58    e = (sin(x)**2).diff(x)
59    assert e == 2*sin(x)*cos(x)
60    e = e.subs(sin(x), cos(x))
61    assert e == 2*cos(x)**2
62
63    assert exp(pi).subs(exp, sin) == 0
64    assert cos(exp(pi)).subs(exp, sin) == 1
65
66    i = Symbol('i', integer=True)
67    zoo = S.ComplexInfinity
68    assert tan(x).subs(x, pi/2) is zoo
69    assert cot(x).subs(x, pi) is zoo
70    assert cot(i*x).subs(x, pi) is zoo
71    assert tan(i*x).subs(x, pi/2) == tan(i*pi/2)
72    assert tan(i*x).subs(x, pi/2).subs(i, 1) is zoo
73    o = Symbol('o', odd=True)
74    assert tan(o*x).subs(x, pi/2) == tan(o*pi/2)
75
76
77def test_powers():
78    assert sqrt(1 - sqrt(x)).subs(x, 4) == I
79    assert (sqrt(1 - x**2)**3).subs(x, 2) == - 3*I*sqrt(3)
80    assert (x**Rational(1, 3)).subs(x, 27) == 3
81    assert (x**Rational(1, 3)).subs(x, -27) == 3*(-1)**Rational(1, 3)
82    assert ((-x)**Rational(1, 3)).subs(x, 27) == 3*(-1)**Rational(1, 3)
83    n = Symbol('n', negative=True)
84    assert (x**n).subs(x, 0) is S.ComplexInfinity
85    assert exp(-1).subs(S.Exp1, 0) is S.ComplexInfinity
86    assert (x**(4.0*y)).subs(x**(2.0*y), n) == n**2.0
87    assert (2**(x + 2)).subs(2, 3) == 3**(x + 3)
88
89
90def test_logexppow():   # no eval()
91    x = Symbol('x', real=True)
92    w = Symbol('w')
93    e = (3**(1 + x) + 2**(1 + x))/(3**x + 2**x)
94    assert e.subs(2**x, w) != e
95    assert e.subs(exp(x*log(Rational(2))), w) != e
96
97
98def test_bug():
99    x1 = Symbol('x1')
100    x2 = Symbol('x2')
101    y = x1*x2
102    assert y.subs(x1, Float(3.0)) == Float(3.0)*x2
103
104
105def test_subbug1():
106    # see that they don't fail
107    (x**x).subs(x, 1)
108    (x**x).subs(x, 1.0)
109
110
111def test_subbug2():
112    # Ensure this does not cause infinite recursion
113    assert Float(7.7).epsilon_eq(abs(x).subs(x, -7.7))
114
115
116def test_dict_set():
117    a, b, c = map(Wild, 'abc')
118
119    f = 3*cos(4*x)
120    r = f.match(a*cos(b*x))
121    assert r == {a: 3, b: 4}
122    e = a/b*sin(b*x)
123    assert e.subs(r) == r[a]/r[b]*sin(r[b]*x)
124    assert e.subs(r) == 3*sin(4*x) / 4
125    s = set(r.items())
126    assert e.subs(s) == r[a]/r[b]*sin(r[b]*x)
127    assert e.subs(s) == 3*sin(4*x) / 4
128
129    assert e.subs(r) == r[a]/r[b]*sin(r[b]*x)
130    assert e.subs(r) == 3*sin(4*x) / 4
131    assert x.subs(Dict((x, 1))) == 1
132
133
134def test_dict_ambigous():   # see issue 3566
135    f = x*exp(x)
136    g = z*exp(z)
137
138    df = {x: y, exp(x): y}
139    dg = {z: y, exp(z): y}
140
141    assert f.subs(df) == y**2
142    assert g.subs(dg) == y**2
143
144    # and this is how order can affect the result
145    assert f.subs(x, y).subs(exp(x), y) == y*exp(y)
146    assert f.subs(exp(x), y).subs(x, y) == y**2
147
148    # length of args and count_ops are the same so
149    # default_sort_key resolves ordering...if one
150    # doesn't want this result then an unordered
151    # sequence should not be used.
152    e = 1 + x*y
153    assert e.subs({x: y, y: 2}) == 5
154    # here, there are no obviously clashing keys or values
155    # but the results depend on the order
156    assert exp(x/2 + y).subs({exp(y + 1): 2, x: 2}) == exp(y + 1)
157
158
159def test_deriv_sub_bug3():
160    f = Function('f')
161    pat = Derivative(f(x), x, x)
162    assert pat.subs(y, y**2) == Derivative(f(x), x, x)
163    assert pat.subs(y, y**2) != Derivative(f(x), x)
164
165
166def test_equality_subs1():
167    f = Function('f')
168    eq = Eq(f(x)**2, x)
169    res = Eq(Integer(16), x)
170    assert eq.subs(f(x), 4) == res
171
172
173def test_equality_subs2():
174    f = Function('f')
175    eq = Eq(f(x)**2, 16)
176    assert bool(eq.subs(f(x), 3)) is False
177    assert bool(eq.subs(f(x), 4)) is True
178
179
180def test_issue_3742():
181    e = sqrt(x)*exp(y)
182    assert e.subs(sqrt(x), 1) == exp(y)
183
184
185def test_subs_dict1():
186    assert (1 + x*y).subs(x, pi) == 1 + pi*y
187    assert (1 + x*y).subs({x: pi, y: 2}) == 1 + 2*pi
188
189    c2, c3, q1p, q2p, c1, s1, s2, s3 = symbols('c2 c3 q1p q2p c1 s1 s2 s3')
190    test = (c2**2*q2p*c3 + c1**2*s2**2*q2p*c3 + s1**2*s2**2*q2p*c3
191            - c1**2*q1p*c2*s3 - s1**2*q1p*c2*s3)
192    assert (test.subs({c1**2: 1 - s1**2, c2**2: 1 - s2**2, c3**3: 1 - s3**2})
193        == c3*q2p*(1 - s2**2) + c3*q2p*s2**2*(1 - s1**2)
194            - c2*q1p*s3*(1 - s1**2) + c3*q2p*s1**2*s2**2 - c2*q1p*s3*s1**2)
195
196
197def test_mul():
198    x, y, z, a, b, c = symbols('x y z a b c')
199    A, B, C = symbols('A B C', commutative=0)
200    assert (x*y*z).subs(z*x, y) == y**2
201    assert (z*x).subs(1/x, z) == 1
202    assert (x*y/z).subs(1/z, a) == a*x*y
203    assert (x*y/z).subs(x/z, a) == a*y
204    assert (x*y/z).subs(y/z, a) == a*x
205    assert (x*y/z).subs(x/z, 1/a) == y/a
206    assert (x*y/z).subs(x, 1/a) == y/(z*a)
207    assert (2*x*y).subs(5*x*y, z) != z*Rational(2, 5)
208    assert (x*y*A).subs(x*y, a) == a*A
209    assert (x**2*y**(x*Rational(3, 2))).subs(x*y**(x/2), 2) == 4*y**(x/2)
210    assert (x*exp(x*2)).subs(x*exp(x), 2) == 2*exp(x)
211    assert ((x**(2*y))**3).subs(x**y, 2) == 64
212    assert (x*A*B).subs(x*A, y) == y*B
213    assert (x*y*(1 + x)*(1 + x*y)).subs(x*y, 2) == 6*(1 + x)
214    assert ((1 + A*B)*A*B).subs(A*B, x*A*B)
215    assert (x*a/z).subs(x/z, A) == a*A
216    assert (x**3*A).subs(x**2*A, a) == a*x
217    assert (x**2*A*B).subs(x**2*B, a) == a*A
218    assert (x**2*A*B).subs(x**2*A, a) == a*B
219    assert (b*A**3/(a**3*c**3)).subs(a**4*c**3*A**3/b**4, z) == \
220        b*A**3/(a**3*c**3)
221    assert (6*x).subs(2*x, y) == 3*y
222    assert (y*exp(x*Rational(3, 2))).subs(y*exp(x), 2) == 2*exp(x/2)
223    assert (y*exp(x*Rational(3, 2))).subs(y*exp(x), 2) == 2*exp(x/2)
224    assert (A**2*B*A**2*B*A**2).subs(A*B*A, C) == A*C**2*A
225    assert (x*A**3).subs(x*A, y) == y*A**2
226    assert (x**2*A**3).subs(x*A, y) == y**2*A
227    assert (x*A**3).subs(x*A, B) == B*A**2
228    assert (x*A*B*A*exp(x*A*B)).subs(x*A, B) == B**2*A*exp(B*B)
229    assert (x**2*A*B*A*exp(x*A*B)).subs(x*A, B) == B**3*exp(B**2)
230    assert (x**3*A*exp(x*A*B)*A*exp(x*A*B)).subs(x*A, B) == \
231        x*B*exp(B**2)*B*exp(B**2)
232    assert (x*A*B*C*A*B).subs(x*A*B, C) == C**2*A*B
233    assert (-I*a*b).subs(a*b, 2) == -2*I
234
235    # issue 6361
236    assert (-8*I*a).subs(-2*a, 1) == 4*I
237    assert (-I*a).subs(-a, 1) == I
238
239    # issue 6441
240    assert (4*x**2).subs(2*x, y) == y**2
241    assert (2*4*x**2).subs(2*x, y) == 2*y**2
242    assert (-x**3/9).subs(-x/3, z) == -z**2*x
243    assert (-x**3/9).subs(x/3, z) == -z**2*x
244    assert (-2*x**3/9).subs(x/3, z) == -2*x*z**2
245    assert (-2*x**3/9).subs(-x/3, z) == -2*x*z**2
246    assert (-2*x**3/9).subs(-2*x, z) == z*x**2/9
247    assert (-2*x**3/9).subs(2*x, z) == -z*x**2/9
248    assert (2*(3*x/5/7)**2).subs(3*x/5, z) == 2*(Rational(1, 7))**2*z**2
249    assert (4*x).subs(-2*x, z) == 4*x  # try keep subs literal
250
251
252def test_subs_simple():
253    a = symbols('a', commutative=True)
254    x = symbols('x', commutative=False)
255
256    assert (2*a).subs(1, 3) == 2*a
257    assert (2*a).subs(2, 3) == 3*a
258    assert (2*a).subs(a, 3) == 6
259    assert sin(2).subs(1, 3) == sin(2)
260    assert sin(2).subs(2, 3) == sin(3)
261    assert sin(a).subs(a, 3) == sin(3)
262
263    assert (2*x).subs(1, 3) == 2*x
264    assert (2*x).subs(2, 3) == 3*x
265    assert (2*x).subs(x, 3) == 6
266    assert sin(x).subs(x, 3) == sin(3)
267
268
269def test_subs_constants():
270    a, b = symbols('a b', commutative=True)
271    x, y = symbols('x y', commutative=False)
272
273    assert (a*b).subs(2*a, 1) == a*b
274    assert (1.5*a*b).subs(a, 1) == 1.5*b
275    assert (2*a*b).subs(2*a, 1) == b
276    assert (2*a*b).subs(4*a, 1) == 2*a*b
277
278    assert (x*y).subs(2*x, 1) == x*y
279    assert (1.5*x*y).subs(x, 1) == 1.5*y
280    assert (2*x*y).subs(2*x, 1) == y
281    assert (2*x*y).subs(4*x, 1) == 2*x*y
282
283
284def test_subs_commutative():
285    a, b, c, d, K = symbols('a b c d K', commutative=True)
286
287    assert (a*b).subs(a*b, K) == K
288    assert (a*b*a*b).subs(a*b, K) == K**2
289    assert (a*a*b*b).subs(a*b, K) == K**2
290    assert (a*b*c*d).subs(a*b*c, K) == d*K
291    assert (a*b**c).subs(a, K) == K*b**c
292    assert (a*b**c).subs(b, K) == a*K**c
293    assert (a*b**c).subs(c, K) == a*b**K
294    assert (a*b*c*b*a).subs(a*b, K) == c*K**2
295    assert (a**3*b**2*a).subs(a*b, K) == a**2*K**2
296
297
298def test_subs_noncommutative():
299    w, x, y, z, L = symbols('w x y z L', commutative=False)
300    alpha = symbols('alpha', commutative=True)
301    someint = symbols('someint', commutative=True, integer=True)
302
303    assert (x*y).subs(x*y, L) == L
304    assert (w*y*x).subs(x*y, L) == w*y*x
305    assert (w*x*y*z).subs(x*y, L) == w*L*z
306    assert (x*y*x*y).subs(x*y, L) == L**2
307    assert (x*x*y).subs(x*y, L) == x*L
308    assert (x*x*y*y).subs(x*y, L) == x*L*y
309    assert (w*x*y).subs(x*y*z, L) == w*x*y
310    assert (x*y**z).subs(x, L) == L*y**z
311    assert (x*y**z).subs(y, L) == x*L**z
312    assert (x*y**z).subs(z, L) == x*y**L
313    assert (w*x*y*z*x*y).subs(x*y*z, L) == w*L*x*y
314    assert (w*x*y*y*w*x*x*y*x*y*y*x*y).subs(x*y, L) == w*L*y*w*x*L**2*y*L
315
316    # Check fractional power substitutions. It should not do
317    # substitutions that choose a value for noncommutative log,
318    # or inverses that don't already appear in the expressions.
319    assert (x*x*x).subs(x*x, L) == L*x
320    assert (x*x*x*y*x*x*x*x).subs(x*x, L) == L*x*y*L**2
321    for p in range(1, 5):
322        for k in range(10):
323            assert (y * x**k).subs(x**p, L) == y * L**(k//p) * x**(k % p)
324    assert (x**Rational(3, 2)).subs(x**S.Half, L) == x**Rational(3, 2)
325    assert (x**S.Half).subs(x**S.Half, L) == L
326    assert (x**Rational(-1, 2)).subs(x**S.Half, L) == x**Rational(-1, 2)
327    assert (x**Rational(-1, 2)).subs(x**Rational(-1, 2), L) == L
328
329    assert (x**(2*someint)).subs(x**someint, L) == L**2
330    assert (x**(2*someint + 3)).subs(x**someint, L) == L**2*x**3
331    assert (x**(3*someint + 3)).subs(x**someint, L) == L**3*x**3
332    assert (x**(3*someint)).subs(x**(2*someint), L) == L * x**someint
333    assert (x**(4*someint)).subs(x**(2*someint), L) == L**2
334    assert (x**(4*someint + 1)).subs(x**(2*someint), L) == L**2 * x
335    assert (x**(4*someint)).subs(x**(3*someint), L) == L * x**someint
336    assert (x**(4*someint + 1)).subs(x**(3*someint), L) == L * x**(someint + 1)
337
338    assert (x**(2*alpha)).subs(x**alpha, L) == x**(2*alpha)
339    assert (x**(2*alpha + 2)).subs(x**2, L) == x**(2*alpha + 2)
340    assert ((2*z)**alpha).subs(z**alpha, y) == (2*z)**alpha
341    assert (x**(2*someint*alpha)).subs(x**someint, L) == x**(2*someint*alpha)
342    assert (x**(2*someint + alpha)).subs(x**someint, L) == x**(2*someint + alpha)
343
344    # This could in principle be substituted, but is not currently
345    # because it requires recognizing that someint**2 is divisible by
346    # someint.
347    assert (x**(someint**2 + 3)).subs(x**someint, L) == x**(someint**2 + 3)
348
349    # alpha**z := exp(log(alpha) z) is usually well-defined
350    assert (4**z).subs(2**z, y) == y**2
351
352    # Negative powers
353    assert (x**(-1)).subs(x**3, L) == x**(-1)
354    assert (x**(-2)).subs(x**3, L) == x**(-2)
355    assert (x**(-3)).subs(x**3, L) == L**(-1)
356    assert (x**(-4)).subs(x**3, L) == L**(-1) * x**(-1)
357    assert (x**(-5)).subs(x**3, L) == L**(-1) * x**(-2)
358
359    assert (x**(-1)).subs(x**(-3), L) == x**(-1)
360    assert (x**(-2)).subs(x**(-3), L) == x**(-2)
361    assert (x**(-3)).subs(x**(-3), L) == L
362    assert (x**(-4)).subs(x**(-3), L) == L * x**(-1)
363    assert (x**(-5)).subs(x**(-3), L) == L * x**(-2)
364
365    assert (x**1).subs(x**(-3), L) == x
366    assert (x**2).subs(x**(-3), L) == x**2
367    assert (x**3).subs(x**(-3), L) == L**(-1)
368    assert (x**4).subs(x**(-3), L) == L**(-1) * x
369    assert (x**5).subs(x**(-3), L) == L**(-1) * x**2
370
371
372def test_subs_basic_funcs():
373    a, b, c, d, K = symbols('a b c d K', commutative=True)
374    w, x, y, z, L = symbols('w x y z L', commutative=False)
375
376    assert (x + y).subs(x + y, L) == L
377    assert (x - y).subs(x - y, L) == L
378    assert (x/y).subs(x, L) == L/y
379    assert (x**y).subs(x, L) == L**y
380    assert (x**y).subs(y, L) == x**L
381    assert ((a - c)/b).subs(b, K) == (a - c)/K
382    assert (exp(x*y - z)).subs(x*y, L) == exp(L - z)
383    assert (a*exp(x*y - w*z) + b*exp(x*y + w*z)).subs(z, 0) == \
384        a*exp(x*y) + b*exp(x*y)
385    assert ((a - b)/(c*d - a*b)).subs(c*d - a*b, K) == (a - b)/K
386    assert (w*exp(a*b - c)*x*y/4).subs(x*y, L) == w*exp(a*b - c)*L/4
387
388
389def test_subs_wild():
390    R, S, T, U = symbols('R S T U', cls=Wild)
391
392    assert (R*S).subs(R*S, T) == T
393    assert (S*R).subs(R*S, T) == T
394    assert (R + S).subs(R + S, T) == T
395    assert (R**S).subs(R, T) == T**S
396    assert (R**S).subs(S, T) == R**T
397    assert (R*S**T).subs(R, U) == U*S**T
398    assert (R*S**T).subs(S, U) == R*U**T
399    assert (R*S**T).subs(T, U) == R*S**U
400
401
402def test_subs_mixed():
403    a, b, c, d, K = symbols('a b c d K', commutative=True)
404    w, x, y, z, L = symbols('w x y z L', commutative=False)
405    R, S, T, U = symbols('R S T U', cls=Wild)
406
407    assert (a*x*y).subs(x*y, L) == a*L
408    assert (a*b*x*y*x).subs(x*y, L) == a*b*L*x
409    assert (R*x*y*exp(x*y)).subs(x*y, L) == R*L*exp(L)
410    assert (a*x*y*y*x - x*y*z*exp(a*b)).subs(x*y, L) == a*L*y*x - L*z*exp(a*b)
411    e = c*y*x*y*x**(R*S - a*b) - T*(a*R*b*S)
412    assert e.subs(x*y, L).subs(a*b, K).subs(R*S, U) == \
413        c*y*L*x**(U - K) - T*(U*K)
414
415
416def test_division():
417    a, b, c = symbols('a b c', commutative=True)
418    x, y, z = symbols('x y z', commutative=True)
419
420    assert (1/a).subs(a, c) == 1/c
421    assert (1/a**2).subs(a, c) == 1/c**2
422    assert (1/a**2).subs(a, -2) == Rational(1, 4)
423    assert (-(1/a**2)).subs(a, -2) == Rational(-1, 4)
424
425    assert (1/x).subs(x, z) == 1/z
426    assert (1/x**2).subs(x, z) == 1/z**2
427    assert (1/x**2).subs(x, -2) == Rational(1, 4)
428    assert (-(1/x**2)).subs(x, -2) == Rational(-1, 4)
429
430    #issue 5360
431    assert (1/x).subs(x, 0) == 1/S.Zero
432
433
434def test_add():
435    a, b, c, d, x, y, t = symbols('a b c d x y t')
436
437    assert (a**2 - b - c).subs(a**2 - b, d) in [d - c, a**2 - b - c]
438    assert (a**2 - c).subs(a**2 - c, d) == d
439    assert (a**2 - b - c).subs(a**2 - c, d) in [d - b, a**2 - b - c]
440    assert (a**2 - x - c).subs(a**2 - c, d) in [d - x, a**2 - x - c]
441    assert (a**2 - b - sqrt(a)).subs(a**2 - sqrt(a), c) == c - b
442    assert (a + b + exp(a + b)).subs(a + b, c) == c + exp(c)
443    assert (c + b + exp(c + b)).subs(c + b, a) == a + exp(a)
444    assert (a + b + c + d).subs(b + c, x) == a + d + x
445    assert (a + b + c + d).subs(-b - c, x) == a + d - x
446    assert ((x + 1)*y).subs(x + 1, t) == t*y
447    assert ((-x - 1)*y).subs(x + 1, t) == -t*y
448    assert ((x - 1)*y).subs(x + 1, t) == y*(t - 2)
449    assert ((-x + 1)*y).subs(x + 1, t) == y*(-t + 2)
450
451    # this should work every time:
452    e = a**2 - b - c
453    assert e.subs(Add(*e.args[:2]), d) == d + e.args[2]
454    assert e.subs(a**2 - c, d) == d - b
455
456    # the fallback should recognize when a change has
457    # been made; while .1 == Rational(1, 10) they are not the same
458    # and the change should be made
459    assert (0.1 + a).subs(0.1, Rational(1, 10)) == Rational(1, 10) + a
460
461    e = (-x*(-y + 1) - y*(y - 1))
462    ans = (-x*(x) - y*(-x)).expand()
463    assert e.subs(-y + 1, x) == ans
464
465    #Test issue 18747
466    assert (exp(x) + cos(x)).subs(x, oo) == oo
467    assert Add(*[AccumBounds(-1, 1), oo]) == oo
468    assert Add(*[oo, AccumBounds(-1, 1)]) == oo
469
470def test_subs_issue_4009():
471    assert (I*Symbol('a')).subs(1, 2) == I*Symbol('a')
472
473
474def test_functions_subs():
475    f, g = symbols('f g', cls=Function)
476    l = Lambda((x, y), sin(x) + y)
477    assert (g(y, x) + cos(x)).subs(g, l) == sin(y) + x + cos(x)
478    assert (f(x)**2).subs(f, sin) == sin(x)**2
479    assert (f(x, y)).subs(f, log) == log(x, y)
480    assert (f(x, y)).subs(f, sin) == f(x, y)
481    assert (sin(x) + atan2(x, y)).subs([[atan2, f], [sin, g]]) == \
482        f(x, y) + g(x)
483    assert (g(f(x + y, x))).subs([[f, l], [g, exp]]) == exp(x + sin(x + y))
484
485
486def test_derivative_subs():
487    f = Function('f')
488    g = Function('g')
489    assert Derivative(f(x), x).subs(f(x), y) != 0
490    # need xreplace to put the function back, see #13803
491    assert Derivative(f(x), x).subs(f(x), y).xreplace({y: f(x)}) == \
492        Derivative(f(x), x)
493    # issues 5085, 5037
494    assert cse(Derivative(f(x), x) + f(x))[1][0].has(Derivative)
495    assert cse(Derivative(f(x, y), x) +
496               Derivative(f(x, y), y))[1][0].has(Derivative)
497    eq = Derivative(g(x), g(x))
498    assert eq.subs(g, f) == Derivative(f(x), f(x))
499    assert eq.subs(g(x), f(x)) == Derivative(f(x), f(x))
500    assert eq.subs(g, cos) == Subs(Derivative(y, y), y, cos(x))
501
502
503def test_derivative_subs2():
504    f_func, g_func = symbols('f g', cls=Function)
505    f, g = f_func(x, y, z), g_func(x, y, z)
506    assert Derivative(f, x, y).subs(Derivative(f, x, y), g) == g
507    assert Derivative(f, y, x).subs(Derivative(f, x, y), g) == g
508    assert Derivative(f, x, y).subs(Derivative(f, x), g) == Derivative(g, y)
509    assert Derivative(f, x, y).subs(Derivative(f, y), g) == Derivative(g, x)
510    assert (Derivative(f, x, y, z).subs(
511                Derivative(f, x, z), g) == Derivative(g, y))
512    assert (Derivative(f, x, y, z).subs(
513                Derivative(f, z, y), g) == Derivative(g, x))
514    assert (Derivative(f, x, y, z).subs(
515                Derivative(f, z, y, x), g) == g)
516
517    # Issue 9135
518    assert (Derivative(f, x, x, y).subs(
519                Derivative(f, y, y), g) == Derivative(f, x, x, y))
520    assert (Derivative(f, x, y, y, z).subs(
521                Derivative(f, x, y, y, y), g) == Derivative(f, x, y, y, z))
522
523    assert Derivative(f, x, y).subs(Derivative(f_func(x), x, y), g) == Derivative(f, x, y)
524
525
526def test_derivative_subs3():
527    dex = Derivative(exp(x), x)
528    assert Derivative(dex, x).subs(dex, exp(x)) == dex
529    assert dex.subs(exp(x), dex) == Derivative(exp(x), x, x)
530
531
532def test_issue_5284():
533    A, B = symbols('A B', commutative=False)
534    assert (x*A).subs(x**2*A, B) == x*A
535    assert (A**2).subs(A**3, B) == A**2
536    assert (A**6).subs(A**3, B) == B**2
537
538
539def test_subs_iter():
540    assert x.subs(reversed([[x, y]])) == y
541    it = iter([[x, y]])
542    assert x.subs(it) == y
543    assert x.subs(Tuple((x, y))) == y
544
545
546def test_subs_dict():
547    a, b, c, d, e = symbols('a b c d e')
548
549    assert (2*x + y + z).subs(dict(x=1, y=2)) == 4 + z
550
551    l = [(sin(x), 2), (x, 1)]
552    assert (sin(x)).subs(l) == \
553           (sin(x)).subs(dict(l)) == 2
554    assert sin(x).subs(reversed(l)) == sin(1)
555
556    expr = sin(2*x) + sqrt(sin(2*x))*cos(2*x)*sin(exp(x)*x)
557    reps = dict([
558               (sin(2*x), c),
559               (sqrt(sin(2*x)), a),
560               (cos(2*x), b),
561               (exp(x), e),
562               (x, d),
563    ])
564    assert expr.subs(reps) == c + a*b*sin(d*e)
565
566    l = [(x, 3), (y, x**2)]
567    assert (x + y).subs(l) == 3 + x**2
568    assert (x + y).subs(reversed(l)) == 12
569
570    # If changes are made to convert lists into dictionaries and do
571    # a dictionary-lookup replacement, these tests will help to catch
572    # some logical errors that might occur
573    l = [(y, z + 2), (1 + z, 5), (z, 2)]
574    assert (y - 1 + 3*x).subs(l) == 5 + 3*x
575    l = [(y, z + 2), (z, 3)]
576    assert (y - 2).subs(l) == 3
577
578
579def test_no_arith_subs_on_floats():
580    assert (x + 3).subs(x + 3, a) == a
581    assert (x + 3).subs(x + 2, a) == a + 1
582
583    assert (x + y + 3).subs(x + 3, a) == a + y
584    assert (x + y + 3).subs(x + 2, a) == a + y + 1
585
586    assert (x + 3.0).subs(x + 3.0, a) == a
587    assert (x + 3.0).subs(x + 2.0, a) == x + 3.0
588
589    assert (x + y + 3.0).subs(x + 3.0, a) == a + y
590    assert (x + y + 3.0).subs(x + 2.0, a) == x + y + 3.0
591
592
593def test_issue_5651():
594    a, b, c, K = symbols('a b c K', commutative=True)
595    assert (a/(b*c)).subs(b*c, K) == a/K
596    assert (a/(b**2*c**3)).subs(b*c, K) == a/(c*K**2)
597    assert (1/(x*y)).subs(x*y, 2) == S.Half
598    assert ((1 + x*y)/(x*y)).subs(x*y, 1) == 2
599    assert (x*y*z).subs(x*y, 2) == 2*z
600    assert ((1 + x*y)/(x*y)/z).subs(x*y, 1) == 2/z
601
602
603def test_issue_6075():
604    assert Tuple(1, True).subs(1, 2) == Tuple(2, True)
605
606
607def test_issue_6079():
608    # since x + 2.0 == x + 2 we can't do a simple equality test
609    assert _aresame((x + 2.0).subs(2, 3), x + 2.0)
610    assert _aresame((x + 2.0).subs(2.0, 3), x + 3)
611    assert not _aresame(x + 2, x + 2.0)
612    assert not _aresame(Basic(cos, 1), Basic(cos, 1.))
613    assert _aresame(cos, cos)
614    assert not _aresame(1, S.One)
615    assert not _aresame(x, symbols('x', positive=True))
616
617
618def test_issue_4680():
619    N = Symbol('N')
620    assert N.subs(dict(N=3)) == 3
621
622
623def test_issue_6158():
624    assert (x - 1).subs(1, y) == x - y
625    assert (x - 1).subs(-1, y) == x + y
626    assert (x - oo).subs(oo, y) == x - y
627    assert (x - oo).subs(-oo, y) == x + y
628
629
630def test_Function_subs():
631    f, g, h, i = symbols('f g h i', cls=Function)
632    p = Piecewise((g(f(x, y)), x < -1), (g(x), x <= 1))
633    assert p.subs(g, h) == Piecewise((h(f(x, y)), x < -1), (h(x), x <= 1))
634    assert (f(y) + g(x)).subs({f: h, g: i}) == i(x) + h(y)
635
636
637def test_simultaneous_subs():
638    reps = {x: 0, y: 0}
639    assert (x/y).subs(reps) != (y/x).subs(reps)
640    assert (x/y).subs(reps, simultaneous=True) == \
641        (y/x).subs(reps, simultaneous=True)
642    reps = reps.items()
643    assert (x/y).subs(reps) != (y/x).subs(reps)
644    assert (x/y).subs(reps, simultaneous=True) == \
645        (y/x).subs(reps, simultaneous=True)
646    assert Derivative(x, y, z).subs(reps, simultaneous=True) == \
647        Subs(Derivative(0, y, z), y, 0)
648
649
650def test_issue_6419_6421():
651    assert (1/(1 + x/y)).subs(x/y, x) == 1/(1 + x)
652    assert (-2*I).subs(2*I, x) == -x
653    assert (-I*x).subs(I*x, x) == -x
654    assert (-3*I*y**4).subs(3*I*y**2, x) == -x*y**2
655
656
657def test_issue_6559():
658    assert (-12*x + y).subs(-x, 1) == 12 + y
659    # though this involves cse it generated a failure in Mul._eval_subs
660    x0, x1 = symbols('x0 x1')
661    e = -log(-12*sqrt(2) + 17)/24 - log(-2*sqrt(2) + 3)/12 + sqrt(2)/3
662    # XXX modify cse so x1 is eliminated and x0 = -sqrt(2)?
663    assert cse(e) == (
664        [(x0, sqrt(2))], [x0/3 - log(-12*x0 + 17)/24 - log(-2*x0 + 3)/12])
665
666
667def test_issue_5261():
668    x = symbols('x', real=True)
669    e = I*x
670    assert exp(e).subs(exp(x), y) == y**I
671    assert (2**e).subs(2**x, y) == y**I
672    eq = (-2)**e
673    assert eq.subs((-2)**x, y) == eq
674
675
676def test_issue_6923():
677    assert (-2*x*sqrt(2)).subs(2*x, y) == -sqrt(2)*y
678
679
680def test_2arg_hack():
681    N = Symbol('N', commutative=False)
682    ans = Mul(2, y + 1, evaluate=False)
683    assert (2*x*(y + 1)).subs(x, 1, hack2=True) == ans
684    assert (2*(y + 1 + N)).subs(N, 0, hack2=True) == ans
685
686
687@XFAIL
688def test_mul2():
689    """When this fails, remove things labelled "2-arg hack"
690    1) remove special handling in the fallback of subs that
691    was added in the same commit as this test
692    2) remove the special handling in Mul.flatten
693    """
694    assert (2*(x + 1)).is_Mul
695
696
697def test_noncommutative_subs():
698    x,y = symbols('x,y', commutative=False)
699    assert (x*y*x).subs([(x, x*y), (y, x)], simultaneous=True) == (x*y*x**2*y)
700
701
702def test_issue_2877():
703    f = Float(2.0)
704    assert (x + f).subs({f: 2}) == x + 2
705
706    def r(a, b, c):
707        return factor(a*x**2 + b*x + c)
708    e = r(5.0/6, 10, 5)
709    assert nsimplify(e) == 5*x**2/6 + 10*x + 5
710
711
712def test_issue_5910():
713    t = Symbol('t')
714    assert (1/(1 - t)).subs(t, 1) is zoo
715    n = t
716    d = t - 1
717    assert (n/d).subs(t, 1) is zoo
718    assert (-n/-d).subs(t, 1) is zoo
719
720
721def test_issue_5217():
722    s = Symbol('s')
723    z = (1 - 2*x*x)
724    w = (1 + 2*x*x)
725    q = 2*x*x*2*y*y
726    sub = {2*x*x: s}
727    assert w.subs(sub) == 1 + s
728    assert z.subs(sub) == 1 - s
729    assert q == 4*x**2*y**2
730    assert q.subs(sub) == 2*y**2*s
731
732
733def test_issue_10829():
734    assert (4**x).subs(2**x, y) == y**2
735    assert (9**x).subs(3**x, y) == y**2
736
737
738def test_pow_eval_subs_no_cache():
739    # Tests pull request 9376 is working
740    from sympy.core.cache import clear_cache
741
742    s = 1/sqrt(x**2)
743    # This bug only appeared when the cache was turned off.
744    # We need to approximate running this test without the cache.
745    # This creates approximately the same situation.
746    clear_cache()
747
748    # This used to fail with a wrong result.
749    # It incorrectly returned 1/sqrt(x**2) before this pull request.
750    result = s.subs(sqrt(x**2), y)
751    assert result == 1/y
752
753
754def test_RootOf_issue_10092():
755    x = Symbol('x', real=True)
756    eq = x**3 - 17*x**2 + 81*x - 118
757    r = RootOf(eq, 0)
758    assert (x < r).subs(x, r) is S.false
759
760
761def test_issue_8886():
762    from sympy.physics.mechanics import ReferenceFrame as R
763    # if something can't be sympified we assume that it
764    # doesn't play well with SymPy and disallow the
765    # substitution
766    v = R('A').x
767    assert x.subs(x, v) == x
768    assert v.subs(v, x) == v
769    assert v.__eq__(x) is False
770
771
772def test_issue_12657():
773    # treat -oo like the atom that it is
774    reps = [(-oo, 1), (oo, 2)]
775    assert (x < -oo).subs(reps) == (x < 1)
776    assert (x < -oo).subs(list(reversed(reps))) == (x < 1)
777    reps = [(-oo, 2), (oo, 1)]
778    assert (x < oo).subs(reps) == (x < 1)
779    assert (x < oo).subs(list(reversed(reps))) == (x < 1)
780
781
782def test_recurse_Application_args():
783    F = Lambda((x, y), exp(2*x + 3*y))
784    f = Function('f')
785    A = f(x, f(x, x))
786    C = F(x, F(x, x))
787    assert A.subs(f, F) == A.replace(f, F) == C
788
789
790def test_Subs_subs():
791    assert Subs(x*y, x, x).subs(x, y) == Subs(x*y, x, y)
792    assert Subs(x*y, x, x + 1).subs(x, y) == \
793        Subs(x*y, x, y + 1)
794    assert Subs(x*y, y, x + 1).subs(x, y) == \
795        Subs(y**2, y, y + 1)
796    a = Subs(x*y*z, (y, x, z), (x + 1, x + z, x))
797    b = Subs(x*y*z, (y, x, z), (x + 1, y + z, y))
798    assert a.subs(x, y) == b and \
799        a.doit().subs(x, y) == a.subs(x, y).doit()
800    f = Function('f')
801    g = Function('g')
802    assert Subs(2*f(x, y) + g(x), f(x, y), 1).subs(y, 2) == Subs(
803        2*f(x, y) + g(x), (f(x, y), y), (1, 2))
804
805
806def test_issue_13333():
807    eq = 1/x
808    assert eq.subs(dict(x='1/2')) == 2
809    assert eq.subs(dict(x='(1/2)')) == 2
810
811
812def test_issue_15234():
813    x, y = symbols('x y', real=True)
814    p = 6*x**5 + x**4 - 4*x**3 + 4*x**2 - 2*x + 3
815    p_subbed = 6*x**5 - 4*x**3 - 2*x + y**4 + 4*y**2 + 3
816    assert p.subs([(x**i, y**i) for i in [2, 4]]) == p_subbed
817    x, y = symbols('x y', complex=True)
818    p = 6*x**5 + x**4 - 4*x**3 + 4*x**2 - 2*x + 3
819    p_subbed = 6*x**5 - 4*x**3 - 2*x + y**4 + 4*y**2 + 3
820    assert p.subs([(x**i, y**i) for i in [2, 4]]) == p_subbed
821
822
823def test_issue_6976():
824    x, y = symbols('x y')
825    assert (sqrt(x)**3 + sqrt(x) + x + x**2).subs(sqrt(x), y) == \
826        y**4 + y**3 + y**2 + y
827    assert (x**4 + x**3 + x**2 + x + sqrt(x)).subs(x**2, y) == \
828        sqrt(x) + x**3 + x + y**2 + y
829    assert x.subs(x**3, y) == x
830    assert x.subs(x**Rational(1, 3), y) == y**3
831
832    # More substitutions are possible with nonnegative symbols
833    x, y = symbols('x y', nonnegative=True)
834    assert (x**4 + x**3 + x**2 + x + sqrt(x)).subs(x**2, y) == \
835        y**Rational(1, 4) + y**Rational(3, 2) + sqrt(y) + y**2 + y
836    assert x.subs(x**3, y) == y**Rational(1, 3)
837
838
839def test_issue_11746():
840    assert (1/x).subs(x**2, 1) == 1/x
841    assert (1/(x**3)).subs(x**2, 1) == x**(-3)
842    assert (1/(x**4)).subs(x**2, 1) == 1
843    assert (1/(x**3)).subs(x**4, 1) == x**(-3)
844    assert (1/(y**5)).subs(x**5, 1) == y**(-5)
845
846
847def test_issue_17823():
848    from sympy.physics.mechanics import dynamicsymbols
849    q1, q2 = dynamicsymbols('q1, q2')
850    expr = q1.diff().diff()**2*q1 + q1.diff()*q2.diff()
851    reps={q1: a, q1.diff(): a*x*y, q1.diff().diff(): z}
852    assert expr.subs(reps) == a*x*y*Derivative(q2, t) + a*z**2
853
854
855def test_issue_19326():
856    x, y = [i(t) for i in map(Function, 'xy')]
857    assert (x*y).subs({x: 1 + x, y: x}) == (1 + x)*x
858
859def test_issue_19558():
860    e = (7*x*cos(x) - 12*log(x)**3)*(-log(x)**4 + 2*sin(x) + 1)**2/ \
861    (2*(x*cos(x) - 2*log(x)**3)*(3*log(x)**4 - 7*sin(x) + 3)**2)
862
863    assert e.subs(x, oo) == AccumBounds(-oo, oo)
864    assert (sin(x) + cos(x)).subs(x, oo) == AccumBounds(-2, 2)
865