1"""Tests for tools for manipulating of large commutative expressions.""" 2 3import pytest 4 5from diofant import (Add, Basic, Dict, Expr, Float, I, Integer, Integral, 6 Interval, Mul, O, Rational, Sum, Symbol, Tuple, cbrt, 7 collect, cos, exp, oo, root, simplify, sin, sqrt, symbols) 8from diofant.abc import a, b, t, x, y, z 9from diofant.core.coreerrors import NonCommutativeExpression 10from diofant.core.exprtools import (Factors, Term, _gcd_terms, decompose_power, 11 factor_nc, factor_terms, gcd_terms) 12from diofant.core.function import _mexpand 13from diofant.core.mul import _keep_coeff 14from diofant.simplify.cse_opts import sub_pre 15 16 17__all__ = () 18 19 20def test_decompose_power(): 21 assert decompose_power(x) == (x, 1) 22 assert decompose_power(x**2) == (x, 2) 23 assert decompose_power(x**(2*y)) == (x**y, 2) 24 assert decompose_power(x**(2*y/3)) == (x**(y/3), 2) 25 26 27def test_Factors(): 28 assert Factors() == Factors({}) == Factors(Integer(1)) 29 assert Factors(Integer(1)) == Factors(Factors(Integer(1))) 30 assert Factors().as_expr() == 1 31 assert Factors({x: 2, y: 3, sin(x): 4}).as_expr() == x**2*y**3*sin(x)**4 32 assert Factors(+oo) == Factors({oo: 1}) 33 assert Factors(-oo) == Factors({oo: 1, -1: 1}) 34 35 f1 = Factors({oo: 1}) 36 f2 = Factors({oo: 1}) 37 assert hash(f1) == hash(f2) 38 39 a = Factors({x: 5, y: 3, z: 7}) 40 b = Factors({ y: 4, z: 3, t: 10}) 41 42 assert a.mul(b) == a*b == Factors({x: 5, y: 7, z: 10, t: 10}) 43 44 assert a.div(b) == divmod(a, b) == \ 45 (Factors({x: 5, z: 4}), Factors({y: 1, t: 10})) 46 assert a.quo(b) == a/b == Factors({x: 5, z: 4}) 47 assert a.rem(b) == a % b == Factors({y: 1, t: 10}) 48 49 assert a.pow(3) == a**3 == Factors({x: 15, y: 9, z: 21}) 50 assert b.pow(3) == b**3 == Factors({y: 12, z: 9, t: 30}) 51 52 pytest.raises(ValueError, lambda: a.pow(3.1)) 53 pytest.raises(ValueError, lambda: a.pow(Factors(3.1))) 54 55 assert a.pow(0) == Factors() 56 57 assert a.gcd(b) == Factors({y: 3, z: 3}) 58 assert a.lcm(b) == a.lcm(b.as_expr()) == Factors({x: 5, y: 4, z: 7, t: 10}) 59 60 a = Factors({x: 4, y: 7, t: 7}) 61 b = Factors({z: 1, t: 3}) 62 63 assert a.normal(b) == (Factors({x: 4, y: 7, t: 4}), Factors({z: 1})) 64 65 assert Factors(sqrt(2)*x).as_expr() == sqrt(2)*x 66 67 assert Factors(-I)*I == Factors() 68 assert Factors({Integer(-1): Integer(3)})*Factors({Integer(-1): Integer(1), I: Integer(5)}) == \ 69 Factors(I) 70 71 assert Factors(Integer(2)**x).div(Integer(3)**x) == \ 72 (Factors({Integer(2): x}), Factors({Integer(3): x})) 73 assert Factors(2**(2*x + 2)).div(Integer(8)) == \ 74 (Factors({Integer(2): 2*x + 2}), Factors({Integer(8): Integer(1)})) 75 76 # coverage 77 # /!\ things break if this is not True 78 assert Factors({Integer(-1): Rational(3, 2)}) == Factors({I: 1, -1: 1}) 79 assert Factors({I: Integer(1), Integer(-1): Rational(1, 3)}).as_expr() == I*cbrt(-1) 80 81 assert Factors(-1.) == Factors({Integer(-1): Integer(1), Float(1.): 1}) 82 assert Factors(-2.) == Factors({Integer(-1): Integer(1), Float(2.): 1}) 83 assert Factors((-2.)**x) == Factors({Float(-2.): x}) 84 assert Factors(Integer(-2)) == Factors({Integer(-1): Integer(1), Integer(2): 1}) 85 assert Factors(Rational(1, 2)) == Factors({Integer(2): -1}) 86 assert Factors(Rational(3, 2)) == Factors({Integer(3): 1, Integer(2): Integer(-1)}) 87 assert Factors({I: Integer(1)}) == Factors(I) 88 assert Factors({-1.0: 2, I: 1}) == Factors({Float(1.0): 1, I: 1}) 89 assert Factors({-1: -Rational(3, 2)}).as_expr() == I 90 A = symbols('A', commutative=False) 91 assert Factors(2*A**2) == Factors({Integer(2): 1, A**2: 1}) 92 assert Factors(I) == Factors({I: 1}) 93 assert Factors(x).normal(Integer(2)) == (Factors(x), Factors(Integer(2))) 94 assert Factors(x).normal(Integer(0)) == (Factors(), Factors(Integer(0))) 95 pytest.raises(ZeroDivisionError, lambda: Factors(x).div(Integer(0))) 96 assert Factors(x).mul(Integer(2)) == Factors(2*x) 97 assert Factors(x).mul(Integer(0)).is_zero 98 assert Factors(x).mul(1/x).is_one 99 assert Factors(x**sqrt(8)).as_expr() == x**(2*sqrt(2)) 100 assert Factors(x)**Factors(Integer(2)) == Factors(x**2) 101 assert Factors(x).gcd(Integer(0)) == Factors(x) 102 assert Factors(x).lcm(Integer(0)).is_zero 103 assert Factors(Integer(0)).div(x) == (Factors(Integer(0)), Factors()) 104 assert Factors(x).div(x) == (Factors(), Factors()) 105 assert Factors({x: .2})/Factors({x: .2}) == Factors() 106 assert Factors(x) != Factors() 107 assert Factors(x) == x 108 assert Factors(Integer(0)).normal(x) == (Factors(Integer(0)), Factors()) 109 n, d = x**(2 + y), x**2 110 f = Factors(n) 111 assert f.div(d) == f.normal(d) == (Factors(x**y), Factors()) 112 assert f.gcd(d) == Factors() 113 d = x**y 114 assert f.div(d) == f.normal(d) == (Factors(x**2), Factors()) 115 assert f.gcd(d) == Factors(d) 116 n = d = 2**x 117 f = Factors(n) 118 assert f.div(d) == f.normal(d) == (Factors(), Factors()) 119 assert f.gcd(d) == Factors(d) 120 n, d = 2**x, 2**y 121 f = Factors(n) 122 assert f.div(d) == f.normal(d) == (Factors({Integer(2): x}), Factors({Integer(2): y})) 123 assert f.gcd(d) == Factors() 124 125 assert f.div(f) == (Factors(), Factors()) 126 127 # extraction of constant only 128 n = x**(x + 3) 129 assert Factors(n).normal(x**-3) == (Factors({x: x + 6}), Factors({})) 130 assert Factors(n).normal(x**3) == (Factors({x: x}), Factors({})) 131 assert Factors(n).normal(x**4) == (Factors({x: x}), Factors({x: 1})) 132 assert Factors(n).normal(x**(y - 3)) == \ 133 (Factors({x: x + 6}), Factors({x: y})) 134 assert Factors(n).normal(x**(y + 3)) == (Factors({x: x}), Factors({x: y})) 135 assert Factors(n).normal(x**(y + 4)) == \ 136 (Factors({x: x}), Factors({x: y + 1})) 137 138 assert Factors(n).div(x**-3) == (Factors({x: x + 6}), Factors({})) 139 assert Factors(n).div(x**3) == (Factors({x: x}), Factors({})) 140 assert Factors(n).div(x**4) == (Factors({x: x}), Factors({x: 1})) 141 assert Factors(n).div(x**(y - 3)) == \ 142 (Factors({x: x + 6}), Factors({x: y})) 143 assert Factors(n).div(x**(y + 3)) == (Factors({x: x}), Factors({x: y})) 144 assert Factors(n).div(x**(y + 4)) == \ 145 (Factors({x: x}), Factors({x: y + 1})) 146 147 assert Factors({I: I}).as_expr() == (-1)**(I/2) 148 assert Factors({-1: Rational(4, 3)}).as_expr() == -cbrt(-1) 149 150 151def test_Term(): 152 a = Term(4*x*y**2/z/t**3) 153 b = Term(2*x**3*y**5/t**3) 154 155 assert a == Term(4, Factors({x: 1, y: 2}), Factors({z: 1, t: 3})) 156 assert b == Term(2, Factors({x: 3, y: 5}), Factors({t: 3})) 157 158 assert a.as_expr() == 4*x*y**2/z/t**3 159 assert b.as_expr() == 2*x**3*y**5/t**3 160 161 assert a.inv() == \ 162 Term(Rational(1, 4), Factors({z: 1, t: 3}), Factors({x: 1, y: 2})) 163 assert b.inv() == Term(Rational(1, 2), Factors({t: 3}), Factors({x: 3, y: 5})) 164 165 assert a.mul(b) == a*b == \ 166 Term(8, Factors({x: 4, y: 7}), Factors({z: 1, t: 6})) 167 assert a.quo(b) == a/b == Term(2, Factors({}), Factors({x: 2, y: 3, z: 1})) 168 169 assert a.pow(3) == a**3 == \ 170 Term(64, Factors({x: 3, y: 6}), Factors({z: 3, t: 9})) 171 assert b.pow(3) == b**3 == Term(8, Factors({x: 9, y: 15}), Factors({t: 9})) 172 173 assert a.pow(-3) == a**(-3) == \ 174 Term(Rational(1, 64), Factors({z: 3, t: 9}), Factors({x: 3, y: 6})) 175 assert b.pow(-3) == b**(-3) == \ 176 Term(Rational(1, 8), Factors({t: 9}), Factors({x: 9, y: 15})) 177 178 assert a.gcd(b) == Term(2, Factors({x: 1, y: 2}), Factors({t: 3})) 179 assert a.lcm(b) == Term(4, Factors({x: 3, y: 5}), Factors({z: 1, t: 3})) 180 181 a = Term(4*x*y**2/z/t**3) 182 b = Term(2*x**3*y**5*t**7) 183 184 assert a.mul(b) == Term(8, Factors({x: 4, y: 7, t: 4}), Factors({z: 1})) 185 186 assert Term((2*x + 2)**3) == Term(8, Factors({x + 1: 3}), Factors({})) 187 assert Term((2*x + 2)*(3*x + 6)**2) == \ 188 Term(18, Factors({x + 1: 1, x + 2: 2}), Factors({})) 189 190 A = Symbol('A', commutative=False) 191 pytest.raises(NonCommutativeExpression, lambda: Term(A)) 192 193 f1, f2 = Factors({x: 2}), Factors() 194 assert Term(2, numer=f1) == Term(2, f1, f2) 195 assert Term(2, denom=f1) == Term(2, f2, f1) 196 197 pytest.raises(TypeError, lambda: a*2) 198 pytest.raises(TypeError, lambda: a/3) 199 pytest.raises(TypeError, lambda: a**3.1) 200 201 202def test_gcd_terms(): 203 f = 2*(x + 1)*(x + 4)/(5*x**2 + 5) + (2*x + 2)*(x + 5)/(x**2 + 1)/5 + \ 204 (2*x + 2)*(x + 6)/(5*x**2 + 5) 205 206 assert _gcd_terms(f) == (Rational(6, 5)*((1 + x)/(1 + x**2)), 5 + x, 1) 207 assert _gcd_terms(Add.make_args(f)) == \ 208 (Rational(6, 5)*((1 + x)/(1 + x**2)), 5 + x, 1) 209 210 newf = Rational(6, 5)*((1 + x)*(5 + x)/(1 + x**2)) 211 assert gcd_terms(f) == newf 212 args = Add.make_args(f) 213 # non-Basic sequences of terms treated as terms of Add 214 assert gcd_terms(list(args)) == newf 215 assert gcd_terms(tuple(args)) == newf 216 assert gcd_terms(set(args)) == newf 217 # but a Basic sequence is treated as a container 218 assert gcd_terms(Tuple(*args)) != newf 219 assert gcd_terms(Basic(Tuple(1, 3*y + 3*x*y), Tuple(1, 3))) == \ 220 Basic((1, 3*y*(x + 1)), (1, 3)) 221 # but we shouldn't change keys of a dictionary or some may be lost 222 assert gcd_terms(Dict((x*(1 + y), 2), (x + x*y, y + x*y))) == \ 223 Dict({x*(y + 1): 2, x + x*y: y*(1 + x)}) 224 225 assert gcd_terms((2*x + 2)**3 + (2*x + 2)**2) == 4*(x + 1)**2*(2*x + 3) 226 227 assert gcd_terms(0) == 0 228 assert gcd_terms(1) == 1 229 assert gcd_terms(x) == x 230 assert gcd_terms(2 + 2*x) == Mul(2, 1 + x, evaluate=False) 231 arg = x*(2*x + 4*y) 232 garg = 2*x*(x + 2*y) 233 assert gcd_terms(arg) == garg 234 assert gcd_terms(sin(arg)) == sin(garg) 235 236 # issue sympy/sympy#6139-like 237 alpha, alpha1, alpha2, alpha3 = symbols('alpha:4') 238 a = alpha**2 - alpha*x**2 + alpha + x**3 - x*(alpha + 1) 239 rep = {alpha: (1 + sqrt(5))/2 + alpha1*x + alpha2*x**2 + alpha3*x**3} 240 s = (a/(x - alpha)).subs(rep).series(x, 0, 1) 241 assert simplify(collect(s, x)) == -sqrt(5)/2 - Rational(3, 2) + O(x) 242 243 # issue sympy/sympy#5917 244 assert _gcd_terms([Integer(0), Integer(0)]) == (0, 0, 1) 245 assert _gcd_terms([2*x + 4]) == (2, x + 2, 1) 246 247 eq = x/(x + 1/x) 248 assert gcd_terms(eq, fraction=False) == eq 249 250 251def test_factor_terms(): 252 A = Symbol('A', commutative=False) 253 assert factor_terms(9*(x + x*y + 1) + (3*x + 3)**(2 + 2*x)) == \ 254 9*x*y + 9*x + _keep_coeff(Integer(3), x + 1)**_keep_coeff(Integer(2), x + 1) + 9 255 assert factor_terms(9*(x + x*y + 1) + 3**(2 + 2*x)) == \ 256 _keep_coeff(Integer(9), 3**(2*x) + x*y + x + 1) 257 assert factor_terms(3**(2 + 2*x) + a*3**(2 + 2*x)) == \ 258 9*3**(2*x)*(a + 1) 259 assert factor_terms(x + x*A) == \ 260 x*(1 + A) 261 assert factor_terms(sin(x + x*A)) == \ 262 sin(x*(1 + A)) 263 assert factor_terms((3*x + 3)**((2 + 2*x)/3)) == \ 264 _keep_coeff(Integer(3), x + 1)**_keep_coeff(Rational(2, 3), x + 1) 265 assert factor_terms(x + (x*y + x)**(3*x + 3)) == \ 266 x + (x*(y + 1))**_keep_coeff(Integer(3), x + 1) 267 assert factor_terms(a*(x + x*y) + b*(x*2 + y*x*2)) == \ 268 x*(a + 2*b)*(y + 1) 269 i = Integral(x, (x, 0, oo)) 270 assert factor_terms(i) == i 271 272 # check radical extraction 273 eq = sqrt(2) + sqrt(10) 274 assert factor_terms(eq) == eq 275 assert factor_terms(eq, radical=True) == sqrt(2)*(1 + sqrt(5)) 276 eq = root(-6, 3) + root(6, 3) 277 assert factor_terms(eq, radical=True) == cbrt(6)*(1 + cbrt(-1)) 278 279 eq = [x + x*y] 280 ans = [x*(y + 1)] 281 for c in [list, tuple, set]: 282 assert factor_terms(c(eq)) == c(ans) 283 assert factor_terms(Tuple(x + x*y)) == Tuple(x*(y + 1)) 284 assert factor_terms(Interval(0, 1)) == Interval(0, 1) 285 e = 1/sqrt(a/2 + 1) 286 assert factor_terms(e, clear=False) == 1/sqrt(a/2 + 1) 287 assert factor_terms(e, clear=True) == sqrt(2)/sqrt(a + 2) 288 289 eq = x/(x + 1/x) + 1/(x**2 + 1) 290 assert factor_terms(eq, fraction=False) == eq 291 assert factor_terms(eq, fraction=True) == 1 292 293 assert factor_terms((1/(x**3 + x**2) + 2/x**2)*y) == \ 294 y*(2 + 1/(x + 1))/x**2 295 296 # if not True, then processesing for this in factor_terms is not necessary 297 assert gcd_terms(-x - y) == -x - y 298 assert factor_terms(-x - y) == Mul(-1, x + y, evaluate=False) 299 300 # if not True, then "special" processesing in factor_terms is not necessary 301 assert gcd_terms(exp(Mul(-1, x + 1))) == exp(-x - 1) 302 e = exp(-x - 2) + x 303 assert factor_terms(e) == exp(Mul(-1, x + 2, evaluate=False)) + x 304 assert factor_terms(e, sign=False) == e 305 assert factor_terms(exp(-4*x - 2) - x) == -x + exp(Mul(-2, 2*x + 1, evaluate=False)) 306 307 308def test_xreplace(): 309 e = Mul(2, 1 + x, evaluate=False) 310 assert e.xreplace({}) == e 311 assert e.xreplace({y: x}) == e 312 313 314def test_factor_nc(): 315 x, y = symbols('x,y') 316 k = symbols('k', integer=True) 317 n, m, o = symbols('n,m,o', commutative=False) 318 319 # mul and multinomial expansion is needed 320 e = x*(1 + y)**2 321 assert _mexpand(e) == x + x*2*y + x*y**2 322 323 def factor_nc_test(e): 324 ex = _mexpand(e) 325 assert ex.is_Add 326 f = factor_nc(ex) 327 assert not f.is_Add and _mexpand(f) == ex 328 329 factor_nc_test(x*(1 + y)) 330 factor_nc_test(n*(x + 1)) 331 factor_nc_test(n*(x + m)) 332 factor_nc_test((x + m)*n) 333 factor_nc_test(n*m*(x*o + n*o*m)*n) 334 s = Sum(x, (x, 1, 2)) 335 factor_nc_test(x*(1 + s)) 336 factor_nc_test(x*(1 + s)*s) 337 factor_nc_test(x*(1 + sin(s))) 338 factor_nc_test((1 + n)**2) 339 340 factor_nc_test((x + n)*(x + m)*(x + y)) 341 factor_nc_test(x*(n*m + 1)) 342 factor_nc_test(x*(n*m + x)) 343 factor_nc_test(x*(x*n*m + 1)) 344 factor_nc_test(x*n*(x*m + 1)) 345 factor_nc_test(x*(m*n + x*n*m)) 346 factor_nc_test(n*(1 - m)*n**2) 347 348 factor_nc_test((n + m)**2) 349 factor_nc_test((n - m)*(n + m)**2) 350 factor_nc_test((n + m)**2*(n - m)) 351 factor_nc_test((m - n)*(n + m)**2*(n - m)) 352 353 assert factor_nc(n*(n + n*m)) == n**2*(1 + m) 354 assert factor_nc(m*(m*n + n*m*n**2)) == m*(m + n*m*n)*n 355 eq = m*sin(n) - sin(n)*m 356 assert factor_nc(eq) == eq 357 358 eq = (sin(n) + x)*(cos(n) + x) 359 assert factor_nc(eq.expand()) == eq 360 361 # issue sympy/sympy#6534 362 assert (2*n + 2*m).factor() == 2*(n + m) 363 364 # issue sympy/sympy#6701 365 assert factor_nc(n**k + n**(k + 1)) == n**k*(1 + n) 366 assert factor_nc((m*n)**k + (m*n)**(k + 1)) == (1 + m*n)*(m*n)**k 367 368 # issue sympy/sympy#6918 369 assert factor_nc(-n*(2*x**2 + 2*x)) == -2*n*x*(x + 1) 370 371 assert factor_nc(1 + Mul(Expr(), Expr(), evaluate=False)) == 1 + Expr()**2 372 373 374def test_sympyissue_6360(): 375 a, b = symbols('a b') 376 apb = a + b 377 eq = apb + apb**2*(-2*a - 2*b) 378 assert factor_terms(sub_pre(eq)) == a + b - 2*(a + b)**3 379 380 381def test_sympyissue_7903(): 382 a = symbols(r'a', extended_real=True) 383 t = exp(I*cos(a)) + exp(-I*sin(a)) 384 assert t.simplify() 385