1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import pytest
15
16import cirq
17
18
19@cirq.value_equality
20class BasicC:
21    def __init__(self, x):
22        self.x = x
23
24    def _value_equality_values_(self):
25        return self.x
26
27
28@cirq.value_equality
29class BasicD:
30    def __init__(self, x):
31        self.x = x
32
33    def _value_equality_values_(self):
34        return self.x
35
36
37@cirq.value_equality(manual_cls=True)
38class MasqueradePositiveD:
39    def __init__(self, x):
40        self.x = x
41
42    def _value_equality_values_(self):
43        return self.x
44
45    def _value_equality_values_cls_(self):
46        return BasicD if self.x > 0 else MasqueradePositiveD
47
48
49class BasicCa(BasicC):
50    pass
51
52
53class BasicCb(BasicC):
54    pass
55
56
57def test_value_equality_basic():
58
59    # Lookup works across equivalent types.
60    v = {BasicC(1): 4, BasicCa(2): 5}
61    assert v[BasicCa(1)] == v[BasicC(1)] == 4
62    assert v[BasicCa(2)] == 5
63
64    # Equality works as expected.
65    eq = cirq.testing.EqualsTester()
66    eq.add_equality_group(BasicC(1), BasicC(1), BasicCa(1), BasicCb(1))
67    eq.add_equality_group(BasicD(1))
68    eq.add_equality_group(BasicC(2))
69    eq.add_equality_group(BasicCa(3))
70
71
72def test_value_equality_manual():
73    eq = cirq.testing.EqualsTester()
74    eq.add_equality_group(MasqueradePositiveD(3), BasicD(3))
75    eq.add_equality_group(MasqueradePositiveD(4), MasqueradePositiveD(4), BasicD(4))
76    eq.add_equality_group(MasqueradePositiveD(-1), MasqueradePositiveD(-1))
77    eq.add_equality_group(BasicD(-1))
78
79
80@cirq.value_equality(unhashable=True)
81class UnhashableC:
82    def __init__(self, x):
83        self.x = x
84
85    def _value_equality_values_(self):
86        return self.x
87
88
89@cirq.value_equality(unhashable=True)
90class UnhashableD:
91    def __init__(self, x):
92        self.x = x
93
94    def _value_equality_values_(self):
95        return self.x
96
97
98class UnhashableCa(UnhashableC):
99    pass
100
101
102class UnhashableCb(UnhashableC):
103    pass
104
105
106def test_value_equality_unhashable():
107    # Not possible to use as a dictionary key.
108    with pytest.raises(TypeError, match='unhashable'):
109        _ = {UnhashableC(1): 4}
110
111    # Equality works as expected.
112    eq = cirq.testing.EqualsTester()
113    eq.add_equality_group(UnhashableC(1), UnhashableC(1), UnhashableCa(1), UnhashableCb(1))
114    eq.add_equality_group(UnhashableC(2))
115    eq.add_equality_group(UnhashableD(1))
116
117
118@cirq.value_equality(distinct_child_types=True)
119class DistinctC:
120    def __init__(self, x):
121        self.x = x
122
123    def _value_equality_values_(self):
124        return self.x
125
126
127@cirq.value_equality(distinct_child_types=True)
128class DistinctD:
129    def __init__(self, x):
130        self.x = x
131
132    def _value_equality_values_(self):
133        return self.x
134
135
136class DistinctCa(DistinctC):
137    pass
138
139
140class DistinctCb(DistinctC):
141    pass
142
143
144def test_value_equality_distinct_child_types():
145    # Lookup is distinct across child types.
146    v = {DistinctC(1): 4, DistinctCa(1): 5, DistinctCb(1): 6}
147    assert v[DistinctC(1)] == 4
148    assert v[DistinctCa(1)] == 5
149    assert v[DistinctCb(1)] == 6
150
151    # Equality works as expected.
152    eq = cirq.testing.EqualsTester()
153    eq.add_equality_group(DistinctC(1), DistinctC(1))
154    eq.add_equality_group(DistinctCa(1), DistinctCa(1))
155    eq.add_equality_group(DistinctCb(1), DistinctCb(1))
156    eq.add_equality_group(DistinctC(2))
157    eq.add_equality_group(DistinctD(1))
158
159
160@cirq.value_equality(approximate=True)
161class ApproxE:
162    def __init__(self, x):
163        self.x = x
164
165    def _value_equality_values_(self):
166        return self.x
167
168
169def test_value_equality_approximate():
170    assert cirq.approx_eq(ApproxE(0.0), ApproxE(0.0), atol=0.1)
171    assert cirq.approx_eq(ApproxE(0.0), ApproxE(0.2), atol=0.3)
172    assert not cirq.approx_eq(ApproxE(0.0), ApproxE(0.2), atol=0.1)
173
174
175@cirq.value_equality(approximate=True)
176class PeriodicF:
177    def __init__(self, x, n):
178        self.x = x
179        self.n = n
180
181    def _value_equality_values_(self):
182        return self.x
183
184    def _value_equality_approximate_values_(self):
185        return self.x % self.n
186
187
188def test_value_equality_approximate_specialized():
189    assert PeriodicF(1, 4) != PeriodicF(5, 4)
190    assert cirq.approx_eq(PeriodicF(1, 4), PeriodicF(5, 4), atol=0.1)
191    assert not cirq.approx_eq(PeriodicF(1, 4), PeriodicF(6, 4), atol=0.1)
192
193
194def test_value_equality_approximate_not_supported():
195    assert not cirq.approx_eq(BasicC(0.0), BasicC(0.1), atol=0.2)
196
197
198class ApproxEa(ApproxE):
199    pass
200
201
202class ApproxEb(ApproxE):
203    pass
204
205
206@cirq.value_equality(distinct_child_types=True, approximate=True)
207class ApproxG:
208    def __init__(self, x):
209        self.x = x
210
211    def _value_equality_values_(self):
212        return self.x
213
214
215class ApproxGa(ApproxG):
216    pass
217
218
219class ApproxGb(ApproxG):
220    pass
221
222
223def test_value_equality_approximate_typing():
224    assert not cirq.approx_eq(ApproxE(0.0), PeriodicF(0.0, 1.0), atol=0.1)
225    assert cirq.approx_eq(ApproxEa(0.0), ApproxEb(0.0), atol=0.1)
226    assert cirq.approx_eq(ApproxG(0.0), ApproxG(0.0), atol=0.1)
227    assert not cirq.approx_eq(ApproxGa(0.0), ApproxGb(0.0), atol=0.1)
228    assert not cirq.approx_eq(ApproxG(0.0), ApproxGb(0.0), atol=0.1)
229
230
231def test_value_equality_forgot_method():
232    with pytest.raises(TypeError, match='_value_equality_values_'):
233
234        @cirq.value_equality
235        class _:
236            pass
237
238
239def test_bad_manual_cls_incompatible_args():
240    with pytest.raises(ValueError, match='incompatible'):
241
242        @cirq.value_equality(manual_cls=True, distinct_child_types=True)
243        class _:
244            pass
245
246
247def test_bad_manual_cls_forgot_method():
248    with pytest.raises(TypeError, match='_value_equality_values_cls_'):
249
250        @cirq.value_equality(manual_cls=True)
251        class _:
252            def _value_equality_values_(self):
253                pass
254