1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""A utility class for testing ordering methods. 16 17To test an ordering method, create an OrderTester and add several 18equivalence groups or items to it. The order tester will check that 19the items within each group are all equal to each other, and every new 20added item or group is strictly ascending with regard to the previously 21added items or groups. 22 23It will also check that a==b implies hash(a)==hash(b). 24""" 25 26from typing import Any 27from cirq.testing.equals_tester import EqualsTester 28 29 30_NAMED_COMPARISON_OPERATORS = [ 31 ('<', lambda a, b: a < b), 32 ('>', lambda a, b: a > b), 33 ('==', lambda a, b: a == b), 34 ('!=', lambda a, b: a != b), 35 ('<=', lambda a, b: a <= b), 36 ('>=', lambda a, b: a >= b), 37] 38 39 40class OrderTester: 41 """Tests ordering against user-provided disjoint ordered groups or items.""" 42 43 def __init__(self): 44 self._groups = [] 45 self._eq_tester = EqualsTester() 46 47 def _verify_ordering_one_sided(self, a: Any, b: Any, sign: int): 48 """Checks that (a vs b) == (0 vs sign).""" 49 for cmp_name, cmp_func in _NAMED_COMPARISON_OPERATORS: 50 expected = cmp_func(0, sign) 51 actual = cmp_func(a, b) 52 assert expected == actual, ( 53 "Ordering constraint violated. Expected X={} to {} Y={}, " 54 "but X {} Y returned {}".format( 55 a, ['be more than', 'equal', 'be less than'][sign + 1], b, cmp_name, actual 56 ) 57 ) 58 59 def _verify_ordering(self, a: Any, b: Any, sign: int): 60 """Checks that (a vs b) == (0 vs sign) and (b vs a) == (sign vs 0).""" 61 self._verify_ordering_one_sided(a, b, sign) 62 self._verify_ordering_one_sided(b, a, -sign) 63 64 def _verify_not_implemented_vs_unknown(self, item: Any): 65 try: 66 self._verify_ordering(_SmallerThanEverythingElse(), item, +1) 67 self._verify_ordering(_EqualToEverything(), item, 0) 68 self._verify_ordering(_LargerThanEverythingElse(), item, -1) 69 except AssertionError as ex: 70 raise AssertionError( 71 "Objects should return NotImplemented when compared to an " 72 "unknown value, i.e. comparison methods should start with\n" 73 "\n" 74 " if not isinstance(other, type(self)):\n" 75 " return NotImplemented\n" 76 "\n" 77 "That rule is being violated by this value: {!r}".format(item) 78 ) from ex 79 80 def add_ascending(self, *items: Any): 81 """Tries to add a sequence of ascending items to the order tester. 82 83 This methods asserts that items must all be ascending 84 with regard to both each other and the elements which have been already 85 added during previous calls. 86 Some of the previously added elements might be equivalence groups, 87 which are supposed to be equal to each other within that group. 88 89 Args: 90 *items: The sequence of strictly ascending items. 91 92 Raises: 93 AssertionError: Items are not ascending either 94 with regard to each other, or with regard to the elements 95 which have been added before. 96 """ 97 for item in items: 98 self.add_ascending_equivalence_group(item) 99 100 def add_ascending_equivalence_group(self, *group_items: Any): 101 """Tries to add an ascending equivalence group to the order tester. 102 103 Asserts that the group items are equal to each other, but strictly 104 ascending with regard to the already added groups. 105 106 Adds the objects as a group. 107 108 Args: 109 group_items: items making the equivalence group 110 111 Raises: 112 AssertionError: The group elements aren't equal to each other, 113 or items in another group overlap with the new group. 114 """ 115 116 for item in group_items: 117 self._verify_not_implemented_vs_unknown(item) 118 119 for item1 in group_items: 120 for item2 in group_items: 121 self._verify_ordering(item1, item2, 0) 122 123 for lesser_group in self._groups: 124 for lesser_item in lesser_group: 125 for larger_item in group_items: 126 self._verify_ordering(lesser_item, larger_item, +1) 127 128 # Use equals tester to check hash function consistency. 129 self._eq_tester.add_equality_group(*group_items) 130 131 self._groups.append(group_items) 132 133 134class _EqualToEverything: 135 def __eq__(self, other) -> bool: 136 return True 137 138 def __ne__(self, other) -> bool: 139 return False 140 141 def __lt__(self, other) -> bool: 142 return False 143 144 def __le__(self, other) -> bool: 145 return True 146 147 def __gt__(self, other) -> bool: 148 return False 149 150 def __ge__(self, other) -> bool: 151 return True 152 153 def __repr__(self) -> str: 154 return '_EqualToEverything' 155 156 157class _SmallerThanEverythingElse: 158 def __eq__(self, other) -> bool: 159 return isinstance(other, _SmallerThanEverythingElse) 160 161 def __ne__(self, other) -> bool: 162 return not isinstance(other, _SmallerThanEverythingElse) 163 164 def __lt__(self, other) -> bool: 165 return not isinstance(other, _SmallerThanEverythingElse) 166 167 def __le__(self, other) -> bool: 168 return True 169 170 def __gt__(self, other) -> bool: 171 return False 172 173 def __ge__(self, other) -> bool: 174 return isinstance(other, _SmallerThanEverythingElse) 175 176 def __repr__(self) -> str: 177 return 'SmallerThanEverythingElse' 178 179 180class _LargerThanEverythingElse: 181 def __eq__(self, other) -> bool: 182 return isinstance(other, _LargerThanEverythingElse) 183 184 def __ne__(self, other) -> bool: 185 return not isinstance(other, _LargerThanEverythingElse) 186 187 def __lt__(self, other) -> bool: 188 return False 189 190 def __le__(self, other) -> bool: 191 return isinstance(other, _LargerThanEverythingElse) 192 193 def __gt__(self, other) -> bool: 194 return not isinstance(other, _LargerThanEverythingElse) 195 196 def __ge__(self, other) -> bool: 197 return True 198 199 def __repr__(self) -> str: 200 return 'LargerThanEverythingElse()' 201