1import pytest 2 3from diofant import (I, Integer, Mul, O, Pow, Rational, Symbol, cbrt, cos, exp, 4 expand, expand_multinomial, expand_power_base, log, pi, 5 sin, sqrt) 6from diofant.abc import x, y, z 7from diofant.simplify.radsimp import expand_numer 8from diofant.utilities.randtest import verify_numerically 9 10 11__all__ = () 12 13 14def test_expand_no_log(): 15 assert ( 16 (1 + log(x**4))**2).expand(log=False) == 1 + 2*log(x**4) + log(x**4)**2 17 assert ((1 + log(x**4))*(1 + log(x**3))).expand( 18 log=False) == 1 + log(x**4) + log(x**3) + log(x**4)*log(x**3) 19 20 21def test_expand_no_multinomial(): 22 assert ((1 + x)*(1 + (1 + x)**4)).expand(multinomial=False) == \ 23 1 + x + (1 + x)**4 + x*(1 + x)**4 24 25 26def test_expand_negative_integer_powers(): 27 expr = (x + y)**(-2) 28 assert expr.expand() == 1 / (2*x*y + x**2 + y**2) 29 assert expr.expand(multinomial=False) == (x + y)**(-2) 30 expr = (x + y)**(-3) 31 assert expr.expand() == 1 / (3*x*x*y + 3*x*y*y + x**3 + y**3) 32 assert expr.expand(multinomial=False) == (x + y)**(-3) 33 expr = (x + y)**2 * (x + y)**(-4) 34 assert expr.expand() == 1 / (2*x*y + x**2 + y**2) 35 assert expr.expand(multinomial=False) == (x + y)**(-2) 36 37 38def test_expand_non_commutative(): 39 A = Symbol('A', commutative=False) 40 B = Symbol('B', commutative=False) 41 C = Symbol('C', commutative=False) 42 a = Symbol('a') 43 b = Symbol('b') 44 i = Symbol('i', integer=True) 45 n = Symbol('n', negative=True, finite=True) 46 m = Symbol('m', negative=True, finite=True) 47 p = Symbol('p', polar=True) 48 np = Symbol('p', polar=False) 49 50 assert (C*(A + B)).expand() == C*A + C*B 51 assert (C*(A + B)).expand() != A*C + B*C 52 assert ((A + B)**2).expand() == A**2 + A*B + B*A + B**2 53 assert ((A + B)**3).expand() == (A**2*B + B**2*A + A*B**2 + B*A**2 + 54 A**3 + B**3 + A*B*A + B*A*B) 55 # issue sympy/sympy#6219 56 assert ((a*A*B*A**-1)**2).expand() == a**2*A*B**2/A 57 # Note that (a*A*B*A**-1)**2 is automatically converted to a**2*(A*B*A**-1)**2 58 assert ((a*A*B*A**-1)**2).expand(deep=False) == a**2*(A*B*A**-1)**2 59 assert ((a*A*B*A**-1)**2).expand() == a**2*(A*B**2*A**-1) 60 assert ((a*A*B*A**-1)**2).expand(force=True) == a**2*A*B**2*A**(-1) 61 assert ((a*A*B)**2).expand() == a**2*A*B*A*B 62 assert ((a*A)**2).expand() == a**2*A**2 63 assert ((a*A*B)**i).expand() == a**i*(A*B)**i 64 assert ((a*A*(B*(A*B/A)**2))**i).expand() == a**i*(A*B*A*B**2/A)**i 65 # issue sympy/sympy#6558 66 assert (A*B*(A*B)**-1).expand() == A*B*(A*B)**-1 67 assert ((a*A)**i).expand() == a**i*A**i 68 assert ((a*A*B*A**-1)**3).expand() == a**3*A*B**3/A 69 assert ((a*A*B*A*B/A)**3).expand() == \ 70 a**3*A*B*(A*B**2)*(A*B**2)*A*B*A**(-1) 71 assert ((a*A*B*A*B/A)**-3).expand() == \ 72 a**-3*(A*B*(A*B**2)*(A*B**2)*A*B*A**(-1))**-1 73 assert ((a*b*A*B*A**-1)**i).expand() == a**i*b**i*(A*B/A)**i 74 assert ((a*(a*b)**i)**i).expand() == a**i*a**(i**2)*b**(i**2) 75 e = Pow(Mul(a, 1/a, A, B, evaluate=False), Integer(2), evaluate=False) 76 assert e.expand() == A*B*A*B 77 assert sqrt(a*(A*b)**i).expand() == sqrt(a*b**i*A**i) 78 assert (sqrt(-a)**a).expand() == sqrt(-a)**a 79 assert expand((-2*n)**(i/3)) == 2**(i/3)*(-n)**(i/3) 80 assert expand((-2*n*m)**(i/a)) == (-2)**(i/a)*(-n)**(i/a)*(-m)**(i/a) 81 assert expand((-2*a*p)**b) == 2**b*p**b*(-a)**b 82 assert expand((-2*a*np)**b) == 2**b*(-a*np)**b 83 assert expand(sqrt(A*B)) == sqrt(A*B) 84 assert expand(sqrt(-2*a*b)) == sqrt(2)*sqrt(-a*b) 85 86 87def test_expand_radicals(): 88 a = sqrt(x + y) 89 90 assert (a**1).expand() == a 91 assert (a**3).expand() == x*a + y*a 92 assert (a**5).expand() == x**2*a + 2*x*y*a + y**2*a 93 94 assert (1/a**1).expand() == 1/a 95 assert (1/a**3).expand() == 1/(x*a + y*a) 96 assert (1/a**5).expand() == 1/(x**2*a + 2*x*y*a + y**2*a) 97 98 a = cbrt(x + y) 99 100 assert (a**1).expand() == a 101 assert (a**2).expand() == a**2 102 assert (a**4).expand() == x*a + y*a 103 assert (a**5).expand() == x*a**2 + y*a**2 104 assert (a**7).expand() == x**2*a + 2*x*y*a + y**2*a 105 106 107def test_expand_modulus(): 108 assert ((x + y)**11).expand(modulus=11) == x**11 + y**11 109 assert ((x + sqrt(2)*y)**11).expand(modulus=11) == x**11 + 10*sqrt(2)*y**11 110 assert (x + y/2).expand(modulus=1) == y/2 111 112 pytest.raises(ValueError, lambda: ((x + y)**11).expand(modulus=0)) 113 pytest.raises(ValueError, lambda: ((x + y)**11).expand(modulus=x)) 114 115 116def test_sympyissue_5743(): 117 assert (x*sqrt( 118 x + y)*(1 + sqrt(x + y))).expand() == x**2 + x*y + x*sqrt(x + y) 119 assert (x*sqrt( 120 x + y)*(1 + x*sqrt(x + y))).expand() == x**3 + x**2*y + x*sqrt(x + y) 121 122 123def test_expand_frac(): 124 assert expand((x + y)*y/x/(x + 1), frac=True) == \ 125 (x*y + y**2)/(x**2 + x) 126 assert expand((x + y)*y/x/(x + 1), numer=True) == \ 127 (x*y + y**2)/(x*(x + 1)) 128 assert expand((x + y)*y/x/(x + 1), denom=True) == \ 129 y*(x + y)/(x**2 + x) 130 eq = (x + 1)**2/y 131 assert expand_numer(eq, multinomial=False) == eq 132 133 134def test_sympyissue_6121(): 135 eq = -I*exp(-3*I*pi/4)/(4*pi**Rational(3, 2)*sqrt(x)) 136 assert eq.expand(complex=True) # does not give oo recursion 137 138 139def test_expand_power_base(): 140 assert expand_power_base((x*y*z)**4) == x**4*y**4*z**4 141 assert expand_power_base((x*y*z)**x).is_Pow 142 assert expand_power_base((x*y*z)**x, force=True) == x**x*y**x*z**x 143 assert expand_power_base((x*(y*z)**2)**3) == x**3*y**6*z**6 144 145 assert expand_power_base((sin((x*y)**2)*y)**z).is_Pow 146 assert expand_power_base( 147 (sin((x*y)**2)*y)**z, force=True) == sin((x*y)**2)**z*y**z 148 assert expand_power_base( 149 (sin((x*y)**2)*y)**z, deep=True) == (sin(x**2*y**2)*y)**z 150 151 assert expand_power_base(exp(x)**2) == exp(2*x) 152 assert expand_power_base((exp(x)*exp(y))**2) == exp(2*x)*exp(2*y) 153 154 assert expand_power_base( 155 (exp((x*y)**z)*exp(y))**2) == exp(2*(x*y)**z)*exp(2*y) 156 assert expand_power_base((exp((x*y)**z)*exp( 157 y))**2, deep=True, force=True) == exp(2*x**z*y**z)*exp(2*y) 158 159 assert expand_power_base((exp(x)*exp(y))**z).is_Pow 160 assert expand_power_base( 161 (exp(x)*exp(y))**z, force=True) == exp(x)**z*exp(y)**z 162 163 164def test_expand_arit(): 165 a = Symbol('a') 166 b = Symbol('b', positive=True) 167 c = Symbol('c') 168 169 p = Integer(5) 170 e = (a + b)*c 171 assert e == c*(a + b) 172 assert (e.expand() - a*c - b*c) == 0 173 e = (a + b)*(a + b) 174 assert e == (a + b)**2 175 assert e.expand() == 2*a*b + a**2 + b**2 176 e = (a + b)*(a + b)**2 177 assert e == (a + b)**3 178 assert e.expand() == 3*b*a**2 + 3*a*b**2 + a**3 + b**3 179 assert e.expand() == 3*b*a**2 + 3*a*b**2 + a**3 + b**3 180 e = (a + b)*(a + c)*(b + c) 181 assert e == (a + c)*(a + b)*(b + c) 182 assert e.expand() == 2*a*b*c + b*a**2 + c*a**2 + b*c**2 + a*c**2 + c*b**2 + a*b**2 183 e = (a + 1)**p 184 assert e == (1 + a)**5 185 assert e.expand() == 1 + 5*a + 10*a**2 + 10*a**3 + 5*a**4 + a**5 186 e = (a + b + c)*(a + c + p) 187 assert e == (5 + a + c)*(a + b + c) 188 assert e.expand() == 5*a + 5*b + 5*c + 2*a*c + b*c + a*b + a**2 + c**2 189 x = Symbol('x') 190 s = exp(x*x) - 1 191 e = s.series(x)/x**2 192 assert e.expand() == 1 + x**2/2 + O(x**4) 193 194 e = (x*(y + z))**(x*(y + z))*(x + y) 195 assert e.expand(power_exp=False, power_base=False) == x*(x*y + x * 196 z)**(x*y + x*z) + y*(x*y + x*z)**(x*y + x*z) 197 assert e.expand(power_exp=False, power_base=False, deep=False) == x * \ 198 (x*(y + z))**(x*(y + z)) + y*(x*(y + z))**(x*(y + z)) 199 e = (x*(y + z))**z 200 assert e.expand(power_base=True, mul=True, deep=True) in [x**z*(y + 201 z)**z, (x*y + x*z)**z] 202 assert ((2*y)**z).expand() == 2**z*y**z 203 p = Symbol('p', positive=True) 204 assert sqrt(-x).expand().is_Pow 205 assert sqrt(-x).expand(force=True) == I*sqrt(x) 206 assert ((2*y*p)**z).expand() == 2**z*p**z*y**z 207 assert ((2*y*p*x)**z).expand() == 2**z*p**z*(x*y)**z 208 assert ((2*y*p*x)**z).expand(force=True) == 2**z*p**z*x**z*y**z 209 assert ((2*y*p*-pi)**z).expand() == 2**z*pi**z*p**z*(-y)**z 210 assert ((2*y*p*-pi*x)**z).expand() == 2**z*pi**z*p**z*(-x*y)**z 211 n = Symbol('n', negative=True, finite=True) 212 m = Symbol('m', negative=True, finite=True) 213 assert ((-2*x*y*n)**z).expand() == 2**z*(-n)**z*(x*y)**z 214 assert ((-2*x*y*n*m)**z).expand() == 2**z*(-m)**z*(-n)**z*(-x*y)**z 215 # issue sympy/sympy#5482 216 assert sqrt(-2*x*n) == sqrt(2)*sqrt(-n)*sqrt(x) 217 # issue sympy/sympy#5605 (2) 218 assert (cos(x + y)**2).expand(trig=True) in [ 219 (-sin(x)*sin(y) + cos(x)*cos(y))**2, 220 sin(x)**2*sin(y)**2 - 2*sin(x)*sin(y)*cos(x)*cos(y) + cos(x)**2*cos(y)**2 221 ] 222 223 # Check that this isn't too slow 224 x = Symbol('x') 225 W = 1 226 for i in range(1, 21): 227 W = W * (x - i) 228 W = W.expand() 229 assert W.has(-1672280820*x**15) 230 231 232def test_power_expand(): 233 """Test for Pow.expand()""" 234 a = Symbol('a') 235 b = Symbol('b') 236 p = (a + b)**2 237 assert p.expand() == a**2 + b**2 + 2*a*b 238 239 p = (1 + 2*(1 + a))**2 240 assert p.expand() == 9 + 4*(a**2) + 12*a 241 242 p = 2**(a + b) 243 assert p.expand() == 2**a*2**b 244 245 A = Symbol('A', commutative=False) 246 B = Symbol('B', commutative=False) 247 assert (2**(A + B)).expand() == 2**(A + B) 248 assert (A**(a + b)).expand() != A**(a + b) 249 250 251def test_expand_multinomial(): 252 assert expand_multinomial((x + 1 + O(z))**2) == 1 + 2*x + x**2 + O(z) 253 assert expand_multinomial((x + 1 + 254 O(z))**3) == 1 + 3*x + 3*x**2 + x**3 + O(z) 255 256 257def test_sympyissues_5919_6830(): 258 # issue sympy/sympy#5919 259 n = -1 + 1/x 260 z = n/x/(-n)**2 - 1/n/x 261 assert expand(z) == 1/(x**2 - 2*x + 1) - 1/(x - 2 + 1/x) - 1/(-x + 1) 262 263 # issue sympy/sympy#6830 264 p = (1 + x)**2 265 assert expand_multinomial((1 + x*p)**2) == ( 266 x**2*(x**4 + 4*x**3 + 6*x**2 + 4*x + 1) + 2*x*(x**2 + 2*x + 1) + 1) 267 assert expand_multinomial((1 + (y + x)*p)**2) == ( 268 2*((x + y)*(x**2 + 2*x + 1)) + (x**2 + 2*x*y + y**2) * 269 (x**4 + 4*x**3 + 6*x**2 + 4*x + 1) + 1) 270 A = Symbol('A', commutative=False) 271 p = (1 + A)**2 272 assert expand_multinomial((1 + x*p)**2) == ( 273 x**2*(1 + 4*A + 6*A**2 + 4*A**3 + A**4) + 2*x*(1 + 2*A + A**2) + 1) 274 assert expand_multinomial((1 + (y + x)*p)**2) == ( 275 (x + y)*(1 + 2*A + A**2)*2 + (x**2 + 2*x*y + y**2) * 276 (1 + 4*A + 6*A**2 + 4*A**3 + A**4) + 1) 277 assert expand_multinomial((1 + (y + x)*p)**3) == ( 278 (x + y)*(1 + 2*A + A**2)*3 + (x**2 + 2*x*y + y**2)*(1 + 4*A + 279 6*A**2 + 4*A**3 + A**4)*3 + (x**3 + 3*x**2*y + 3*x*y**2 + y**3)*(1 + 6*A 280 + 15*A**2 + 20*A**3 + 15*A**4 + 6*A**5 + A**6) + 1) 281 # unevaluate powers 282 eq = (Pow((x + 1)*((A + 1)**2), 2, evaluate=False)) 283 # - in this case the base is not an Add so no further 284 # expansion is done 285 assert expand_multinomial(eq) == \ 286 (x**2 + 2*x + 1)*(1 + 4*A + 6*A**2 + 4*A**3 + A**4) 287 # - but here, the expanded base *is* an Add so it gets expanded 288 eq = (Pow(((A + 1)**2), 2, evaluate=False)) 289 assert expand_multinomial(eq) == 1 + 4*A + 6*A**2 + 4*A**3 + A**4 290 291 # coverage 292 def ok(a, b, n): 293 e = (a + I*b)**n 294 return verify_numerically(e, expand_multinomial(e)) 295 296 for a in [2, Rational(1, 2)]: 297 for b in [3, Rational(1, 3)]: 298 for n in range(2, 6): 299 assert ok(a, b, n) 300 301 e = (sin(x) + y)**3 302 assert (expand_multinomial(e.subs({y: O(x**4)})) == 303 expand_multinomial(e).subs({y: O(x**4)}) == sin(x)**3 + O(x**6)) 304 305 assert expand_multinomial(3**(x + y + 3)) == 27*3**(x + y) 306 307 308def test_expand_log(): 309 t = Symbol('t', positive=True) 310 # after first expansion, -2*log(2) + log(4); then 0 after second 311 assert expand(log(t**2) - log(t**2/4) - 2*log(2)) == 0 312