1"""This tests diofant/core/basic.py with (ideally) no reference to subclasses
2of Basic or Atom.
3"""
4
5import collections
6
7import pytest
8
9from diofant import (Atom, Basic, Function, I, Integral, Lambda, cos,
10                     default_sort_key, exp, gamma, preorder_traversal, sin)
11from diofant.abc import w, x, y, z
12from diofant.core.singleton import S
13from diofant.core.singleton import SingletonWithManagedProperties as Singleton
14
15
16__all__ = ()
17
18
19b1 = Basic()
20b2 = Basic(b1)
21b3 = Basic(b2)
22b21 = Basic(b2, b1)
23
24
25def test_structure():
26    assert b21.args == (b2, b1)
27    assert b21.func(*b21.args) == b21
28    assert bool(b1)
29
30
31def test_equality():
32    instances = [b1, b2, b3, b21, Basic(b1, b1, b1), Basic]
33    for i, b_i in enumerate(instances):
34        for j, b_j in enumerate(instances):
35            assert (b_i == b_j) == (i == j)
36            assert (b_i != b_j) == (i != j)
37
38    assert Basic() != []
39    assert not Basic() == []
40    assert Basic() != 0
41    assert not Basic() == 0
42
43
44def test_matches_basic():
45    instances = [Basic(b1, b1, b2), Basic(b1, b2, b1), Basic(b2, b1, b1),
46                 Basic(b1, b2), Basic(b2, b1), b2, b1]
47    for i, b_i in enumerate(instances):
48        for j, b_j in enumerate(instances):
49            if i == j:
50                assert b_j.match(b_i) == {}
51            else:
52                assert b_j.match(b_i) is None
53    assert b1.match(b1) == {}
54
55
56def test_has():
57    assert b21.has(b1)
58    assert b21.has(b3, b1)
59    assert b21.has(Basic)
60    assert not b1.has(b21, b3)
61    assert not b21.has()
62
63
64def test_subs():
65    assert b21.subs({b2: b1}) == Basic(b1, b1)
66    assert b21.subs({b2: b21}) == Basic(b21, b1)
67    assert b3.subs({b2: b1}) == b2
68
69    assert b21.subs([(b2, b1), (b1, b2)]) == Basic(b2, b2)
70
71    assert b21.subs({b1: b2, b2: b1}) == Basic(b2, b2)
72
73    pytest.raises(ValueError, lambda: b21.subs('bad arg'))
74    pytest.raises(ValueError, lambda: b21.subs(b1, b2, b3))
75
76    assert b21.subs(collections.ChainMap({b1: b2}, {b2: b1})) == Basic(b2, b2)
77    assert b21.subs(collections.OrderedDict([(b2, b1), (b1, b2)])) == Basic(b2, b2)
78
79
80def test_rewrite():
81    assert sin(1).rewrite() == sin(1)
82
83    f1 = sin(x) + cos(x)
84    assert f1.rewrite(cos, exp) == exp(I*x)/2 + sin(x) + exp(-I*x)/2
85
86    f2 = sin(x) + cos(y)/gamma(z)
87    assert f2.rewrite(sin, exp) == -I*(exp(I*x) - exp(-I*x))/2 + cos(y)/gamma(z)
88
89
90def test_atoms():
91    assert b21.atoms() == set()
92
93
94def test_free_symbols_empty():
95    assert b21.free_symbols == set()
96
97
98def test_doit():
99    assert b21.doit() == b21
100    assert b21.doit(deep=False) == b21
101
102
103def test_S():
104    assert repr(S) == 'S'
105
106
107def test_xreplace():
108    assert b21.xreplace({b2: b1}) == Basic(b1, b1)
109    assert b21.xreplace({b2: b21}) == Basic(b21, b1)
110    assert b3.xreplace({b2: b1}) == b2
111    assert Basic(b1, b2).xreplace({b1: b2, b2: b1}) == Basic(b2, b1)
112    assert Atom(b1).xreplace({b1: b2}) == Atom(b1)
113    assert Atom(b1).xreplace({Atom(b1): b2}) == b2
114    pytest.raises(TypeError, lambda: b1.xreplace())
115    pytest.raises(TypeError, lambda: b1.xreplace([b1, b2]))
116
117
118def test_Singleton():
119    global instantiated
120    instantiated = 0
121
122    class MyNewSingleton(Basic, metaclass=Singleton):
123        def __new__(cls):
124            global instantiated
125            instantiated += 1
126            return Basic.__new__(cls)
127
128    assert instantiated == 0
129    MyNewSingleton()  # force instantiation
130    assert instantiated == 1
131    assert MyNewSingleton() is not Basic()
132    assert MyNewSingleton() is MyNewSingleton()
133    assert S.MyNewSingleton is MyNewSingleton()
134    assert instantiated == 1
135
136    class MySingletonSub(MyNewSingleton):
137        pass
138    assert instantiated == 1
139    MySingletonSub()
140    assert instantiated == 2
141    assert MySingletonSub() is not MyNewSingleton()
142    assert MySingletonSub() is MySingletonSub()
143
144
145def test_preorder_traversal():
146    expr = Basic(b21, b3)
147    assert list(
148        preorder_traversal(expr)) == [expr, b21, b2, b1, b1, b3, b2, b1]
149    assert list(preorder_traversal(('abc', ('d', 'ef')))) == [
150        ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
151
152    result = []
153    pt = preorder_traversal(expr)
154    for i in pt:
155        result.append(i)
156        if i == b2:
157            pt.skip()
158    assert result == [expr, b21, b2, b1, b3, b2]
159
160    expr = z + w*(x + y)
161    assert list(preorder_traversal([expr], keys=default_sort_key)) == \
162        [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
163    assert list(preorder_traversal((x + y)*z, keys=True)) == \
164        [z*(x + y), z, x + y, x, y]
165
166
167def test_sorted_args():
168    assert b21._sorted_args == b21.args
169    pytest.raises(AttributeError, lambda: x._sorted_args)
170
171
172def test_call():
173    # See the long history of this in issues sympy/sympy#5026 and sympy/sympy#5105.
174
175    pytest.raises(TypeError, lambda: sin(x)({x: 1, sin(x): 2}))
176    pytest.raises(TypeError, lambda: sin(x)(1))
177
178    # No effect as there are no callables
179    assert sin(x).rcall(1) == sin(x)
180    assert (1 + sin(x)).rcall(1) == 1 + sin(x)
181
182    # Effect in the pressence of callables
183    l = Lambda(x, 2*x)
184    assert (l + x).rcall(y) == 2*y + x
185    assert (x**l).rcall(2) == x**4
186    # TODO UndefinedFunction does not subclass Expr
187    # f = Function('f')
188    # assert (2*f)(x) == 2*f(x)
189
190
191def test_literal_evalf_is_number_is_zero_is_comparable():
192    f = Function('f')
193
194    # the following should not be changed without a lot of dicussion
195    # `foo.is_number` should be equivalent to `not foo.free_symbols`
196    # it should not attempt anything fancy; see is_zero, is_constant
197    # and equals for more rigorous tests.
198    assert f(1).is_number is True
199    i = Integral(0, (x, x, x))
200    # expressions that are symbolically 0 can be difficult to prove
201    # so in case there is some easy way to know if something is 0
202    # it should appear in the is_zero property for that object;
203    # if is_zero is true evalf should always be able to compute that
204    # zero
205    assert i.evalf() == 0
206    assert i.is_zero
207    assert i.is_number is False
208    assert i.evalf(2, strict=False) == 0
209
210    # issue sympy/sympy#10272
211    n = sin(1)**2 + cos(1)**2 - 1
212    assert n.is_comparable is not True
213    assert n.evalf(2, strict=False).is_comparable is not True
214