1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A utility class for testing ordering methods.
16
17To test an ordering method, create an OrderTester and add several
18equivalence groups or items to it. The order tester will check that
19the items within each group are all equal to each other, and every new
20added item or group is strictly ascending with regard to the previously
21added items or groups.
22
23It will also check that a==b implies hash(a)==hash(b).
24"""
25
26from typing import Any
27from cirq.testing.equals_tester import EqualsTester
28
29
30_NAMED_COMPARISON_OPERATORS = [
31    ('<', lambda a, b: a < b),
32    ('>', lambda a, b: a > b),
33    ('==', lambda a, b: a == b),
34    ('!=', lambda a, b: a != b),
35    ('<=', lambda a, b: a <= b),
36    ('>=', lambda a, b: a >= b),
37]
38
39
40class OrderTester:
41    """Tests ordering against user-provided disjoint ordered groups or items."""
42
43    def __init__(self):
44        self._groups = []
45        self._eq_tester = EqualsTester()
46
47    def _verify_ordering_one_sided(self, a: Any, b: Any, sign: int):
48        """Checks that (a vs b) == (0 vs sign)."""
49        for cmp_name, cmp_func in _NAMED_COMPARISON_OPERATORS:
50            expected = cmp_func(0, sign)
51            actual = cmp_func(a, b)
52            assert expected == actual, (
53                "Ordering constraint violated. Expected X={} to {} Y={}, "
54                "but X {} Y returned {}".format(
55                    a, ['be more than', 'equal', 'be less than'][sign + 1], b, cmp_name, actual
56                )
57            )
58
59    def _verify_ordering(self, a: Any, b: Any, sign: int):
60        """Checks that (a vs b) == (0 vs sign) and (b vs a) == (sign vs 0)."""
61        self._verify_ordering_one_sided(a, b, sign)
62        self._verify_ordering_one_sided(b, a, -sign)
63
64    def _verify_not_implemented_vs_unknown(self, item: Any):
65        try:
66            self._verify_ordering(_SmallerThanEverythingElse(), item, +1)
67            self._verify_ordering(_EqualToEverything(), item, 0)
68            self._verify_ordering(_LargerThanEverythingElse(), item, -1)
69        except AssertionError as ex:
70            raise AssertionError(
71                "Objects should return NotImplemented when compared to an "
72                "unknown value, i.e. comparison methods should start with\n"
73                "\n"
74                "    if not isinstance(other, type(self)):\n"
75                "        return NotImplemented\n"
76                "\n"
77                "That rule is being violated by this value: {!r}".format(item)
78            ) from ex
79
80    def add_ascending(self, *items: Any):
81        """Tries to add a sequence of ascending items to the order tester.
82
83        This methods asserts that items must all be ascending
84        with regard to both each other and the elements which have been already
85        added during previous calls.
86        Some of the previously added elements might be equivalence groups,
87        which are supposed to be equal to each other within that group.
88
89        Args:
90          *items: The sequence of strictly ascending items.
91
92        Raises:
93            AssertionError: Items are not ascending either
94                with regard to each other, or with regard to the elements
95                which have been added before.
96        """
97        for item in items:
98            self.add_ascending_equivalence_group(item)
99
100    def add_ascending_equivalence_group(self, *group_items: Any):
101        """Tries to add an ascending equivalence group to the order tester.
102
103        Asserts that the group items are equal to each other, but strictly
104        ascending with regard to the already added groups.
105
106        Adds the objects as a group.
107
108        Args:
109            group_items: items making the equivalence group
110
111        Raises:
112            AssertionError: The group elements aren't equal to each other,
113                or items in another group overlap with the new group.
114        """
115
116        for item in group_items:
117            self._verify_not_implemented_vs_unknown(item)
118
119        for item1 in group_items:
120            for item2 in group_items:
121                self._verify_ordering(item1, item2, 0)
122
123        for lesser_group in self._groups:
124            for lesser_item in lesser_group:
125                for larger_item in group_items:
126                    self._verify_ordering(lesser_item, larger_item, +1)
127
128        # Use equals tester to check hash function consistency.
129        self._eq_tester.add_equality_group(*group_items)
130
131        self._groups.append(group_items)
132
133
134class _EqualToEverything:
135    def __eq__(self, other) -> bool:
136        return True
137
138    def __ne__(self, other) -> bool:
139        return False
140
141    def __lt__(self, other) -> bool:
142        return False
143
144    def __le__(self, other) -> bool:
145        return True
146
147    def __gt__(self, other) -> bool:
148        return False
149
150    def __ge__(self, other) -> bool:
151        return True
152
153    def __repr__(self) -> str:
154        return '_EqualToEverything'
155
156
157class _SmallerThanEverythingElse:
158    def __eq__(self, other) -> bool:
159        return isinstance(other, _SmallerThanEverythingElse)
160
161    def __ne__(self, other) -> bool:
162        return not isinstance(other, _SmallerThanEverythingElse)
163
164    def __lt__(self, other) -> bool:
165        return not isinstance(other, _SmallerThanEverythingElse)
166
167    def __le__(self, other) -> bool:
168        return True
169
170    def __gt__(self, other) -> bool:
171        return False
172
173    def __ge__(self, other) -> bool:
174        return isinstance(other, _SmallerThanEverythingElse)
175
176    def __repr__(self) -> str:
177        return 'SmallerThanEverythingElse'
178
179
180class _LargerThanEverythingElse:
181    def __eq__(self, other) -> bool:
182        return isinstance(other, _LargerThanEverythingElse)
183
184    def __ne__(self, other) -> bool:
185        return not isinstance(other, _LargerThanEverythingElse)
186
187    def __lt__(self, other) -> bool:
188        return False
189
190    def __le__(self, other) -> bool:
191        return isinstance(other, _LargerThanEverythingElse)
192
193    def __gt__(self, other) -> bool:
194        return not isinstance(other, _LargerThanEverythingElse)
195
196    def __ge__(self, other) -> bool:
197        return True
198
199    def __repr__(self) -> str:
200        return 'LargerThanEverythingElse()'
201