1from diofant import Expr, Symbol 2from diofant.core.decorators import call_highest_priority 3 4 5__all__ = () 6 7 8class Higher(Expr): 9 10 _op_priority = 20.0 11 result = 'high' 12 13 is_commutative = False 14 15 @call_highest_priority('__rmul__') 16 def __mul__(self, other): 17 return self.result 18 19 @call_highest_priority('__mul__') 20 def __rmul__(self, other): 21 return self.result 22 23 @call_highest_priority('__radd__') 24 def __add__(self, other): 25 return self.result 26 27 @call_highest_priority('__add__') 28 def __radd__(self, other): 29 return self.result 30 31 @call_highest_priority('__rsub__') 32 def __sub__(self, other): 33 return self.result 34 35 @call_highest_priority('__sub__') 36 def __rsub__(self, other): 37 return self.result 38 39 @call_highest_priority('__rpow__') 40 def __pow__(self, other): 41 return self.result 42 43 @call_highest_priority('__pow__') 44 def __rpow__(self, other): 45 return self.result 46 47 @call_highest_priority('__rtruediv__') 48 def __truediv__(self, other): 49 return self.result 50 51 @call_highest_priority('__truediv__') 52 def __rtruediv__(self, other): 53 return self.result 54 55 56class Lower(Higher): 57 58 _op_priority = 5.0 59 result = 'low' 60 61 62class Lower2(Higher): 63 _op_priority = 5.0 64 result = 'low' 65 66 @call_highest_priority('typo') 67 def __mul__(self, other): 68 return self.result 69 70 71class Higher2: 72 result = 'high' 73 74 75def test_mul(): 76 x = Symbol('x') 77 h = Higher() 78 l = Lower() 79 l2 = Lower2() 80 h2 = Higher2() 81 assert l*h == h*l == 'high' 82 assert x*h == h*x == 'high' 83 assert l*x == x*l != 'low' 84 85 assert l2*h == 'low' 86 assert l2*h2 == 'low' 87 88 89def test_add(): 90 x = Symbol('x') 91 h = Higher() 92 l = Lower() 93 assert l + h == h + l == 'high' 94 assert x + h == h + x == 'high' 95 assert l + x == x + l != 'low' 96 97 98def test_sub(): 99 x = Symbol('x') 100 h = Higher() 101 l = Lower() 102 assert l - h == h - l == 'high' 103 assert x - h == h - x == 'high' 104 assert l - x == -(x - l) != 'low' 105 106 107def test_pow(): 108 x = Symbol('x') 109 h = Higher() 110 l = Lower() 111 assert l**h == h**l == 'high' 112 assert x**h == h**x == 'high' 113 assert l**x != 'low' 114 assert x**l != 'low' 115 116 117def test_div(): 118 x = Symbol('x') 119 h = Higher() 120 l = Lower() 121 assert l/h == h/l == 'high' 122 assert x/h == h/x == 'high' 123 assert l/x != 'low' 124 assert x/l != 'low' 125