1"""This tests diofant/core/basic.py with (ideally) no reference to subclasses 2of Basic or Atom. 3""" 4 5import collections 6 7import pytest 8 9from diofant import (Atom, Basic, Function, I, Integral, Lambda, cos, 10 default_sort_key, exp, gamma, preorder_traversal, sin) 11from diofant.abc import w, x, y, z 12from diofant.core.singleton import S 13from diofant.core.singleton import SingletonWithManagedProperties as Singleton 14 15 16__all__ = () 17 18 19b1 = Basic() 20b2 = Basic(b1) 21b3 = Basic(b2) 22b21 = Basic(b2, b1) 23 24 25def test_structure(): 26 assert b21.args == (b2, b1) 27 assert b21.func(*b21.args) == b21 28 assert bool(b1) 29 30 31def test_equality(): 32 instances = [b1, b2, b3, b21, Basic(b1, b1, b1), Basic] 33 for i, b_i in enumerate(instances): 34 for j, b_j in enumerate(instances): 35 assert (b_i == b_j) == (i == j) 36 assert (b_i != b_j) == (i != j) 37 38 assert Basic() != [] 39 assert not Basic() == [] 40 assert Basic() != 0 41 assert not Basic() == 0 42 43 44def test_matches_basic(): 45 instances = [Basic(b1, b1, b2), Basic(b1, b2, b1), Basic(b2, b1, b1), 46 Basic(b1, b2), Basic(b2, b1), b2, b1] 47 for i, b_i in enumerate(instances): 48 for j, b_j in enumerate(instances): 49 if i == j: 50 assert b_j.match(b_i) == {} 51 else: 52 assert b_j.match(b_i) is None 53 assert b1.match(b1) == {} 54 55 56def test_has(): 57 assert b21.has(b1) 58 assert b21.has(b3, b1) 59 assert b21.has(Basic) 60 assert not b1.has(b21, b3) 61 assert not b21.has() 62 63 64def test_subs(): 65 assert b21.subs({b2: b1}) == Basic(b1, b1) 66 assert b21.subs({b2: b21}) == Basic(b21, b1) 67 assert b3.subs({b2: b1}) == b2 68 69 assert b21.subs([(b2, b1), (b1, b2)]) == Basic(b2, b2) 70 71 assert b21.subs({b1: b2, b2: b1}) == Basic(b2, b2) 72 73 pytest.raises(ValueError, lambda: b21.subs('bad arg')) 74 pytest.raises(ValueError, lambda: b21.subs(b1, b2, b3)) 75 76 assert b21.subs(collections.ChainMap({b1: b2}, {b2: b1})) == Basic(b2, b2) 77 assert b21.subs(collections.OrderedDict([(b2, b1), (b1, b2)])) == Basic(b2, b2) 78 79 80def test_rewrite(): 81 assert sin(1).rewrite() == sin(1) 82 83 f1 = sin(x) + cos(x) 84 assert f1.rewrite(cos, exp) == exp(I*x)/2 + sin(x) + exp(-I*x)/2 85 86 f2 = sin(x) + cos(y)/gamma(z) 87 assert f2.rewrite(sin, exp) == -I*(exp(I*x) - exp(-I*x))/2 + cos(y)/gamma(z) 88 89 90def test_atoms(): 91 assert b21.atoms() == set() 92 93 94def test_free_symbols_empty(): 95 assert b21.free_symbols == set() 96 97 98def test_doit(): 99 assert b21.doit() == b21 100 assert b21.doit(deep=False) == b21 101 102 103def test_S(): 104 assert repr(S) == 'S' 105 106 107def test_xreplace(): 108 assert b21.xreplace({b2: b1}) == Basic(b1, b1) 109 assert b21.xreplace({b2: b21}) == Basic(b21, b1) 110 assert b3.xreplace({b2: b1}) == b2 111 assert Basic(b1, b2).xreplace({b1: b2, b2: b1}) == Basic(b2, b1) 112 assert Atom(b1).xreplace({b1: b2}) == Atom(b1) 113 assert Atom(b1).xreplace({Atom(b1): b2}) == b2 114 pytest.raises(TypeError, lambda: b1.xreplace()) 115 pytest.raises(TypeError, lambda: b1.xreplace([b1, b2])) 116 117 118def test_Singleton(): 119 global instantiated 120 instantiated = 0 121 122 class MyNewSingleton(Basic, metaclass=Singleton): 123 def __new__(cls): 124 global instantiated 125 instantiated += 1 126 return Basic.__new__(cls) 127 128 assert instantiated == 0 129 MyNewSingleton() # force instantiation 130 assert instantiated == 1 131 assert MyNewSingleton() is not Basic() 132 assert MyNewSingleton() is MyNewSingleton() 133 assert S.MyNewSingleton is MyNewSingleton() 134 assert instantiated == 1 135 136 class MySingletonSub(MyNewSingleton): 137 pass 138 assert instantiated == 1 139 MySingletonSub() 140 assert instantiated == 2 141 assert MySingletonSub() is not MyNewSingleton() 142 assert MySingletonSub() is MySingletonSub() 143 144 145def test_preorder_traversal(): 146 expr = Basic(b21, b3) 147 assert list( 148 preorder_traversal(expr)) == [expr, b21, b2, b1, b1, b3, b2, b1] 149 assert list(preorder_traversal(('abc', ('d', 'ef')))) == [ 150 ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef'] 151 152 result = [] 153 pt = preorder_traversal(expr) 154 for i in pt: 155 result.append(i) 156 if i == b2: 157 pt.skip() 158 assert result == [expr, b21, b2, b1, b3, b2] 159 160 expr = z + w*(x + y) 161 assert list(preorder_traversal([expr], keys=default_sort_key)) == \ 162 [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y] 163 assert list(preorder_traversal((x + y)*z, keys=True)) == \ 164 [z*(x + y), z, x + y, x, y] 165 166 167def test_sorted_args(): 168 assert b21._sorted_args == b21.args 169 pytest.raises(AttributeError, lambda: x._sorted_args) 170 171 172def test_call(): 173 # See the long history of this in issues sympy/sympy#5026 and sympy/sympy#5105. 174 175 pytest.raises(TypeError, lambda: sin(x)({x: 1, sin(x): 2})) 176 pytest.raises(TypeError, lambda: sin(x)(1)) 177 178 # No effect as there are no callables 179 assert sin(x).rcall(1) == sin(x) 180 assert (1 + sin(x)).rcall(1) == 1 + sin(x) 181 182 # Effect in the pressence of callables 183 l = Lambda(x, 2*x) 184 assert (l + x).rcall(y) == 2*y + x 185 assert (x**l).rcall(2) == x**4 186 # TODO UndefinedFunction does not subclass Expr 187 # f = Function('f') 188 # assert (2*f)(x) == 2*f(x) 189 190 191def test_literal_evalf_is_number_is_zero_is_comparable(): 192 f = Function('f') 193 194 # the following should not be changed without a lot of dicussion 195 # `foo.is_number` should be equivalent to `not foo.free_symbols` 196 # it should not attempt anything fancy; see is_zero, is_constant 197 # and equals for more rigorous tests. 198 assert f(1).is_number is True 199 i = Integral(0, (x, x, x)) 200 # expressions that are symbolically 0 can be difficult to prove 201 # so in case there is some easy way to know if something is 0 202 # it should appear in the is_zero property for that object; 203 # if is_zero is true evalf should always be able to compute that 204 # zero 205 assert i.evalf() == 0 206 assert i.is_zero 207 assert i.is_number is False 208 assert i.evalf(2, strict=False) == 0 209 210 # issue sympy/sympy#10272 211 n = sin(1)**2 + cos(1)**2 - 1 212 assert n.is_comparable is not True 213 assert n.evalf(2, strict=False).is_comparable is not True 214