1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import pytest 15 16import cirq 17 18 19@cirq.value_equality 20class BasicC: 21 def __init__(self, x): 22 self.x = x 23 24 def _value_equality_values_(self): 25 return self.x 26 27 28@cirq.value_equality 29class BasicD: 30 def __init__(self, x): 31 self.x = x 32 33 def _value_equality_values_(self): 34 return self.x 35 36 37@cirq.value_equality(manual_cls=True) 38class MasqueradePositiveD: 39 def __init__(self, x): 40 self.x = x 41 42 def _value_equality_values_(self): 43 return self.x 44 45 def _value_equality_values_cls_(self): 46 return BasicD if self.x > 0 else MasqueradePositiveD 47 48 49class BasicCa(BasicC): 50 pass 51 52 53class BasicCb(BasicC): 54 pass 55 56 57def test_value_equality_basic(): 58 59 # Lookup works across equivalent types. 60 v = {BasicC(1): 4, BasicCa(2): 5} 61 assert v[BasicCa(1)] == v[BasicC(1)] == 4 62 assert v[BasicCa(2)] == 5 63 64 # Equality works as expected. 65 eq = cirq.testing.EqualsTester() 66 eq.add_equality_group(BasicC(1), BasicC(1), BasicCa(1), BasicCb(1)) 67 eq.add_equality_group(BasicD(1)) 68 eq.add_equality_group(BasicC(2)) 69 eq.add_equality_group(BasicCa(3)) 70 71 72def test_value_equality_manual(): 73 eq = cirq.testing.EqualsTester() 74 eq.add_equality_group(MasqueradePositiveD(3), BasicD(3)) 75 eq.add_equality_group(MasqueradePositiveD(4), MasqueradePositiveD(4), BasicD(4)) 76 eq.add_equality_group(MasqueradePositiveD(-1), MasqueradePositiveD(-1)) 77 eq.add_equality_group(BasicD(-1)) 78 79 80@cirq.value_equality(unhashable=True) 81class UnhashableC: 82 def __init__(self, x): 83 self.x = x 84 85 def _value_equality_values_(self): 86 return self.x 87 88 89@cirq.value_equality(unhashable=True) 90class UnhashableD: 91 def __init__(self, x): 92 self.x = x 93 94 def _value_equality_values_(self): 95 return self.x 96 97 98class UnhashableCa(UnhashableC): 99 pass 100 101 102class UnhashableCb(UnhashableC): 103 pass 104 105 106def test_value_equality_unhashable(): 107 # Not possible to use as a dictionary key. 108 with pytest.raises(TypeError, match='unhashable'): 109 _ = {UnhashableC(1): 4} 110 111 # Equality works as expected. 112 eq = cirq.testing.EqualsTester() 113 eq.add_equality_group(UnhashableC(1), UnhashableC(1), UnhashableCa(1), UnhashableCb(1)) 114 eq.add_equality_group(UnhashableC(2)) 115 eq.add_equality_group(UnhashableD(1)) 116 117 118@cirq.value_equality(distinct_child_types=True) 119class DistinctC: 120 def __init__(self, x): 121 self.x = x 122 123 def _value_equality_values_(self): 124 return self.x 125 126 127@cirq.value_equality(distinct_child_types=True) 128class DistinctD: 129 def __init__(self, x): 130 self.x = x 131 132 def _value_equality_values_(self): 133 return self.x 134 135 136class DistinctCa(DistinctC): 137 pass 138 139 140class DistinctCb(DistinctC): 141 pass 142 143 144def test_value_equality_distinct_child_types(): 145 # Lookup is distinct across child types. 146 v = {DistinctC(1): 4, DistinctCa(1): 5, DistinctCb(1): 6} 147 assert v[DistinctC(1)] == 4 148 assert v[DistinctCa(1)] == 5 149 assert v[DistinctCb(1)] == 6 150 151 # Equality works as expected. 152 eq = cirq.testing.EqualsTester() 153 eq.add_equality_group(DistinctC(1), DistinctC(1)) 154 eq.add_equality_group(DistinctCa(1), DistinctCa(1)) 155 eq.add_equality_group(DistinctCb(1), DistinctCb(1)) 156 eq.add_equality_group(DistinctC(2)) 157 eq.add_equality_group(DistinctD(1)) 158 159 160@cirq.value_equality(approximate=True) 161class ApproxE: 162 def __init__(self, x): 163 self.x = x 164 165 def _value_equality_values_(self): 166 return self.x 167 168 169def test_value_equality_approximate(): 170 assert cirq.approx_eq(ApproxE(0.0), ApproxE(0.0), atol=0.1) 171 assert cirq.approx_eq(ApproxE(0.0), ApproxE(0.2), atol=0.3) 172 assert not cirq.approx_eq(ApproxE(0.0), ApproxE(0.2), atol=0.1) 173 174 175@cirq.value_equality(approximate=True) 176class PeriodicF: 177 def __init__(self, x, n): 178 self.x = x 179 self.n = n 180 181 def _value_equality_values_(self): 182 return self.x 183 184 def _value_equality_approximate_values_(self): 185 return self.x % self.n 186 187 188def test_value_equality_approximate_specialized(): 189 assert PeriodicF(1, 4) != PeriodicF(5, 4) 190 assert cirq.approx_eq(PeriodicF(1, 4), PeriodicF(5, 4), atol=0.1) 191 assert not cirq.approx_eq(PeriodicF(1, 4), PeriodicF(6, 4), atol=0.1) 192 193 194def test_value_equality_approximate_not_supported(): 195 assert not cirq.approx_eq(BasicC(0.0), BasicC(0.1), atol=0.2) 196 197 198class ApproxEa(ApproxE): 199 pass 200 201 202class ApproxEb(ApproxE): 203 pass 204 205 206@cirq.value_equality(distinct_child_types=True, approximate=True) 207class ApproxG: 208 def __init__(self, x): 209 self.x = x 210 211 def _value_equality_values_(self): 212 return self.x 213 214 215class ApproxGa(ApproxG): 216 pass 217 218 219class ApproxGb(ApproxG): 220 pass 221 222 223def test_value_equality_approximate_typing(): 224 assert not cirq.approx_eq(ApproxE(0.0), PeriodicF(0.0, 1.0), atol=0.1) 225 assert cirq.approx_eq(ApproxEa(0.0), ApproxEb(0.0), atol=0.1) 226 assert cirq.approx_eq(ApproxG(0.0), ApproxG(0.0), atol=0.1) 227 assert not cirq.approx_eq(ApproxGa(0.0), ApproxGb(0.0), atol=0.1) 228 assert not cirq.approx_eq(ApproxG(0.0), ApproxGb(0.0), atol=0.1) 229 230 231def test_value_equality_forgot_method(): 232 with pytest.raises(TypeError, match='_value_equality_values_'): 233 234 @cirq.value_equality 235 class _: 236 pass 237 238 239def test_bad_manual_cls_incompatible_args(): 240 with pytest.raises(ValueError, match='incompatible'): 241 242 @cirq.value_equality(manual_cls=True, distinct_child_types=True) 243 class _: 244 pass 245 246 247def test_bad_manual_cls_forgot_method(): 248 with pytest.raises(TypeError, match='_value_equality_values_cls_'): 249 250 @cirq.value_equality(manual_cls=True) 251 class _: 252 def _value_equality_values_(self): 253 pass 254