1import pytest
2
3from diofant import (I, Integer, Mul, O, Pow, Rational, Symbol, cbrt, cos, exp,
4                     expand, expand_multinomial, expand_power_base, log, pi,
5                     sin, sqrt)
6from diofant.abc import x, y, z
7from diofant.simplify.radsimp import expand_numer
8from diofant.utilities.randtest import verify_numerically
9
10
11__all__ = ()
12
13
14def test_expand_no_log():
15    assert (
16        (1 + log(x**4))**2).expand(log=False) == 1 + 2*log(x**4) + log(x**4)**2
17    assert ((1 + log(x**4))*(1 + log(x**3))).expand(
18        log=False) == 1 + log(x**4) + log(x**3) + log(x**4)*log(x**3)
19
20
21def test_expand_no_multinomial():
22    assert ((1 + x)*(1 + (1 + x)**4)).expand(multinomial=False) == \
23        1 + x + (1 + x)**4 + x*(1 + x)**4
24
25
26def test_expand_negative_integer_powers():
27    expr = (x + y)**(-2)
28    assert expr.expand() == 1 / (2*x*y + x**2 + y**2)
29    assert expr.expand(multinomial=False) == (x + y)**(-2)
30    expr = (x + y)**(-3)
31    assert expr.expand() == 1 / (3*x*x*y + 3*x*y*y + x**3 + y**3)
32    assert expr.expand(multinomial=False) == (x + y)**(-3)
33    expr = (x + y)**2 * (x + y)**(-4)
34    assert expr.expand() == 1 / (2*x*y + x**2 + y**2)
35    assert expr.expand(multinomial=False) == (x + y)**(-2)
36
37
38def test_expand_non_commutative():
39    A = Symbol('A', commutative=False)
40    B = Symbol('B', commutative=False)
41    C = Symbol('C', commutative=False)
42    a = Symbol('a')
43    b = Symbol('b')
44    i = Symbol('i', integer=True)
45    n = Symbol('n', negative=True, finite=True)
46    m = Symbol('m', negative=True, finite=True)
47    p = Symbol('p', polar=True)
48    np = Symbol('p', polar=False)
49
50    assert (C*(A + B)).expand() == C*A + C*B
51    assert (C*(A + B)).expand() != A*C + B*C
52    assert ((A + B)**2).expand() == A**2 + A*B + B*A + B**2
53    assert ((A + B)**3).expand() == (A**2*B + B**2*A + A*B**2 + B*A**2 +
54                                     A**3 + B**3 + A*B*A + B*A*B)
55    # issue sympy/sympy#6219
56    assert ((a*A*B*A**-1)**2).expand() == a**2*A*B**2/A
57    # Note that (a*A*B*A**-1)**2 is automatically converted to a**2*(A*B*A**-1)**2
58    assert ((a*A*B*A**-1)**2).expand(deep=False) == a**2*(A*B*A**-1)**2
59    assert ((a*A*B*A**-1)**2).expand() == a**2*(A*B**2*A**-1)
60    assert ((a*A*B*A**-1)**2).expand(force=True) == a**2*A*B**2*A**(-1)
61    assert ((a*A*B)**2).expand() == a**2*A*B*A*B
62    assert ((a*A)**2).expand() == a**2*A**2
63    assert ((a*A*B)**i).expand() == a**i*(A*B)**i
64    assert ((a*A*(B*(A*B/A)**2))**i).expand() == a**i*(A*B*A*B**2/A)**i
65    # issue sympy/sympy#6558
66    assert (A*B*(A*B)**-1).expand() == A*B*(A*B)**-1
67    assert ((a*A)**i).expand() == a**i*A**i
68    assert ((a*A*B*A**-1)**3).expand() == a**3*A*B**3/A
69    assert ((a*A*B*A*B/A)**3).expand() == \
70        a**3*A*B*(A*B**2)*(A*B**2)*A*B*A**(-1)
71    assert ((a*A*B*A*B/A)**-3).expand() == \
72        a**-3*(A*B*(A*B**2)*(A*B**2)*A*B*A**(-1))**-1
73    assert ((a*b*A*B*A**-1)**i).expand() == a**i*b**i*(A*B/A)**i
74    assert ((a*(a*b)**i)**i).expand() == a**i*a**(i**2)*b**(i**2)
75    e = Pow(Mul(a, 1/a, A, B, evaluate=False), Integer(2), evaluate=False)
76    assert e.expand() == A*B*A*B
77    assert sqrt(a*(A*b)**i).expand() == sqrt(a*b**i*A**i)
78    assert (sqrt(-a)**a).expand() == sqrt(-a)**a
79    assert expand((-2*n)**(i/3)) == 2**(i/3)*(-n)**(i/3)
80    assert expand((-2*n*m)**(i/a)) == (-2)**(i/a)*(-n)**(i/a)*(-m)**(i/a)
81    assert expand((-2*a*p)**b) == 2**b*p**b*(-a)**b
82    assert expand((-2*a*np)**b) == 2**b*(-a*np)**b
83    assert expand(sqrt(A*B)) == sqrt(A*B)
84    assert expand(sqrt(-2*a*b)) == sqrt(2)*sqrt(-a*b)
85
86
87def test_expand_radicals():
88    a = sqrt(x + y)
89
90    assert (a**1).expand() == a
91    assert (a**3).expand() == x*a + y*a
92    assert (a**5).expand() == x**2*a + 2*x*y*a + y**2*a
93
94    assert (1/a**1).expand() == 1/a
95    assert (1/a**3).expand() == 1/(x*a + y*a)
96    assert (1/a**5).expand() == 1/(x**2*a + 2*x*y*a + y**2*a)
97
98    a = cbrt(x + y)
99
100    assert (a**1).expand() == a
101    assert (a**2).expand() == a**2
102    assert (a**4).expand() == x*a + y*a
103    assert (a**5).expand() == x*a**2 + y*a**2
104    assert (a**7).expand() == x**2*a + 2*x*y*a + y**2*a
105
106
107def test_expand_modulus():
108    assert ((x + y)**11).expand(modulus=11) == x**11 + y**11
109    assert ((x + sqrt(2)*y)**11).expand(modulus=11) == x**11 + 10*sqrt(2)*y**11
110    assert (x + y/2).expand(modulus=1) == y/2
111
112    pytest.raises(ValueError, lambda: ((x + y)**11).expand(modulus=0))
113    pytest.raises(ValueError, lambda: ((x + y)**11).expand(modulus=x))
114
115
116def test_sympyissue_5743():
117    assert (x*sqrt(
118        x + y)*(1 + sqrt(x + y))).expand() == x**2 + x*y + x*sqrt(x + y)
119    assert (x*sqrt(
120        x + y)*(1 + x*sqrt(x + y))).expand() == x**3 + x**2*y + x*sqrt(x + y)
121
122
123def test_expand_frac():
124    assert expand((x + y)*y/x/(x + 1), frac=True) == \
125        (x*y + y**2)/(x**2 + x)
126    assert expand((x + y)*y/x/(x + 1), numer=True) == \
127        (x*y + y**2)/(x*(x + 1))
128    assert expand((x + y)*y/x/(x + 1), denom=True) == \
129        y*(x + y)/(x**2 + x)
130    eq = (x + 1)**2/y
131    assert expand_numer(eq, multinomial=False) == eq
132
133
134def test_sympyissue_6121():
135    eq = -I*exp(-3*I*pi/4)/(4*pi**Rational(3, 2)*sqrt(x))
136    assert eq.expand(complex=True)  # does not give oo recursion
137
138
139def test_expand_power_base():
140    assert expand_power_base((x*y*z)**4) == x**4*y**4*z**4
141    assert expand_power_base((x*y*z)**x).is_Pow
142    assert expand_power_base((x*y*z)**x, force=True) == x**x*y**x*z**x
143    assert expand_power_base((x*(y*z)**2)**3) == x**3*y**6*z**6
144
145    assert expand_power_base((sin((x*y)**2)*y)**z).is_Pow
146    assert expand_power_base(
147        (sin((x*y)**2)*y)**z, force=True) == sin((x*y)**2)**z*y**z
148    assert expand_power_base(
149        (sin((x*y)**2)*y)**z, deep=True) == (sin(x**2*y**2)*y)**z
150
151    assert expand_power_base(exp(x)**2) == exp(2*x)
152    assert expand_power_base((exp(x)*exp(y))**2) == exp(2*x)*exp(2*y)
153
154    assert expand_power_base(
155        (exp((x*y)**z)*exp(y))**2) == exp(2*(x*y)**z)*exp(2*y)
156    assert expand_power_base((exp((x*y)**z)*exp(
157        y))**2, deep=True, force=True) == exp(2*x**z*y**z)*exp(2*y)
158
159    assert expand_power_base((exp(x)*exp(y))**z).is_Pow
160    assert expand_power_base(
161        (exp(x)*exp(y))**z, force=True) == exp(x)**z*exp(y)**z
162
163
164def test_expand_arit():
165    a = Symbol('a')
166    b = Symbol('b', positive=True)
167    c = Symbol('c')
168
169    p = Integer(5)
170    e = (a + b)*c
171    assert e == c*(a + b)
172    assert (e.expand() - a*c - b*c) == 0
173    e = (a + b)*(a + b)
174    assert e == (a + b)**2
175    assert e.expand() == 2*a*b + a**2 + b**2
176    e = (a + b)*(a + b)**2
177    assert e == (a + b)**3
178    assert e.expand() == 3*b*a**2 + 3*a*b**2 + a**3 + b**3
179    assert e.expand() == 3*b*a**2 + 3*a*b**2 + a**3 + b**3
180    e = (a + b)*(a + c)*(b + c)
181    assert e == (a + c)*(a + b)*(b + c)
182    assert e.expand() == 2*a*b*c + b*a**2 + c*a**2 + b*c**2 + a*c**2 + c*b**2 + a*b**2
183    e = (a + 1)**p
184    assert e == (1 + a)**5
185    assert e.expand() == 1 + 5*a + 10*a**2 + 10*a**3 + 5*a**4 + a**5
186    e = (a + b + c)*(a + c + p)
187    assert e == (5 + a + c)*(a + b + c)
188    assert e.expand() == 5*a + 5*b + 5*c + 2*a*c + b*c + a*b + a**2 + c**2
189    x = Symbol('x')
190    s = exp(x*x) - 1
191    e = s.series(x)/x**2
192    assert e.expand() == 1 + x**2/2 + O(x**4)
193
194    e = (x*(y + z))**(x*(y + z))*(x + y)
195    assert e.expand(power_exp=False, power_base=False) == x*(x*y + x *
196                                                             z)**(x*y + x*z) + y*(x*y + x*z)**(x*y + x*z)
197    assert e.expand(power_exp=False, power_base=False, deep=False) == x * \
198        (x*(y + z))**(x*(y + z)) + y*(x*(y + z))**(x*(y + z))
199    e = (x*(y + z))**z
200    assert e.expand(power_base=True, mul=True, deep=True) in [x**z*(y +
201                                                                    z)**z, (x*y + x*z)**z]
202    assert ((2*y)**z).expand() == 2**z*y**z
203    p = Symbol('p', positive=True)
204    assert sqrt(-x).expand().is_Pow
205    assert sqrt(-x).expand(force=True) == I*sqrt(x)
206    assert ((2*y*p)**z).expand() == 2**z*p**z*y**z
207    assert ((2*y*p*x)**z).expand() == 2**z*p**z*(x*y)**z
208    assert ((2*y*p*x)**z).expand(force=True) == 2**z*p**z*x**z*y**z
209    assert ((2*y*p*-pi)**z).expand() == 2**z*pi**z*p**z*(-y)**z
210    assert ((2*y*p*-pi*x)**z).expand() == 2**z*pi**z*p**z*(-x*y)**z
211    n = Symbol('n', negative=True, finite=True)
212    m = Symbol('m', negative=True, finite=True)
213    assert ((-2*x*y*n)**z).expand() == 2**z*(-n)**z*(x*y)**z
214    assert ((-2*x*y*n*m)**z).expand() == 2**z*(-m)**z*(-n)**z*(-x*y)**z
215    # issue sympy/sympy#5482
216    assert sqrt(-2*x*n) == sqrt(2)*sqrt(-n)*sqrt(x)
217    # issue sympy/sympy#5605 (2)
218    assert (cos(x + y)**2).expand(trig=True) in [
219        (-sin(x)*sin(y) + cos(x)*cos(y))**2,
220        sin(x)**2*sin(y)**2 - 2*sin(x)*sin(y)*cos(x)*cos(y) + cos(x)**2*cos(y)**2
221    ]
222
223    # Check that this isn't too slow
224    x = Symbol('x')
225    W = 1
226    for i in range(1, 21):
227        W = W * (x - i)
228    W = W.expand()
229    assert W.has(-1672280820*x**15)
230
231
232def test_power_expand():
233    """Test for Pow.expand()"""
234    a = Symbol('a')
235    b = Symbol('b')
236    p = (a + b)**2
237    assert p.expand() == a**2 + b**2 + 2*a*b
238
239    p = (1 + 2*(1 + a))**2
240    assert p.expand() == 9 + 4*(a**2) + 12*a
241
242    p = 2**(a + b)
243    assert p.expand() == 2**a*2**b
244
245    A = Symbol('A', commutative=False)
246    B = Symbol('B', commutative=False)
247    assert (2**(A + B)).expand() == 2**(A + B)
248    assert (A**(a + b)).expand() != A**(a + b)
249
250
251def test_expand_multinomial():
252    assert expand_multinomial((x + 1 + O(z))**2) == 1 + 2*x + x**2 + O(z)
253    assert expand_multinomial((x + 1 +
254                               O(z))**3) == 1 + 3*x + 3*x**2 + x**3 + O(z)
255
256
257def test_sympyissues_5919_6830():
258    # issue sympy/sympy#5919
259    n = -1 + 1/x
260    z = n/x/(-n)**2 - 1/n/x
261    assert expand(z) == 1/(x**2 - 2*x + 1) - 1/(x - 2 + 1/x) - 1/(-x + 1)
262
263    # issue sympy/sympy#6830
264    p = (1 + x)**2
265    assert expand_multinomial((1 + x*p)**2) == (
266        x**2*(x**4 + 4*x**3 + 6*x**2 + 4*x + 1) + 2*x*(x**2 + 2*x + 1) + 1)
267    assert expand_multinomial((1 + (y + x)*p)**2) == (
268        2*((x + y)*(x**2 + 2*x + 1)) + (x**2 + 2*x*y + y**2) *
269        (x**4 + 4*x**3 + 6*x**2 + 4*x + 1) + 1)
270    A = Symbol('A', commutative=False)
271    p = (1 + A)**2
272    assert expand_multinomial((1 + x*p)**2) == (
273        x**2*(1 + 4*A + 6*A**2 + 4*A**3 + A**4) + 2*x*(1 + 2*A + A**2) + 1)
274    assert expand_multinomial((1 + (y + x)*p)**2) == (
275        (x + y)*(1 + 2*A + A**2)*2 + (x**2 + 2*x*y + y**2) *
276        (1 + 4*A + 6*A**2 + 4*A**3 + A**4) + 1)
277    assert expand_multinomial((1 + (y + x)*p)**3) == (
278        (x + y)*(1 + 2*A + A**2)*3 + (x**2 + 2*x*y + y**2)*(1 + 4*A +
279                                                            6*A**2 + 4*A**3 + A**4)*3 + (x**3 + 3*x**2*y + 3*x*y**2 + y**3)*(1 + 6*A
280                                                                                                                             + 15*A**2 + 20*A**3 + 15*A**4 + 6*A**5 + A**6) + 1)
281    # unevaluate powers
282    eq = (Pow((x + 1)*((A + 1)**2), 2, evaluate=False))
283    # - in this case the base is not an Add so no further
284    #   expansion is done
285    assert expand_multinomial(eq) == \
286        (x**2 + 2*x + 1)*(1 + 4*A + 6*A**2 + 4*A**3 + A**4)
287    # - but here, the expanded base *is* an Add so it gets expanded
288    eq = (Pow(((A + 1)**2), 2, evaluate=False))
289    assert expand_multinomial(eq) == 1 + 4*A + 6*A**2 + 4*A**3 + A**4
290
291    # coverage
292    def ok(a, b, n):
293        e = (a + I*b)**n
294        return verify_numerically(e, expand_multinomial(e))
295
296    for a in [2, Rational(1, 2)]:
297        for b in [3, Rational(1, 3)]:
298            for n in range(2, 6):
299                assert ok(a, b, n)
300
301    e = (sin(x) + y)**3
302    assert (expand_multinomial(e.subs({y: O(x**4)})) ==
303            expand_multinomial(e).subs({y: O(x**4)}) == sin(x)**3 + O(x**6))
304
305    assert expand_multinomial(3**(x + y + 3)) == 27*3**(x + y)
306
307
308def test_expand_log():
309    t = Symbol('t', positive=True)
310    # after first expansion, -2*log(2) + log(4); then 0 after second
311    assert expand(log(t**2) - log(t**2/4) - 2*log(2)) == 0
312