1import pytest
2
3from diofant import (Add, Basic, Derivative, Dict, E, Eq, Float, Function, I,
4                     Integer, Lambda, Min, Piecewise, Rational, RootOf, Subs,
5                     Symbol, Tuple, Wild, abc, atan2, cbrt, cos, cot, cse, exp,
6                     factor, false, log, nsimplify, oo, pi, sin, sqrt, symbols,
7                     tan, zoo)
8from diofant.abc import a, b, c, d, e, t, x, y, z
9from diofant.core.basic import _aresame
10from diofant.core.cache import clear_cache
11
12
13__all__ = ()
14
15
16def test_subs():
17    n3 = Integer(3)
18    e = x
19    e = e.subs({x: n3})
20    assert e == 3
21
22    e = 2*x
23    assert e == 2*x
24    e = e.subs({x: n3})
25    assert e == 6
26
27
28def test_trigonometric():
29    n3 = Integer(3)
30    e = (sin(x)**2).diff(x)
31    assert e == 2*sin(x)*cos(x)
32    e = e.subs({x: n3})
33    assert e == 2*cos(n3)*sin(n3)
34
35    e = (sin(x)**2).diff(x)
36    assert e == 2*sin(x)*cos(x)
37    e = e.subs({sin(x): cos(x)})
38    assert e == 2*cos(x)**2
39
40    i = Symbol('i', integer=True)
41    assert tan(x).subs({x: pi/2}) is zoo
42    assert cot(x).subs({x: pi}) is zoo
43    assert cot(i*x).subs({x: pi}) is zoo
44    assert tan(i*x).subs({x: pi/2}) == tan(i*pi/2)
45    assert tan(i*x).subs({x: pi/2}).subs({i: 1}) is zoo
46    o = Symbol('o', odd=True)
47    assert tan(o*x).subs({x: pi/2}) == tan(o*pi/2)
48
49
50def test_powers():
51    assert sqrt(1 - sqrt(x)).subs({x: 4}) == I
52    assert (sqrt(1 - x**2)**3).subs({x: 2}) == - 3*I*sqrt(3)
53    assert cbrt(x).subs({x: 27}) == 3
54    assert cbrt(x).subs({x: -27}) == 3*cbrt(-1)
55    assert cbrt(-x).subs({x: 27}) == 3*cbrt(-1)
56    n = Symbol('n', negative=True)
57    assert (x**n).subs({x: 0}) is zoo
58    assert exp(-1).subs({E: 0}) is zoo
59    assert (x**(4.0*y)).subs({x**(2.0*y): n}) == n**2.0
60    assert (2**(x + 2)).subs({2: 3}) == 3**(x + 3)
61
62    # issue sympy/sympy#10829
63    assert (4**x).subs({2**x: y}) == y**2
64    assert (9**x).subs({3**x: y}) == y**2
65
66    # issue sympy/sympy#6923
67    assert (-2*x*sqrt(2)).subs({2*x: y}) == -sqrt(2)*y
68
69
70def test_logexppow():   # no eval()
71    x = Symbol('x', extended_real=True)
72    w = Symbol('w')
73    e = (3**(1 + x) + 2**(1 + x))/(3**x + 2**x)
74    assert e.subs({2**x: w}) != e
75    assert e.subs({exp(x*log(2)): w}) != e
76
77
78def test_bug():
79    x1 = Symbol('x1')
80    x2 = Symbol('x2')
81    y = x1*x2
82    assert y.subs({x1: Float(3.0)}) == Float(3.0)*x2
83
84
85def test_subbug1():
86    # see that they don't fail
87    (x**x).subs({x: 1})
88    (x**x).subs({x: 1.0})
89
90
91def test_subbug2():
92    # Ensure this does not cause infinite recursion
93    assert Float(7.7).epsilon_eq(abs(x).subs({x: -7.7}))
94
95
96def test_dict_set():
97    a, b = map(Wild, 'ab')
98
99    f = 3*cos(4*x)
100    r = f.match(a*cos(b*x))
101    assert r == {a: 3, b: 4}
102    e = a/b*sin(b*x)
103    assert e.subs(r) == r[a]/r[b]*sin(r[b]*x)
104    assert e.subs(r) == 3*sin(4*x) / 4
105    s = set(r.items())
106    assert e.subs(s) == r[a]/r[b]*sin(r[b]*x)
107    assert e.subs(s) == 3*sin(4*x) / 4
108
109    assert e.subs(r) == r[a]/r[b]*sin(r[b]*x)
110    assert e.subs(r) == 3*sin(4*x) / 4
111    assert x.subs(Dict((x, 1))) == 1
112
113
114def test_dict_ambigous():   # see issue sympy/sympy#3566
115    y = Symbol('y')
116    z = Symbol('z')
117
118    f = x*exp(x)
119    g = z*exp(z)
120
121    df = {x: y, exp(x): y}
122    dg = {z: y, exp(z): y}
123
124    assert f.subs(df) == y**2
125    assert g.subs(dg) == y**2
126
127    # and this is how order can affect the result
128    assert f.subs({x: y}).subs({exp(x): y}) == y*exp(y)
129    assert f.subs({exp(x): y}).subs({x: y}) == y**2
130
131    # length of args and count_ops are the same so
132    # default_sort_key resolves ordering...if one
133    # doesn't want this result then an unordered
134    # sequence should not be used.
135    e = 1 + x*y
136    assert e.subs({x: y, y: 2}) == 5
137    # here, there are no obviously clashing keys or values
138    # but the results depend on the order
139    assert exp(x/2 + y).subs({exp(y + 1): 2, x: 2}) == exp(y + 1)
140
141
142def test_deriv_sub_bug3():
143    y = Symbol('y')
144    f = Function('f')
145    pat = Derivative(f(x), x, x)
146    assert pat.subs({y: y**2}) == Derivative(f(x), x, x)
147    assert pat.subs({y: y**2}) != Derivative(f(x), x)
148
149
150def test_equality_subs1():
151    f = Function('f')
152    x = abc.x
153    eq = Eq(f(x)**2, x)
154    res = Eq(Integer(16), x)
155    assert eq.subs({f(x): 4}) == res
156
157
158def test_equality_subs2():
159    f = Function('f')
160    x = abc.x
161    eq = Eq(f(x)**2, 16)
162    assert bool(eq.subs({f(x): 3})) is False
163    assert bool(eq.subs({f(x): 4})) is True
164
165
166def test_sympyissue_3742():
167    y = Symbol('y')
168
169    e = sqrt(x)*exp(y)
170    assert e.subs({sqrt(x): 1}) == exp(y)
171
172
173def test_subs_dict1():
174    assert (1 + x*y).subs({x: pi}) == 1 + pi*y
175    assert (1 + x*y).subs({x: pi, y: 2}) == 1 + 2*pi
176
177    c2, c3, q1p, q2p, c1, s1, s2, s3 = symbols('c2 c3 q1p q2p c1 s1 s2 s3')
178    test = (c2**2*q2p*c3 + c1**2*s2**2*q2p*c3 + s1**2*s2**2*q2p*c3
179            - c1**2*q1p*c2*s3 - s1**2*q1p*c2*s3)
180    assert (test.subs({c1**2: 1 - s1**2, c2**2: 1 - s2**2, c3**3: 1 - s3**2})
181            == c3*q2p*(1 - s2**2) + c3*q2p*s2**2*(1 - s1**2)
182            - c2*q1p*s3*(1 - s1**2) + c3*q2p*s1**2*s2**2 - c2*q1p*s3*s1**2)
183
184
185def test_mul():
186    A, B, C = symbols('A B C', commutative=False)
187    assert (x*y*z).subs({z*x: y}) == y**2
188    assert (z*x).subs({1/x: z}) == z*x
189    assert (x*y/z).subs({1/z: a}) == a*x*y
190    assert (x*y/z).subs({x/z: a}) == a*y
191    assert (x*y/z).subs({y/z: a}) == a*x
192    assert (x*y/z).subs({x/z: 1/a}) == y/a
193    assert (x*y/z).subs({x: 1/a}) == y/(z*a)
194    assert (2*x*y).subs({5*x*y: z}) != 2*z/5
195    assert (x*y*A).subs({x*y: a}) == a*A
196    assert Subs(x*y*A, (x*y, a)).is_commutative is False
197    assert (x**2*y**(3*x/2)).subs({x*y**(x/2): 2}) == 4*y**(x/2)
198    assert (x*exp(x*2)).subs({x*exp(x): 2}) == 2*exp(x)
199    assert ((x**(2*y))**3).subs({x**y: 2}) == 64
200    assert (x*A*B).subs({x*A: y}) == y*B
201    assert (x*y*(1 + x)*(1 + x*y)).subs({x*y: 2}) == 6*(1 + x)
202    assert ((1 + A*B)*A*B).subs({A*B: x*A*B})
203    assert (x*a/z).subs({x/z: A}) == a*A
204    assert Subs(x*a/z, (x/z, A)).is_commutative is False
205    assert (x**3*A).subs({x**2*A: a}) == a*x
206    assert (x**2*A*B).subs({x**2*B: a}) == a*A
207    assert (x**2*A*B).subs({x**2*A: a}) == a*B
208    assert (b*A**3/(a**3*c**3)).subs({a**4*c**3*A**3/b**4: z}) == \
209        b*A**3/(a**3*c**3)
210    assert (6*x).subs({2*x: y}) == 3*y
211    assert (y*exp(3*x/2)).subs({y*exp(x): 2}) == 2*exp(x/2)
212    assert (y*exp(3*x/2)).subs({y*exp(x): 2}) == 2*exp(x/2)
213    assert (A**2*B*A**2*B*A**2).subs({A*B*A: C}) == A*C**2*A
214    assert (x*A**3).subs({x*A: y}) == y*A**2
215    assert (x**2*A**3).subs({x*A: y}) == y**2*A
216    assert (x*A**3).subs({x*A: B}) == B*A**2
217    assert (x*A*B*A*exp(x*A*B)).subs({x*A: B}) == B**2*A*exp(B*B)
218    assert (x**2*A*B*A*exp(x*A*B)).subs({x*A: B}) == B**3*exp(B**2)
219    assert (x**3*A*exp(x*A*B)*A*exp(x*A*B)).subs({x*A: B}) == \
220        x*B*exp(B**2)*B*exp(B**2)
221    assert (x*A*B*C*A*B).subs({x*A*B: C}) == C**2*A*B
222    assert (-I*a*b).subs({a*b: 2}) == -2*I
223
224    # issue sympy/sympy#6361
225    assert (-8*I*a).subs({-2*a: 1}) == 4*I
226    assert (-I*a).subs({-a: 1}) == I
227
228    # issue sympy/sympy#6441
229    assert (4*x**2).subs({2*x: y}) == y**2
230    assert (2*4*x**2).subs({2*x: y}) == 2*y**2
231    assert (-x**3/9).subs({-x/3: z}) == -z**2*x
232    assert (-x**3/9).subs({x/3: z}) == -z**2*x
233    assert (-2*x**3/9).subs({x/3: z}) == -2*x*z**2
234    assert (-2*x**3/9).subs({-x/3: z}) == -2*x*z**2
235    assert (-2*x**3/9).subs({-2*x: z}) == z*x**2/9
236    assert (-2*x**3/9).subs({2*x: z}) == -z*x**2/9
237    assert (2*(3*x/5/7)**2).subs({3*x/5: z}) == 2*Rational(1, 7)**2*z**2
238    assert (4*x).subs({-2*x: z}) == 4*x  # try keep subs literal
239
240    assert (A**2*B**2).subs({A*B**3: C}) == A**2*B**2
241    assert (A**Rational(5, 3)*B**3).subs({sqrt(A)*B: C}) == A**Rational(5, 3)*B**3
242    assert (A**2*B**2*A).subs({A**2*B*A: C}) == A**2*B**2*A
243
244
245def test_subs_simple():
246    a = symbols('a', commutative=True)
247    x = symbols('x', commutative=False)
248
249    assert (2*a).subs({1: 3}) == 2*a
250    assert (2*a).subs({2: 3}) == 3*a
251    assert (2*a).subs({a: 3}) == 6
252    assert sin(2).subs({1: 3}) == sin(2)
253    assert sin(2).subs({2: 3}) == sin(3)
254    assert sin(a).subs({a: 3}) == sin(3)
255
256    assert (2*x).subs({1: 3}) == 2*x
257    assert (2*x).subs({2: 3}) == 3*x
258    assert (2*x).subs({x: 3}) == 6
259    assert sin(x).subs({x: 3}) == sin(3)
260
261
262def test_subs_constants():
263    a, b = symbols('a b', commutative=True)
264    x, y = symbols('x y', commutative=False)
265
266    assert (a*b).subs({2*a: 1}) == a*b
267    assert (1.5*a*b).subs({a: 1}) == 1.5*b
268    assert (2*a*b).subs({2*a: 1}) == b
269    assert (2*a*b).subs({4*a: 1}) == 2*a*b
270
271    assert (x*y).subs({2*x: 1}) == x*y
272    assert (1.5*x*y).subs({x: 1}) == 1.5*y
273    assert (2*x*y).subs({2*x: 1}) == y
274    assert (2*x*y).subs({4*x: 1}) == 2*x*y
275
276
277def test_subs_commutative():
278    a, b, c, d, K = symbols('a b c d K', commutative=True)
279
280    assert (a*b).subs({a*b: K}) == K
281    assert (a*b*a*b).subs({a*b: K}) == K**2
282    assert (a*a*b*b).subs({a*b: K}) == K**2
283    assert (a*b*c*d).subs({a*b*c: K}) == d*K
284    assert (a*b**c).subs({a: K}) == K*b**c
285    assert (a*b**c).subs({b: K}) == a*K**c
286    assert (a*b**c).subs({c: K}) == a*b**K
287    assert (a*b*c*b*a).subs({a*b: K}) == c*K**2
288    assert (a**3*b**2*a).subs({a*b: K}) == a**2*K**2
289
290
291def test_subs_noncommutative():
292    w, x, y, z, L = symbols('w x y z L', commutative=False)
293
294    assert (x*y).subs({x*y: L}) == L
295    assert (w*y*x).subs({x*y: L}) == w*y*x
296    assert (w*x*y*z).subs({x*y: L}) == w*L*z
297    assert (x*y*x*y).subs({x*y: L}) == L**2
298    assert (x*x*y).subs({x*y: L}) == x*L
299    assert (x*x*y*y).subs({x*y: L}) == x*L*y
300    assert (w*x*y).subs({x*y*z: L}) == w*x*y
301    assert (x*y**z).subs({x: L}) == L*y**z
302    assert (x*y**z).subs({y: L}) == x*L**z
303    assert (x*y**z).subs({z: L}) == x*y**L
304    assert (w*x*y*z*x*y).subs({x*y*z: L}) == w*L*x*y
305    assert (w*x*y*y*w*x*x*y*x*y*y*x*y).subs({x*y: L}) == w*L*y*w*x*L**2*y*L
306
307    # issue sympy/sympy#5284
308    A, B = symbols('A B', commutative=False)
309    x = symbols('x')
310    assert (x*A).subs({x**2*A: B}) == x*A
311    assert (A**2).subs({A**3: B}) == A**2
312    assert (A**6).subs({A**3: B}) == B**2
313
314
315def test_subs_basic_funcs():
316    a, b, c, d, K = symbols('a b c d K', commutative=True)
317    w, x, y, z, L = symbols('w x y z L', commutative=False)
318
319    assert (x + y).subs({x + y: L}) == L
320    assert (x - y).subs({x - y: L}) == L
321    assert (x/y).subs({x: L}) == L/y
322    assert (x**y).subs({x: L}) == L**y
323    assert (x**y).subs({y: L}) == x**L
324    assert ((a - c)/b).subs({b: K}) == (a - c)/K
325    assert (exp(x*y - z)).subs({x*y: L}) == exp(L - z)
326    assert (a*exp(x*y - w*z) + b*exp(x*y + w*z)).subs({z: 0}) == \
327        a*exp(x*y) + b*exp(x*y)
328    assert ((a - b)/(c*d - a*b)).subs({c*d - a*b: K}) == (a - b)/K
329    assert (w*exp(a*b - c)*x*y/4).subs({x*y: L}) == w*exp(a*b - c)*L/4
330
331
332def test_subs_wild():
333    R, S, T, U = symbols('R S T U', cls=Wild)
334
335    assert (R*S).subs({R*S: T}) == T
336    assert (S*R).subs({R*S: T}) == T
337    assert (R + S).subs({R + S: T}) == T
338    assert (R**S).subs({R: T}) == T**S
339    assert (R**S).subs({S: T}) == R**T
340    assert (R*S**T).subs({R: U}) == U*S**T
341    assert (R*S**T).subs({S: U}) == R*U**T
342    assert (R*S**T).subs({T: U}) == R*S**U
343
344
345def test_subs_mixed():
346    a, b, c, K = symbols('a b c K', commutative=True)
347    x, y, z, L = symbols('x y z L', commutative=False)
348    R, S, T, U = symbols('R S T U', cls=Wild)
349
350    assert (a*x*y).subs({x*y: L}) == a*L
351    assert (a*b*x*y*x).subs({x*y: L}) == a*b*L*x
352    assert (R*x*y*exp(x*y)).subs({x*y: L}) == R*L*exp(L)
353    assert (a*x*y*y*x - x*y*z*exp(a*b)).subs({x*y: L}) == a*L*y*x - L*z*exp(a*b)
354    e = c*y*x*y*x**(R*S - a*b) - T*(a*R*b*S)
355    assert e.subs({x*y: L}).subs({a*b: K}).subs({R*S: U}) == \
356        c*y*L*x**(U - K) - T*(U*K)
357
358
359def test_division():
360    a, c = symbols('a c', commutative=True)
361    x, z = symbols('x z', commutative=True)
362
363    assert (1/a).subs({a: c}) == 1/c
364    assert (1/a**2).subs({a: c}) == 1/c**2
365    assert (1/a**2).subs({a: -2}) == Rational(1, 4)
366    assert (-(1/a**2)).subs({a: -2}) == -Rational(1, 4)
367
368    assert (1/x).subs({x: z}) == 1/z
369    assert (1/x**2).subs({x: z}) == 1/z**2
370    assert (1/x**2).subs({x: -2}) == Rational(1, 4)
371    assert (-(1/x**2)).subs({x: -2}) == -Rational(1, 4)
372
373    # issue sympy/sympy#5360
374    assert (1/x).subs({x: 0}) == 1/Integer(0)
375
376
377def test_add():
378    assert (a**2 - b - c).subs({a**2 - b: d}) in [d - c, a**2 - b - c]
379    assert (a**2 - c).subs({a**2 - c: d}) == d
380    assert (a**2 - b - c).subs({a**2 - c: d}) in [d - b, a**2 - b - c]
381    assert (a**2 - x - c).subs({a**2 - c: d}) in [d - x, a**2 - x - c]
382    assert (a**2 - b - sqrt(a)).subs({a**2 - sqrt(a): c}) == c - b
383    assert (a + b + exp(a + b)).subs({a + b: c}) == c + exp(c)
384    assert (c + b + exp(c + b)).subs({c + b: a}) == a + exp(a)
385    assert (a + b + c + d).subs({b + c: x}) == a + d + x
386    assert (a + b + c + d).subs({-b - c: x}) == a + d - x
387    assert ((x + 1)*y).subs({x + 1: t}) == t*y
388    assert ((-x - 1)*y).subs({x + 1: t}) == -t*y
389    assert ((x - 1)*y).subs({x + 1: t}) == y*(t - 2)
390    assert ((-x + 1)*y).subs({x + 1: t}) == y*(-t + 2)
391
392    # this should work everytime:
393    e = a**2 - b - c
394    assert e.subs({Add(*e.args[:2]): d}) == d + e.args[2]
395    assert e.subs({a**2 - c: d}) == d - b
396
397    # the fallback should recognize when a change has
398    # been made; while .1 == Rational(1, 10) they are not the same
399    # and the change should be made
400    assert (0.1 + a).subs({0.1: Rational(1, 10)}) == Rational(1, 10) + a
401
402    e = (-x*(-y + 1) - y*(y - 1))
403    ans = (-x*x - y*(-x)).expand()
404    assert e.subs({-y + 1: x}) == ans
405
406
407def test_subs_sympyissue_4009():
408    assert (I*Symbol('a')).subs({1: 2}) == I*Symbol('a')
409
410
411def test_functions_subs():
412    f, g = symbols('f g', cls=Function)
413    l = Lambda((x, y), sin(x) + y)
414    assert (g(y, x) + cos(x)).subs({g: l}) == sin(y) + x + cos(x)
415    assert (f(x)**2).subs({f: sin}) == sin(x)**2
416    assert (f(x, y)).subs({f: log}) == log(x, y)
417    assert (f(x, y)).subs({f: sin}) == f(x, y)
418    assert (sin(x) + atan2(x, y)).subs({atan2: f, sin: g}) == \
419        f(x, y) + g(x)
420    assert (g(f(x + y, x))).subs({f: l, g: Lambda(x, exp(x))}) == exp(x + sin(x + y))
421
422
423def test_derivative_subs():
424    y = Symbol('y')
425    f = Function('f')
426    assert Derivative(f(x), x).subs({f(x): y}) != 0
427    assert Derivative(f(x), x).subs({f(x): y}).subs({y: f(x)}) == \
428        Derivative(f(x), x)
429    # issues sympy/sympy#5085, sympy/sympy#5037
430    assert cse(Derivative(f(x), x) + f(x))[1][0].has(Derivative)
431    assert cse(Derivative(f(x, y), x) +
432               Derivative(f(x, y), y))[1][0].has(Derivative)
433
434
435def test_derivative_subs2():
436    f_func, g_func = symbols('f g', cls=Function)
437    f, g = f_func(x, y, z), g_func(x, y, z)
438    assert Derivative(f, x, y).subs({Derivative(f, x, y): g}) == g
439    assert Derivative(f, y, x).subs({Derivative(f, x, y): g}) == g
440    assert Derivative(f, x, y).subs({Derivative(f, x): g}) == Derivative(g, y)
441    assert Derivative(f, x, y).subs({Derivative(f, y): g}) == Derivative(g, x)
442    assert (Derivative(f, x, y, z).subs({Derivative(f, x, z): g}) == Derivative(g, y))
443    assert (Derivative(f, x, y, z).subs({Derivative(f, z, y): g}) == Derivative(g, x))
444    assert Derivative(f, x, y, z).subs({Derivative(f, z, y, x): g}) == g
445    assert (Derivative(sin(x), (x, 2)).subs({Derivative(sin(x), f_func(x)): g_func}) ==
446            Derivative(sin(x), (x, 2)))
447
448    # issue sympy/sympy#9135
449    assert (Derivative(f, x, x, y).subs({Derivative(f, y, y): g}) == Derivative(f, x, x, y))
450    assert (Derivative(f, x, y, y, z).subs({Derivative(f, x, y, y, y): g}) == Derivative(f, x, y, y, z))
451
452    assert Derivative(f, x, y).subs({Derivative(f_func(x), x, y): g}) == Derivative(f, x, y)
453
454
455def test_derivative_subs3():
456    x = Symbol('x')
457    dex = Derivative(exp(x), x)
458    assert Derivative(dex, x).subs({dex: exp(x)}) == dex
459    assert dex.subs({exp(x): dex}) == Derivative(exp(x), x, x)
460
461
462def test_subs_iter():
463    assert x.subs(reversed([[x, y]])) == y
464    it = iter([[x, y]])
465    assert x.subs(it) == y
466    assert x.subs(Tuple((x, y))) == y
467
468
469def test_subs_dict():
470    assert (2*x + y + z).subs({x: 1, y: 2}) == 4 + z
471
472    l = [(sin(x), 2), (x, 1)]
473    assert (sin(x)).subs(l) == \
474           (sin(x)).subs(dict(l)) == 2
475    assert sin(x).subs(reversed(l)) == sin(1)
476
477    expr = sin(2*x) + sqrt(sin(2*x))*cos(2*x)*sin(exp(x)*x)
478    reps = {sin(2*x): c, sqrt(sin(2*x)): a, cos(2*x): b,
479            exp(x): e, x: d}
480    assert expr.subs(reps) == c + a*b*sin(d*e)
481
482    l = [(x, 3), (y, x**2)]
483    assert (x + y).subs(l) == 3 + x**2
484    assert (x + y).subs(reversed(l)) == 12
485
486    # If changes are made to convert lists into dictionaries and do
487    # a dictionary-lookup replacement, these tests will help to catch
488    # some logical errors that might occur
489    l = [(y, z + 2), (1 + z, 5), (z, 2)]
490    assert (y - 1 + 3*x).subs(l) == 5 + 3*x
491    l = [(y, z + 2), (z, 3)]
492    assert (y - 2).subs(l) == 3
493
494
495def test_no_arith_subs_on_floats():
496    assert (x + 3).subs({x + 3: a}) == a
497    assert (x + 3).subs({x + 2: a}) == a + 1
498
499    assert (x + y + 3).subs({x + 3: a}) == a + y
500    assert (x + y + 3).subs({x + 2: a}) == a + y + 1
501
502    assert (x + 3.0).subs({x + 3.0: a}) == a
503    assert (x + 3.0).subs({x + 2.0: a}) == x + 3.0
504
505    assert (x + y + 3.0).subs({x + 3.0: a}) == a + y
506    assert (x + y + 3.0).subs({x + 2.0: a}) == x + y + 3.0
507
508
509def test_sympyissue_5651():
510    a, b, c, K = symbols('a b c K', commutative=True)
511    assert (a/(b*c)).subs({b*c: K}) == a/K
512    assert (a/(b**2*c**3)).subs({b*c: K}) == a/(c*K**2)
513    assert (1/(x*y)).subs({x*y: 2}) == Rational(1, 2)
514    assert ((1 + x*y)/(x*y)).subs({x*y: 1}) == 2
515    assert (x*y*z).subs({x*y: 2}) == 2*z
516    assert ((1 + x*y)/(x*y)/z).subs({x*y: 1}) == 2/z
517
518
519def test_sympyissue_6075():
520    assert Tuple(1, True).subs({1: 2}) == Tuple(2, True)
521
522
523def test_sympyissue_6079():
524    # since x + 2.0 == x + 2 we can't do a simple equality test
525    assert _aresame((x + 2.0).subs({2: 3}), x + 2.0)
526    assert _aresame((x + 2.0).subs({2.0: 3}), x + 3)
527    assert not _aresame(x + 2, x + 2.0)
528    assert not _aresame(Basic(cos, 1), Basic(cos, 1.))
529    assert _aresame(cos, cos)
530    assert not _aresame(1, Integer(1))
531    assert not _aresame(x, symbols('x', positive=True))
532
533
534def test_sympyissue_4680():
535    N = Symbol('N')
536    assert N.subs({N: 3}) == 3
537
538
539def test_sympyissue_6158():
540    assert (x - 1).subs({1: y}) == x - y
541    assert (x - 1).subs({-1: y}) == x + y
542    assert (x - oo).subs({oo: y}) == x - y
543    assert (x - oo).subs({-oo: y}) == x + y
544
545
546def test_Function_subs():
547    f, g, h, i = symbols('f g h i', cls=Function)
548    p = Piecewise((g(f(x, y)), x < -1), (g(x), x <= 1))
549    assert p.subs({g: h}) == Piecewise((h(f(x, y)), x < -1), (h(x), x <= 1))
550    assert (f(y) + g(x)).subs({f: h, g: i}) == i(x) + h(y)
551
552
553def test_simultaneous_subs():
554    reps = {x: 0, y: 0}
555    assert (x/y).subs(reps) != (y/x).subs(reps)
556    assert (x/y).subs(reps, simultaneous=True) == \
557        (y/x).subs(reps, simultaneous=True)
558    reps = reps.items()
559    assert (x/y).subs(reps) != (y/x).subs(reps)
560    assert (x/y).subs(reps, simultaneous=True) == \
561        (y/x).subs(reps, simultaneous=True)
562    assert Derivative(x, y, z).subs(reps, simultaneous=True) == \
563        Subs(Derivative(0, y, z), (y, 0))
564
565
566def test_sympyissue_6419_6421():
567    assert (1/(1 + x/y)).subs({x/y: x}) == 1/(1 + x)
568    assert (-2*I).subs({2*I: x}) == -x
569    assert (-I*x).subs({I*x: x}) == -x
570    assert (-3*I*y**4).subs({3*I*y**2: x}) == -x*y**2
571
572
573def test_sympyissue_6559():
574    assert (-12*x + y).subs({-x: 1}) == 12 + y
575    # though this involves cse it generated a failure in Mul._eval_subs
576    x0 = Symbol('x0')
577    e = -log(-12*sqrt(2) + 17)/24 - log(-2*sqrt(2) + 3)/12 + sqrt(2)/3
578    # XXX modify cse so x1 is eliminated and x0 = -sqrt(2)?
579    assert cse(e) == (
580        [(x0, sqrt(2))], [x0/3 - log(-12*x0 + 17)/24 - log(-2*x0 + 3)/12])
581
582
583def test_sympyissue_5261():
584    x = symbols('x', extended_real=True)
585    e = I*x
586    assert exp(e).subs({exp(x): y}) == y**I
587    assert (2**e).subs({2**x: y}) == y**I
588    eq = (-2)**e
589    assert eq.subs({(-2)**x: y}) == eq
590
591
592@pytest.mark.xfail
593def test_mul2():
594    """When this fails, remove things labelled "2-arg hack"
595    1) remove special handling in the fallback of subs that
596    was added in the same commit as this test
597    2) remove the special handling in Mul.flatten
598    """
599    assert (2*(x + 1)).is_Mul
600
601
602def test_noncommutative_subs():
603    x, y = symbols('x, y', commutative=False)
604    assert (x*y*x).subs({x: x*y, y: x}, simultaneous=True) == (x*y*x**2*y)
605
606
607def test_sympyissue_2877():
608    f = Float(2.0)
609    assert (x + f).subs({f: 2}) == x + 2
610
611    def r(a, b, c):
612        return factor(a*x**2 + b*x + c)
613    e = r(5/6, 10, 5)
614    assert nsimplify(e) == 5*x**2/6 + 10*x + 5
615
616
617def test_sympyissue_5910():
618    t = Symbol('t')
619    assert (1/(1 - t)).subs({t: 1}) == zoo
620    n = t
621    d = t - 1
622    assert (n/d).subs({t: 1}) == zoo
623    assert (-n/-d).subs({t: 1}) == zoo
624
625
626def test_sympyissue_5217():
627    s = Symbol('s')
628    z = (1 - 2*x*x)
629    w = (1 + 2*x*x)
630    q = 2*x*x*2*y*y
631    sub = {2*x*x: s}
632    assert w.subs(sub) == 1 + s
633    assert z.subs(sub) == 1 - s
634    assert q == 4*x**2*y**2
635    assert q.subs(sub) == 2*y**2*s
636
637
638def test_pow_eval_subs_no_cache():
639    s = 1/sqrt(x**2)
640    # This bug only appeared when the cache was turned off.
641    clear_cache()
642
643    # This used to fail with a wrong result.
644    # It incorrectly returned 1/sqrt(x**2) before.
645    result = s.subs({sqrt(x**2): y})
646    assert result == 1/y
647
648
649def test_diofantissue_124():
650    n = Symbol('n', integer=True)
651    assert exp(n*x).subs({exp(x): x}) == x**n
652
653
654def test_sympyissue_11159():
655    exp1 = E
656    exp0 = exp1*exp1
657    assert exp0.subs({exp1: exp0}) == E**4
658
659
660def test_RootOf_sympyissue_10092():
661    x = Symbol('x', real=True)
662    eq = x**3 - 17*x**2 + 81*x - 118
663    r = RootOf(eq, 0)
664    assert (x < r).subs({x: r}) is false
665
666
667def test_sympyissue_11746():
668    x = Symbol('x', real=True)
669    assert (1/x).subs({x**2: 1}) != 1
670
671
672def test_diff_subs():
673    # issue diofant/diofant#376
674    f = symbols('f', cls=Function)
675    e1 = Subs(Derivative(f(x), x), (x, y*z))
676    e2 = Subs(Derivative(f(t), t), (t, y*z))
677    assert (x*e1).subs({e2: y}) == x*y
678    e3 = Subs(Derivative(f(t), t), (t, y*z**2))
679    assert e3.subs({e2: y}) == e3
680    e4 = Subs(Derivative(f(t), t, t), (t, y*z))
681    assert e4.subs({e2: y}) == e4
682
683
684def test_aresame_ordering():
685    e = Min(3, z)
686    s = {Min(z, 3): 3}
687    assert e.subs(s) == 3
688
689
690def test_underscores():
691    _0, _1 = symbols('_0 _1')
692    e = Subs(_0 + _1, (_0, 1), (_1, 0))
693    assert e._expr == Symbol('__0') + Symbol('__1')
694