1from diofant import Add, Basic, Integer 2from diofant.abc import x 3from diofant.core.strategies import (arguments, flatten, glom, null_safe, 4 operator, rm_id, sort, term, unpack) 5 6 7__all__ = () 8 9 10def test_rm_id(): 11 rmzeros = rm_id(lambda x: x == 0) 12 assert rmzeros(Basic(0, 1)) == Basic(1) 13 assert rmzeros(Basic(0, 0)) == Basic(0) 14 assert rmzeros(Basic(2, 1)) == Basic(2, 1) 15 16 17def test_glom(): 18 def key(x): 19 return x.as_coeff_Mul()[1] 20 21 def count(x): 22 return x.as_coeff_Mul()[0] 23 24 def newargs(cnt, arg): 25 return cnt * arg 26 27 rl = glom(key, count, newargs) 28 29 result = rl(Add(x, -x, 3*x, 2, 3, evaluate=False)) 30 expected = Add(3*x, 5) 31 assert set(result.args) == set(expected.args) 32 33 result = rl(Add(*expected.args, evaluate=False)) 34 assert set(result.args) == set(expected.args) 35 36 37def test_flatten(): 38 assert flatten(Basic(1, 2, Basic(3, 4))) == Basic(1, 2, 3, 4) 39 40 41def test_unpack(): 42 assert unpack(Basic(2)) == 2 43 assert unpack(Basic(2, 3)) == Basic(2, 3) 44 45 46def test_sort(): 47 assert sort(str)(Basic(3, 1, 2)) == Basic(1, 2, 3) 48 49 50def test_term(): 51 assert arguments(2) == () 52 assert arguments(Integer(2)) == () 53 assert arguments(2 + x) == (2, x) 54 assert operator(2 + x) == Add 55 assert operator(Integer(2)) == Integer(2) 56 assert term(Add, (2, x)) == 2 + x 57 assert term(Integer(2), ()) == Integer(2) 58 59 60def test_null_safe(): 61 def rl(expr): 62 if expr == 1: 63 return 2 64 safe_rl = null_safe(rl) 65 assert rl(1) == safe_rl(1) 66 67 assert rl(3) is None 68 assert safe_rl(3) == 3 69