1from diofant import Add, Basic, Integer
2from diofant.abc import x
3from diofant.core.strategies import (arguments, flatten, glom, null_safe,
4                                     operator, rm_id, sort, term, unpack)
5
6
7__all__ = ()
8
9
10def test_rm_id():
11    rmzeros = rm_id(lambda x: x == 0)
12    assert rmzeros(Basic(0, 1)) == Basic(1)
13    assert rmzeros(Basic(0, 0)) == Basic(0)
14    assert rmzeros(Basic(2, 1)) == Basic(2, 1)
15
16
17def test_glom():
18    def key(x):
19        return x.as_coeff_Mul()[1]
20
21    def count(x):
22        return x.as_coeff_Mul()[0]
23
24    def newargs(cnt, arg):
25        return cnt * arg
26
27    rl = glom(key, count, newargs)
28
29    result = rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
30    expected = Add(3*x, 5)
31    assert set(result.args) == set(expected.args)
32
33    result = rl(Add(*expected.args, evaluate=False))
34    assert set(result.args) == set(expected.args)
35
36
37def test_flatten():
38    assert flatten(Basic(1, 2, Basic(3, 4))) == Basic(1, 2, 3, 4)
39
40
41def test_unpack():
42    assert unpack(Basic(2)) == 2
43    assert unpack(Basic(2, 3)) == Basic(2, 3)
44
45
46def test_sort():
47    assert sort(str)(Basic(3, 1, 2)) == Basic(1, 2, 3)
48
49
50def test_term():
51    assert arguments(2) == ()
52    assert arguments(Integer(2)) == ()
53    assert arguments(2 + x) == (2, x)
54    assert operator(2 + x) == Add
55    assert operator(Integer(2)) == Integer(2)
56    assert term(Add, (2, x)) == 2 + x
57    assert term(Integer(2), ()) == Integer(2)
58
59
60def test_null_safe():
61    def rl(expr):
62        if expr == 1:
63            return 2
64    safe_rl = null_safe(rl)
65    assert rl(1) == safe_rl(1)
66
67    assert rl(3) is None
68    assert safe_rl(3) == 3
69