1from diofant import Expr, Symbol
2from diofant.core.decorators import call_highest_priority
3
4
5__all__ = ()
6
7
8class Higher(Expr):
9
10    _op_priority = 20.0
11    result = 'high'
12
13    is_commutative = False
14
15    @call_highest_priority('__rmul__')
16    def __mul__(self, other):
17        return self.result
18
19    @call_highest_priority('__mul__')
20    def __rmul__(self, other):
21        return self.result
22
23    @call_highest_priority('__radd__')
24    def __add__(self, other):
25        return self.result
26
27    @call_highest_priority('__add__')
28    def __radd__(self, other):
29        return self.result
30
31    @call_highest_priority('__rsub__')
32    def __sub__(self, other):
33        return self.result
34
35    @call_highest_priority('__sub__')
36    def __rsub__(self, other):
37        return self.result
38
39    @call_highest_priority('__rpow__')
40    def __pow__(self, other):
41        return self.result
42
43    @call_highest_priority('__pow__')
44    def __rpow__(self, other):
45        return self.result
46
47    @call_highest_priority('__rtruediv__')
48    def __truediv__(self, other):
49        return self.result
50
51    @call_highest_priority('__truediv__')
52    def __rtruediv__(self, other):
53        return self.result
54
55
56class Lower(Higher):
57
58    _op_priority = 5.0
59    result = 'low'
60
61
62class Lower2(Higher):
63    _op_priority = 5.0
64    result = 'low'
65
66    @call_highest_priority('typo')
67    def __mul__(self, other):
68        return self.result
69
70
71class Higher2:
72    result = 'high'
73
74
75def test_mul():
76    x = Symbol('x')
77    h = Higher()
78    l = Lower()
79    l2 = Lower2()
80    h2 = Higher2()
81    assert l*h == h*l == 'high'
82    assert x*h == h*x == 'high'
83    assert l*x == x*l != 'low'
84
85    assert l2*h == 'low'
86    assert l2*h2 == 'low'
87
88
89def test_add():
90    x = Symbol('x')
91    h = Higher()
92    l = Lower()
93    assert l + h == h + l == 'high'
94    assert x + h == h + x == 'high'
95    assert l + x == x + l != 'low'
96
97
98def test_sub():
99    x = Symbol('x')
100    h = Higher()
101    l = Lower()
102    assert l - h == h - l == 'high'
103    assert x - h == h - x == 'high'
104    assert l - x == -(x - l) != 'low'
105
106
107def test_pow():
108    x = Symbol('x')
109    h = Higher()
110    l = Lower()
111    assert l**h == h**l == 'high'
112    assert x**h == h**x == 'high'
113    assert l**x != 'low'
114    assert x**l != 'low'
115
116
117def test_div():
118    x = Symbol('x')
119    h = Higher()
120    l = Lower()
121    assert l/h == h/l == 'high'
122    assert x/h == h/x == 'high'
123    assert l/x != 'low'
124    assert x/l != 'low'
125