1"""Tests for tools for manipulating of large commutative expressions."""
2
3import pytest
4
5from diofant import (Add, Basic, Dict, Expr, Float, I, Integer, Integral,
6                     Interval, Mul, O, Rational, Sum, Symbol, Tuple, cbrt,
7                     collect, cos, exp, oo, root, simplify, sin, sqrt, symbols)
8from diofant.abc import a, b, t, x, y, z
9from diofant.core.coreerrors import NonCommutativeExpression
10from diofant.core.exprtools import (Factors, Term, _gcd_terms, decompose_power,
11                                    factor_nc, factor_terms, gcd_terms)
12from diofant.core.function import _mexpand
13from diofant.core.mul import _keep_coeff
14from diofant.simplify.cse_opts import sub_pre
15
16
17__all__ = ()
18
19
20def test_decompose_power():
21    assert decompose_power(x) == (x, 1)
22    assert decompose_power(x**2) == (x, 2)
23    assert decompose_power(x**(2*y)) == (x**y, 2)
24    assert decompose_power(x**(2*y/3)) == (x**(y/3), 2)
25
26
27def test_Factors():
28    assert Factors() == Factors({}) == Factors(Integer(1))
29    assert Factors(Integer(1)) == Factors(Factors(Integer(1)))
30    assert Factors().as_expr() == 1
31    assert Factors({x: 2, y: 3, sin(x): 4}).as_expr() == x**2*y**3*sin(x)**4
32    assert Factors(+oo) == Factors({oo: 1})
33    assert Factors(-oo) == Factors({oo: 1, -1: 1})
34
35    f1 = Factors({oo: 1})
36    f2 = Factors({oo: 1})
37    assert hash(f1) == hash(f2)
38
39    a = Factors({x: 5, y: 3, z: 7})
40    b = Factors({      y: 4, z: 3, t: 10})
41
42    assert a.mul(b) == a*b == Factors({x: 5, y: 7, z: 10, t: 10})
43
44    assert a.div(b) == divmod(a, b) == \
45        (Factors({x: 5, z: 4}), Factors({y: 1, t: 10}))
46    assert a.quo(b) == a/b == Factors({x: 5, z: 4})
47    assert a.rem(b) == a % b == Factors({y: 1, t: 10})
48
49    assert a.pow(3) == a**3 == Factors({x: 15, y: 9, z: 21})
50    assert b.pow(3) == b**3 == Factors({y: 12, z: 9, t: 30})
51
52    pytest.raises(ValueError, lambda: a.pow(3.1))
53    pytest.raises(ValueError, lambda: a.pow(Factors(3.1)))
54
55    assert a.pow(0) == Factors()
56
57    assert a.gcd(b) == Factors({y: 3, z: 3})
58    assert a.lcm(b) == a.lcm(b.as_expr()) == Factors({x: 5, y: 4, z: 7, t: 10})
59
60    a = Factors({x: 4, y: 7, t: 7})
61    b = Factors({z: 1, t: 3})
62
63    assert a.normal(b) == (Factors({x: 4, y: 7, t: 4}), Factors({z: 1}))
64
65    assert Factors(sqrt(2)*x).as_expr() == sqrt(2)*x
66
67    assert Factors(-I)*I == Factors()
68    assert Factors({Integer(-1): Integer(3)})*Factors({Integer(-1): Integer(1), I: Integer(5)}) == \
69        Factors(I)
70
71    assert Factors(Integer(2)**x).div(Integer(3)**x) == \
72        (Factors({Integer(2): x}), Factors({Integer(3): x}))
73    assert Factors(2**(2*x + 2)).div(Integer(8)) == \
74        (Factors({Integer(2): 2*x + 2}), Factors({Integer(8): Integer(1)}))
75
76    # coverage
77    # /!\ things break if this is not True
78    assert Factors({Integer(-1): Rational(3, 2)}) == Factors({I: 1, -1: 1})
79    assert Factors({I: Integer(1), Integer(-1): Rational(1, 3)}).as_expr() == I*cbrt(-1)
80
81    assert Factors(-1.) == Factors({Integer(-1): Integer(1), Float(1.): 1})
82    assert Factors(-2.) == Factors({Integer(-1): Integer(1), Float(2.): 1})
83    assert Factors((-2.)**x) == Factors({Float(-2.): x})
84    assert Factors(Integer(-2)) == Factors({Integer(-1): Integer(1), Integer(2): 1})
85    assert Factors(Rational(1, 2)) == Factors({Integer(2): -1})
86    assert Factors(Rational(3, 2)) == Factors({Integer(3): 1, Integer(2): Integer(-1)})
87    assert Factors({I: Integer(1)}) == Factors(I)
88    assert Factors({-1.0: 2, I: 1}) == Factors({Float(1.0): 1, I: 1})
89    assert Factors({-1: -Rational(3, 2)}).as_expr() == I
90    A = symbols('A', commutative=False)
91    assert Factors(2*A**2) == Factors({Integer(2): 1, A**2: 1})
92    assert Factors(I) == Factors({I: 1})
93    assert Factors(x).normal(Integer(2)) == (Factors(x), Factors(Integer(2)))
94    assert Factors(x).normal(Integer(0)) == (Factors(), Factors(Integer(0)))
95    pytest.raises(ZeroDivisionError, lambda: Factors(x).div(Integer(0)))
96    assert Factors(x).mul(Integer(2)) == Factors(2*x)
97    assert Factors(x).mul(Integer(0)).is_zero
98    assert Factors(x).mul(1/x).is_one
99    assert Factors(x**sqrt(8)).as_expr() == x**(2*sqrt(2))
100    assert Factors(x)**Factors(Integer(2)) == Factors(x**2)
101    assert Factors(x).gcd(Integer(0)) == Factors(x)
102    assert Factors(x).lcm(Integer(0)).is_zero
103    assert Factors(Integer(0)).div(x) == (Factors(Integer(0)), Factors())
104    assert Factors(x).div(x) == (Factors(), Factors())
105    assert Factors({x: .2})/Factors({x: .2}) == Factors()
106    assert Factors(x) != Factors()
107    assert Factors(x) == x
108    assert Factors(Integer(0)).normal(x) == (Factors(Integer(0)), Factors())
109    n, d = x**(2 + y), x**2
110    f = Factors(n)
111    assert f.div(d) == f.normal(d) == (Factors(x**y), Factors())
112    assert f.gcd(d) == Factors()
113    d = x**y
114    assert f.div(d) == f.normal(d) == (Factors(x**2), Factors())
115    assert f.gcd(d) == Factors(d)
116    n = d = 2**x
117    f = Factors(n)
118    assert f.div(d) == f.normal(d) == (Factors(), Factors())
119    assert f.gcd(d) == Factors(d)
120    n, d = 2**x, 2**y
121    f = Factors(n)
122    assert f.div(d) == f.normal(d) == (Factors({Integer(2): x}), Factors({Integer(2): y}))
123    assert f.gcd(d) == Factors()
124
125    assert f.div(f) == (Factors(), Factors())
126
127    # extraction of constant only
128    n = x**(x + 3)
129    assert Factors(n).normal(x**-3) == (Factors({x: x + 6}), Factors({}))
130    assert Factors(n).normal(x**3) == (Factors({x: x}), Factors({}))
131    assert Factors(n).normal(x**4) == (Factors({x: x}), Factors({x: 1}))
132    assert Factors(n).normal(x**(y - 3)) == \
133        (Factors({x: x + 6}), Factors({x: y}))
134    assert Factors(n).normal(x**(y + 3)) == (Factors({x: x}), Factors({x: y}))
135    assert Factors(n).normal(x**(y + 4)) == \
136        (Factors({x: x}), Factors({x: y + 1}))
137
138    assert Factors(n).div(x**-3) == (Factors({x: x + 6}), Factors({}))
139    assert Factors(n).div(x**3) == (Factors({x: x}), Factors({}))
140    assert Factors(n).div(x**4) == (Factors({x: x}), Factors({x: 1}))
141    assert Factors(n).div(x**(y - 3)) == \
142        (Factors({x: x + 6}), Factors({x: y}))
143    assert Factors(n).div(x**(y + 3)) == (Factors({x: x}), Factors({x: y}))
144    assert Factors(n).div(x**(y + 4)) == \
145        (Factors({x: x}), Factors({x: y + 1}))
146
147    assert Factors({I: I}).as_expr() == (-1)**(I/2)
148    assert Factors({-1: Rational(4, 3)}).as_expr() == -cbrt(-1)
149
150
151def test_Term():
152    a = Term(4*x*y**2/z/t**3)
153    b = Term(2*x**3*y**5/t**3)
154
155    assert a == Term(4, Factors({x: 1, y: 2}), Factors({z: 1, t: 3}))
156    assert b == Term(2, Factors({x: 3, y: 5}), Factors({t: 3}))
157
158    assert a.as_expr() == 4*x*y**2/z/t**3
159    assert b.as_expr() == 2*x**3*y**5/t**3
160
161    assert a.inv() == \
162        Term(Rational(1, 4), Factors({z: 1, t: 3}), Factors({x: 1, y: 2}))
163    assert b.inv() == Term(Rational(1, 2), Factors({t: 3}), Factors({x: 3, y: 5}))
164
165    assert a.mul(b) == a*b == \
166        Term(8, Factors({x: 4, y: 7}), Factors({z: 1, t: 6}))
167    assert a.quo(b) == a/b == Term(2, Factors({}), Factors({x: 2, y: 3, z: 1}))
168
169    assert a.pow(3) == a**3 == \
170        Term(64, Factors({x: 3, y: 6}), Factors({z: 3, t: 9}))
171    assert b.pow(3) == b**3 == Term(8, Factors({x: 9, y: 15}), Factors({t: 9}))
172
173    assert a.pow(-3) == a**(-3) == \
174        Term(Rational(1, 64), Factors({z: 3, t: 9}), Factors({x: 3, y: 6}))
175    assert b.pow(-3) == b**(-3) == \
176        Term(Rational(1, 8), Factors({t: 9}), Factors({x: 9, y: 15}))
177
178    assert a.gcd(b) == Term(2, Factors({x: 1, y: 2}), Factors({t: 3}))
179    assert a.lcm(b) == Term(4, Factors({x: 3, y: 5}), Factors({z: 1, t: 3}))
180
181    a = Term(4*x*y**2/z/t**3)
182    b = Term(2*x**3*y**5*t**7)
183
184    assert a.mul(b) == Term(8, Factors({x: 4, y: 7, t: 4}), Factors({z: 1}))
185
186    assert Term((2*x + 2)**3) == Term(8, Factors({x + 1: 3}), Factors({}))
187    assert Term((2*x + 2)*(3*x + 6)**2) == \
188        Term(18, Factors({x + 1: 1, x + 2: 2}), Factors({}))
189
190    A = Symbol('A', commutative=False)
191    pytest.raises(NonCommutativeExpression, lambda: Term(A))
192
193    f1, f2 = Factors({x: 2}), Factors()
194    assert Term(2, numer=f1) == Term(2, f1, f2)
195    assert Term(2, denom=f1) == Term(2, f2, f1)
196
197    pytest.raises(TypeError, lambda: a*2)
198    pytest.raises(TypeError, lambda: a/3)
199    pytest.raises(TypeError, lambda: a**3.1)
200
201
202def test_gcd_terms():
203    f = 2*(x + 1)*(x + 4)/(5*x**2 + 5) + (2*x + 2)*(x + 5)/(x**2 + 1)/5 + \
204        (2*x + 2)*(x + 6)/(5*x**2 + 5)
205
206    assert _gcd_terms(f) == (Rational(6, 5)*((1 + x)/(1 + x**2)), 5 + x, 1)
207    assert _gcd_terms(Add.make_args(f)) == \
208        (Rational(6, 5)*((1 + x)/(1 + x**2)), 5 + x, 1)
209
210    newf = Rational(6, 5)*((1 + x)*(5 + x)/(1 + x**2))
211    assert gcd_terms(f) == newf
212    args = Add.make_args(f)
213    # non-Basic sequences of terms treated as terms of Add
214    assert gcd_terms(list(args)) == newf
215    assert gcd_terms(tuple(args)) == newf
216    assert gcd_terms(set(args)) == newf
217    # but a Basic sequence is treated as a container
218    assert gcd_terms(Tuple(*args)) != newf
219    assert gcd_terms(Basic(Tuple(1, 3*y + 3*x*y), Tuple(1, 3))) == \
220        Basic((1, 3*y*(x + 1)), (1, 3))
221    # but we shouldn't change keys of a dictionary or some may be lost
222    assert gcd_terms(Dict((x*(1 + y), 2), (x + x*y, y + x*y))) == \
223        Dict({x*(y + 1): 2, x + x*y: y*(1 + x)})
224
225    assert gcd_terms((2*x + 2)**3 + (2*x + 2)**2) == 4*(x + 1)**2*(2*x + 3)
226
227    assert gcd_terms(0) == 0
228    assert gcd_terms(1) == 1
229    assert gcd_terms(x) == x
230    assert gcd_terms(2 + 2*x) == Mul(2, 1 + x, evaluate=False)
231    arg = x*(2*x + 4*y)
232    garg = 2*x*(x + 2*y)
233    assert gcd_terms(arg) == garg
234    assert gcd_terms(sin(arg)) == sin(garg)
235
236    # issue sympy/sympy#6139-like
237    alpha, alpha1, alpha2, alpha3 = symbols('alpha:4')
238    a = alpha**2 - alpha*x**2 + alpha + x**3 - x*(alpha + 1)
239    rep = {alpha: (1 + sqrt(5))/2 + alpha1*x + alpha2*x**2 + alpha3*x**3}
240    s = (a/(x - alpha)).subs(rep).series(x, 0, 1)
241    assert simplify(collect(s, x)) == -sqrt(5)/2 - Rational(3, 2) + O(x)
242
243    # issue sympy/sympy#5917
244    assert _gcd_terms([Integer(0), Integer(0)]) == (0, 0, 1)
245    assert _gcd_terms([2*x + 4]) == (2, x + 2, 1)
246
247    eq = x/(x + 1/x)
248    assert gcd_terms(eq, fraction=False) == eq
249
250
251def test_factor_terms():
252    A = Symbol('A', commutative=False)
253    assert factor_terms(9*(x + x*y + 1) + (3*x + 3)**(2 + 2*x)) == \
254        9*x*y + 9*x + _keep_coeff(Integer(3), x + 1)**_keep_coeff(Integer(2), x + 1) + 9
255    assert factor_terms(9*(x + x*y + 1) + 3**(2 + 2*x)) == \
256        _keep_coeff(Integer(9), 3**(2*x) + x*y + x + 1)
257    assert factor_terms(3**(2 + 2*x) + a*3**(2 + 2*x)) == \
258        9*3**(2*x)*(a + 1)
259    assert factor_terms(x + x*A) == \
260        x*(1 + A)
261    assert factor_terms(sin(x + x*A)) == \
262        sin(x*(1 + A))
263    assert factor_terms((3*x + 3)**((2 + 2*x)/3)) == \
264        _keep_coeff(Integer(3), x + 1)**_keep_coeff(Rational(2, 3), x + 1)
265    assert factor_terms(x + (x*y + x)**(3*x + 3)) == \
266        x + (x*(y + 1))**_keep_coeff(Integer(3), x + 1)
267    assert factor_terms(a*(x + x*y) + b*(x*2 + y*x*2)) == \
268        x*(a + 2*b)*(y + 1)
269    i = Integral(x, (x, 0, oo))
270    assert factor_terms(i) == i
271
272    # check radical extraction
273    eq = sqrt(2) + sqrt(10)
274    assert factor_terms(eq) == eq
275    assert factor_terms(eq, radical=True) == sqrt(2)*(1 + sqrt(5))
276    eq = root(-6, 3) + root(6, 3)
277    assert factor_terms(eq, radical=True) == cbrt(6)*(1 + cbrt(-1))
278
279    eq = [x + x*y]
280    ans = [x*(y + 1)]
281    for c in [list, tuple, set]:
282        assert factor_terms(c(eq)) == c(ans)
283    assert factor_terms(Tuple(x + x*y)) == Tuple(x*(y + 1))
284    assert factor_terms(Interval(0, 1)) == Interval(0, 1)
285    e = 1/sqrt(a/2 + 1)
286    assert factor_terms(e, clear=False) == 1/sqrt(a/2 + 1)
287    assert factor_terms(e, clear=True) == sqrt(2)/sqrt(a + 2)
288
289    eq = x/(x + 1/x) + 1/(x**2 + 1)
290    assert factor_terms(eq, fraction=False) == eq
291    assert factor_terms(eq, fraction=True) == 1
292
293    assert factor_terms((1/(x**3 + x**2) + 2/x**2)*y) == \
294        y*(2 + 1/(x + 1))/x**2
295
296    # if not True, then processesing for this in factor_terms is not necessary
297    assert gcd_terms(-x - y) == -x - y
298    assert factor_terms(-x - y) == Mul(-1, x + y, evaluate=False)
299
300    # if not True, then "special" processesing in factor_terms is not necessary
301    assert gcd_terms(exp(Mul(-1, x + 1))) == exp(-x - 1)
302    e = exp(-x - 2) + x
303    assert factor_terms(e) == exp(Mul(-1, x + 2, evaluate=False)) + x
304    assert factor_terms(e, sign=False) == e
305    assert factor_terms(exp(-4*x - 2) - x) == -x + exp(Mul(-2, 2*x + 1, evaluate=False))
306
307
308def test_xreplace():
309    e = Mul(2, 1 + x, evaluate=False)
310    assert e.xreplace({}) == e
311    assert e.xreplace({y: x}) == e
312
313
314def test_factor_nc():
315    x, y = symbols('x,y')
316    k = symbols('k', integer=True)
317    n, m, o = symbols('n,m,o', commutative=False)
318
319    # mul and multinomial expansion is needed
320    e = x*(1 + y)**2
321    assert _mexpand(e) == x + x*2*y + x*y**2
322
323    def factor_nc_test(e):
324        ex = _mexpand(e)
325        assert ex.is_Add
326        f = factor_nc(ex)
327        assert not f.is_Add and _mexpand(f) == ex
328
329    factor_nc_test(x*(1 + y))
330    factor_nc_test(n*(x + 1))
331    factor_nc_test(n*(x + m))
332    factor_nc_test((x + m)*n)
333    factor_nc_test(n*m*(x*o + n*o*m)*n)
334    s = Sum(x, (x, 1, 2))
335    factor_nc_test(x*(1 + s))
336    factor_nc_test(x*(1 + s)*s)
337    factor_nc_test(x*(1 + sin(s)))
338    factor_nc_test((1 + n)**2)
339
340    factor_nc_test((x + n)*(x + m)*(x + y))
341    factor_nc_test(x*(n*m + 1))
342    factor_nc_test(x*(n*m + x))
343    factor_nc_test(x*(x*n*m + 1))
344    factor_nc_test(x*n*(x*m + 1))
345    factor_nc_test(x*(m*n + x*n*m))
346    factor_nc_test(n*(1 - m)*n**2)
347
348    factor_nc_test((n + m)**2)
349    factor_nc_test((n - m)*(n + m)**2)
350    factor_nc_test((n + m)**2*(n - m))
351    factor_nc_test((m - n)*(n + m)**2*(n - m))
352
353    assert factor_nc(n*(n + n*m)) == n**2*(1 + m)
354    assert factor_nc(m*(m*n + n*m*n**2)) == m*(m + n*m*n)*n
355    eq = m*sin(n) - sin(n)*m
356    assert factor_nc(eq) == eq
357
358    eq = (sin(n) + x)*(cos(n) + x)
359    assert factor_nc(eq.expand()) == eq
360
361    # issue sympy/sympy#6534
362    assert (2*n + 2*m).factor() == 2*(n + m)
363
364    # issue sympy/sympy#6701
365    assert factor_nc(n**k + n**(k + 1)) == n**k*(1 + n)
366    assert factor_nc((m*n)**k + (m*n)**(k + 1)) == (1 + m*n)*(m*n)**k
367
368    # issue sympy/sympy#6918
369    assert factor_nc(-n*(2*x**2 + 2*x)) == -2*n*x*(x + 1)
370
371    assert factor_nc(1 + Mul(Expr(), Expr(), evaluate=False)) == 1 + Expr()**2
372
373
374def test_sympyissue_6360():
375    a, b = symbols('a b')
376    apb = a + b
377    eq = apb + apb**2*(-2*a - 2*b)
378    assert factor_terms(sub_pre(eq)) == a + b - 2*(a + b)**3
379
380
381def test_sympyissue_7903():
382    a = symbols(r'a', extended_real=True)
383    t = exp(I*cos(a)) + exp(-I*sin(a))
384    assert t.simplify()
385