1import pytest
2
3from diofant import (Dummy, Float, GreaterThan, I, Integer, LessThan, Rational,
4                     StrictGreaterThan, StrictLessThan, Symbol, Wild, beta, pi,
5                     sstr, symbols, sympify)
6
7
8__all__ = ()
9
10
11def test_Symbol():
12    a = Symbol('a')
13    x1 = Symbol('x')
14    x2 = Symbol('x')
15    xdummy1 = Dummy('x')
16    xdummy2 = Dummy('x')
17
18    assert a != x1
19    assert a != x2
20    assert x1 == x2
21    assert x1 != xdummy1
22    assert xdummy1 != xdummy2
23
24    assert Symbol('x') == Symbol('x')
25    assert Dummy('x') != Dummy('x')
26    d = symbols('d', cls=Dummy)
27    assert isinstance(d, Dummy)
28    c, d = symbols('c,d', cls=Dummy)
29    assert isinstance(c, Dummy)
30    assert isinstance(d, Dummy)
31    pytest.raises(TypeError, lambda: Symbol())
32    pytest.raises(TypeError, lambda: Symbol(1))
33
34
35def test_Dummy():
36    assert Dummy() != Dummy()
37    Dummy._count = 0
38    d1 = Dummy()
39    Dummy._count = 0
40    assert d1 == Dummy()
41
42
43def test_as_dummy():
44    x = Symbol('x')
45    x1 = x.as_dummy()
46    assert x1 != x
47    assert x1 != x.as_dummy()
48
49    x = Symbol('x', commutative=False)
50    x1 = x.as_dummy()
51    assert x1 != x
52    assert x1.is_commutative is False
53
54
55def test_lt_gt():
56    x, y = Symbol('x'), Symbol('y')
57
58    assert (x >= y) == GreaterThan(x, y)
59    assert (x >= 0) == GreaterThan(x, 0)
60    assert (x <= y) == LessThan(x, y)
61    assert (x <= 0) == LessThan(x, 0)
62
63    assert (0 <= x) == GreaterThan(x, 0)
64    assert (0 >= x) == LessThan(x, 0)
65    assert (Integer(0) >= x) == GreaterThan(0, x)
66    assert (Integer(0) <= x) == LessThan(0, x)
67
68    assert (x > y) == StrictGreaterThan(x, y)
69    assert (x > 0) == StrictGreaterThan(x, 0)
70    assert (x < y) == StrictLessThan(x, y)
71    assert (x < 0) == StrictLessThan(x, 0)
72
73    assert (0 < x) == StrictGreaterThan(x, 0)
74    assert (0 > x) == StrictLessThan(x, 0)
75    assert (Integer(0) > x) == StrictGreaterThan(0, x)
76    assert (Integer(0) < x) == StrictLessThan(0, x)
77
78    e = x**2 + 4*x + 1
79    assert (e >= 0) == GreaterThan(e, 0)
80    assert (0 <= e) == GreaterThan(e, 0)
81    assert (e > 0) == StrictGreaterThan(e, 0)
82    assert (0 < e) == StrictGreaterThan(e, 0)
83
84    assert (e <= 0) == LessThan(e, 0)
85    assert (0 >= e) == LessThan(e, 0)
86    assert (e < 0) == StrictLessThan(e, 0)
87    assert (0 > e) == StrictLessThan(e, 0)
88
89    assert (Integer(0) >= e) == GreaterThan(0, e)
90    assert (Integer(0) <= e) == LessThan(0, e)
91    assert (Integer(0) < e) == StrictLessThan(0, e)
92    assert (Integer(0) > e) == StrictGreaterThan(0, e)
93
94
95def test_no_len():
96    # there should be no len for numbers
97    x = Symbol('x')
98    pytest.raises(TypeError, lambda: len(x))
99
100
101def test_ineq_unequal():
102    x, y, z = symbols('x,y,z')
103
104    e = (
105        Integer(-1) >= x, Integer(-1) >= y, Integer(-1) >= z,
106        Integer(-1) > x, Integer(-1) > y, Integer(-1) > z,
107        Integer(-1) <= x, Integer(-1) <= y, Integer(-1) <= z,
108        Integer(-1) < x, Integer(-1) < y, Integer(-1) < z,
109        Integer(0) >= x, Integer(0) >= y, Integer(0) >= z,
110        Integer(0) > x, Integer(0) > y, Integer(0) > z,
111        Integer(0) <= x, Integer(0) <= y, Integer(0) <= z,
112        Integer(0) < x, Integer(0) < y, Integer(0) < z,
113        Rational(3, 7) >= x, Rational(3, 7) >= y, Rational(3, 7) >= z,
114        Rational(3, 7) > x, Rational(3, 7) > y, Rational(3, 7) > z,
115        Rational(3, 7) <= x, Rational(3, 7) <= y, Rational(3, 7) <= z,
116        Rational(3, 7) < x, Rational(3, 7) < y, Rational(3, 7) < z,
117        Float(1.5) >= x, Float(1.5) >= y, Float(1.5) >= z,
118        Float(1.5) > x, Float(1.5) > y, Float(1.5) > z,
119        Float(1.5) <= x, Float(1.5) <= y, Float(1.5) <= z,
120        Float(1.5) < x, Float(1.5) < y, Float(1.5) < z,
121        Integer(2) >= x, Integer(2) >= y, Integer(2) >= z,
122        Integer(2) > x, Integer(2) > y, Integer(2) > z,
123        Integer(2) <= x, Integer(2) <= y, Integer(2) <= z,
124        Integer(2) < x, Integer(2) < y, Integer(2) < z,
125        x >= -1, y >= -1, z >= -1,
126        x > -1, y > -1, z > -1,
127        x <= -1, y <= -1, z <= -1,
128        x < -1, y < -1, z < -1,
129        x >= 0, y >= 0, z >= 0,
130        x > 0, y > 0, z > 0,
131        x <= 0, y <= 0, z <= 0,
132        x < 0, y < 0, z < 0,
133        x >= 1.5, y >= 1.5, z >= 1.5,
134        x > 1.5, y > 1.5, z > 1.5,
135        x <= 1.5, y <= 1.5, z <= 1.5,
136        x < 1.5, y < 1.5, z < 1.5,
137        x >= 2, y >= 2, z >= 2,
138        x > 2, y > 2, z > 2,
139        x <= 2, y <= 2, z <= 2,
140        x < 2, y < 2, z < 2,
141
142        x >= y, x >= z, y >= x, y >= z, z >= x, z >= y,
143        x > y, x > z, y > x, y > z, z > x, z > y,
144        x <= y, x <= z, y <= x, y <= z, z <= x, z <= y,
145        x < y, x < z, y < x, y < z, z < x, z < y,
146
147        x - pi >= y + z, y - pi >= x + z, z - pi >= x + y,
148        x - pi > y + z, y - pi > x + z, z - pi > x + y,
149        x - pi <= y + z, y - pi <= x + z, z - pi <= x + y,
150        x - pi < y + z, y - pi < x + z, z - pi < x + y,
151        True, False)
152
153    left_e = e[:-1]
154    for i, e1 in enumerate(left_e):
155        for e2 in e[i + 1:]:
156            assert e1 != e2
157
158
159def test_Wild_properties():
160    # these tests only include Atoms
161    x = Symbol('x')
162    y = Symbol('y')
163    p = Symbol('p', positive=True)
164    k = Symbol('k', integer=True)
165    n = Symbol('n', integer=True, positive=True)
166
167    given_patterns = [x, y, p, k, -k, n, -n, sympify(-3), sympify(3),
168                      pi, Rational(3, 2), I]
169
170    def integerp(k):
171        return k.is_integer
172
173    def positivep(k):
174        return k.is_positive
175
176    def symbolp(k):
177        return k.is_Symbol
178
179    def realp(k):
180        return k.is_extended_real
181
182    S = Wild('S', properties=[symbolp])
183    R = Wild('R', properties=[realp])
184    Y = Wild('Y', exclude=[x, p, k, n])
185    P = Wild('P', properties=[positivep])
186    K = Wild('K', properties=[integerp])
187    N = Wild('N', properties=[positivep, integerp])
188
189    given_wildcards = [S, R, Y, P, K, N]
190
191    goodmatch = {
192        S: (x, y, p, k, n),
193        R: (p, k, -k, n, -n, -3, 3, pi, Rational(3, 2)),
194        Y: (y, -3, 3, pi, Rational(3, 2), I),
195        P: (p, n, 3, pi, Rational(3, 2)),
196        K: (k, -k, n, -n, -3, 3),
197        N: (n, 3)}
198
199    for A in given_wildcards:
200        for pat in given_patterns:
201            d = pat.match(A)
202            if pat in goodmatch[A]:
203                assert d[A] in goodmatch[A]
204            else:
205                assert d is None
206
207
208def test_symbols():
209    x = Symbol('x')
210    y = Symbol('y')
211    z = Symbol('z')
212
213    assert symbols('x') == x
214    assert symbols('x ') == x
215    assert symbols(' x ') == x
216    assert symbols('x,') == (x,)
217    assert symbols('x, ') == (x,)
218    assert symbols('x ,') == (x,)
219
220    assert symbols('x , y') == (x, y)
221
222    assert symbols('x,y,z') == (x, y, z)
223    assert symbols('x y z') == (x, y, z)
224
225    assert symbols('x,y,z,') == (x, y, z)
226    assert symbols('x y z ') == (x, y, z)
227
228    xyz = Symbol('xyz')
229    abc = Symbol('abc')
230
231    assert symbols('xyz') == xyz
232    assert symbols('xyz,') == (xyz,)
233    assert symbols('xyz,abc') == (xyz, abc)
234
235    assert symbols(('xyz',)) == (xyz,)
236    assert symbols(('xyz,',)) == ((xyz,),)
237    assert symbols(('x,y,z,',)) == ((x, y, z),)
238    assert symbols(('xyz', 'abc')) == (xyz, abc)
239    assert symbols(('xyz,abc',)) == ((xyz, abc),)
240    assert symbols(('xyz,abc', 'x,y,z')) == ((xyz, abc), (x, y, z))
241
242    assert symbols(('x', 'y', 'z')) == (x, y, z)
243    assert symbols(['x', 'y', 'z']) == [x, y, z]
244    assert symbols({'x', 'y', 'z'}) == {x, y, z}
245
246    pytest.raises(ValueError, lambda: symbols(''))
247    pytest.raises(ValueError, lambda: symbols(','))
248    pytest.raises(ValueError, lambda: symbols('x,,y,,z'))
249    pytest.raises(ValueError, lambda: symbols(('x', '', 'y', '', 'z')))
250
251    a, b = symbols('x,y', extended_real=True)
252    assert a.is_extended_real and b.is_extended_real
253
254    x0 = Symbol('x0')
255    x1 = Symbol('x1')
256    x2 = Symbol('x2')
257
258    y0 = Symbol('y0')
259    y1 = Symbol('y1')
260
261    assert symbols('x0:0') == ()
262    assert symbols('x0:1') == (x0,)
263    assert symbols('x0:2') == (x0, x1)
264    assert symbols('x0:3') == (x0, x1, x2)
265
266    assert symbols('x:0') == ()
267    assert symbols('x:1') == (x0,)
268    assert symbols('x:2') == (x0, x1)
269    assert symbols('x:3') == (x0, x1, x2)
270
271    assert symbols('x1:1') == ()
272    assert symbols('x1:2') == (x1,)
273    assert symbols('x1:3') == (x1, x2)
274
275    assert symbols('x1:3,x,y,z') == (x1, x2, x, y, z)
276
277    assert symbols('x:3,y:2') == (x0, x1, x2, y0, y1)
278    assert symbols(('x:3', 'y:2')) == ((x0, x1, x2), (y0, y1))
279
280    a = Symbol('a')
281    b = Symbol('b')
282    c = Symbol('c')
283    d = Symbol('d')
284
285    assert symbols('x:z') == (x, y, z)
286    assert symbols('a:d,x:z') == (a, b, c, d, x, y, z)
287    assert symbols(('a:d', 'x:z')) == ((a, b, c, d), (x, y, z))
288
289    aa = Symbol('aa')
290    ab = Symbol('ab')
291    ac = Symbol('ac')
292    ad = Symbol('ad')
293
294    assert symbols('aa:d') == (aa, ab, ac, ad)
295    assert symbols('aa:d,x:z') == (aa, ab, ac, ad, x, y, z)
296    assert symbols(('aa:d', 'x:z')) == ((aa, ab, ac, ad), (x, y, z))
297
298    # issue sympy/sympy#6675
299    def sym(s):
300        return sstr(symbols(s))
301    assert sym('a0:4') == '(a0, a1, a2, a3)'
302    assert sym('a2:4,b1:3') == '(a2, a3, b1, b2)'
303    assert sym('a1(2:4)') == '(a12, a13)'
304    assert sym(('a0:2.0:2')) == '(a0.0, a0.1, a1.0, a1.1)'
305    assert sym(('aa:cz')) == '(aaz, abz, acz)'
306    assert sym('aa:c0:2') == '(aa0, aa1, ab0, ab1, ac0, ac1)'
307    assert sym('aa:ba:b') == '(aaa, aab, aba, abb)'
308    assert sym('a:3b') == '(a0b, a1b, a2b)'
309    assert sym('a-1:3b') == '(a-1b, a-2b)'
310
311    c = chr(0)
312
313    assert sym(r'a:2\,:2' + c) == f'(a0,0{c}, a0,1{c}, a1,0{c}, a1,1{c})'
314    assert sym('x(:a:3)') == '(x(a0), x(a1), x(a2))'
315    assert sym('x(:c):1') == '(xa0, xb0, xc0)'
316    assert sym('x((:a)):3') == '(x(a)0, x(a)1, x(a)2)'
317    assert sym('x(:a:3') == '(x(a0, x(a1, x(a2)'
318    assert sym(':2') == '(0, 1)'
319    assert sym(':b') == '(a, b)'
320    assert sym(':b:2') == '(a0, a1, b0, b1)'
321    assert sym(':2:2') == '(00, 01, 10, 11)'
322    assert sym(':b:b') == '(aa, ab, ba, bb)'
323
324    pytest.raises(ValueError, lambda: symbols(':'))
325    pytest.raises(ValueError, lambda: symbols('a:'))
326    pytest.raises(ValueError, lambda: symbols('::'))
327    pytest.raises(ValueError, lambda: symbols('a::'))
328    pytest.raises(ValueError, lambda: symbols(':a:'))
329    pytest.raises(ValueError, lambda: symbols('::a'))
330
331
332def test_sympyissue_9057_1():
333    beta(2, 3)  # not raises
334
335
336def test_sympyissue_9057_2():
337    beta = Symbol('beta')
338    pytest.raises(TypeError, lambda: beta(2))
339    pytest.raises(TypeError, lambda: beta(2.5))
340    pytest.raises(TypeError, lambda: beta(2, 3))
341