1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from fractions import Fraction 16from decimal import Decimal 17from numbers import Number 18import numpy as np 19import pytest 20import sympy 21import cirq 22 23 24def test_approx_eq_primitives(): 25 assert not cirq.approx_eq(1, 2, atol=1e-01) 26 assert cirq.approx_eq(1.0, 1.0 + 1e-10, atol=1e-09) 27 assert not cirq.approx_eq(1.0, 1.0 + 1e-10, atol=1e-11) 28 assert cirq.approx_eq(0.0, 1e-10, atol=1e-09) 29 assert not cirq.approx_eq(0.0, 1e-10, atol=1e-11) 30 assert cirq.approx_eq(complex(1, 1), complex(1.1, 1.2), atol=0.3) 31 assert not cirq.approx_eq(complex(1, 1), complex(1.1, 1.2), atol=0.1) 32 assert cirq.approx_eq('ab', 'ab', atol=1e-3) 33 assert not cirq.approx_eq('ab', 'ac', atol=1e-3) 34 assert not cirq.approx_eq('1', '2', atol=999) 35 assert not cirq.approx_eq('test', 1, atol=1e-3) 36 assert not cirq.approx_eq('1', 1, atol=1e-3) 37 38 39def test_approx_eq_mixed_primitives(): 40 assert cirq.approx_eq(complex(1, 1e-10), 1, atol=1e-09) 41 assert not cirq.approx_eq(complex(1, 1e-4), 1, atol=1e-09) 42 assert cirq.approx_eq(complex(1, 1e-10), 1.0, atol=1e-09) 43 assert not cirq.approx_eq(complex(1, 1e-8), 1.0, atol=1e-09) 44 assert cirq.approx_eq(1, 1.0 + 1e-10, atol=1e-9) 45 assert not cirq.approx_eq(1, 1.0 + 1e-10, atol=1e-11) 46 47 48def test_numpy_dtype_compatibility(): 49 i_a, i_b, i_c = 0, 1, 2 50 i_types = [np.intc, np.intp, np.int0, np.int8, np.int16, np.int32, np.int64] 51 for i_type in i_types: 52 assert cirq.approx_eq(i_type(i_a), i_type(i_b), atol=1) 53 assert not cirq.approx_eq(i_type(i_a), i_type(i_c), atol=1) 54 u_types = [np.uint, np.uint0, np.uint8, np.uint16, np.uint32, np.uint64] 55 for u_type in u_types: 56 assert cirq.approx_eq(u_type(i_a), u_type(i_b), atol=1) 57 assert not cirq.approx_eq(u_type(i_a), u_type(i_c), atol=1) 58 59 f_a, f_b, f_c = 0, 1e-8, 1 60 f_types = [np.float16, np.float32, np.float64] 61 if hasattr(np, 'float128'): 62 f_types.append(np.float128) 63 for f_type in f_types: 64 assert cirq.approx_eq(f_type(f_a), f_type(f_b), atol=1e-8) 65 assert not cirq.approx_eq(f_type(f_a), f_type(f_c), atol=1e-8) 66 67 c_a, c_b, c_c = 0, 1e-8j, 1j 68 c_types = [np.complex64, np.complex128] 69 if hasattr(np, 'complex256'): 70 c_types.append(np.complex256) 71 for c_type in c_types: 72 assert cirq.approx_eq(c_type(c_a), c_type(c_b), atol=1e-8) 73 assert not cirq.approx_eq(c_type(c_a), c_type(c_c), atol=1e-8) 74 75 76def test_fractions_compatibility(): 77 assert cirq.approx_eq(Fraction(0), Fraction(1, int(1e10)), atol=1e-9) 78 assert not cirq.approx_eq(Fraction(0), Fraction(1, int(1e7)), atol=1e-9) 79 80 81def test_decimal_compatibility(): 82 assert cirq.approx_eq(Decimal('0'), Decimal('0.0000000001'), atol=1e-9) 83 assert not cirq.approx_eq(Decimal('0'), Decimal('0.00000001'), atol=1e-9) 84 assert not cirq.approx_eq(Decimal('NaN'), Decimal('-Infinity'), atol=1e-9) 85 86 87def test_approx_eq_mixed_types(): 88 assert cirq.approx_eq(np.float32(1), 1.0 + 1e-10, atol=1e-9) 89 assert cirq.approx_eq(np.float64(1), np.complex64(1 + 1e-8j), atol=1e-4) 90 assert cirq.approx_eq(np.uint8(1), np.complex64(1 + 1e-8j), atol=1e-4) 91 if hasattr(np, 'complex256'): 92 assert cirq.approx_eq(np.complex256(1), complex(1, 1e-8), atol=1e-4) 93 assert cirq.approx_eq(np.int32(1), 1, atol=1e-9) 94 assert cirq.approx_eq(complex(0.5, 0), Fraction(1, 2), atol=0.0) 95 assert cirq.approx_eq(0.5 + 1e-4j, Fraction(1, 2), atol=1e-4) 96 assert cirq.approx_eq(0, Fraction(1, 100000000), atol=1e-8) 97 assert cirq.approx_eq(np.uint16(1), Decimal('1'), atol=0.0) 98 assert cirq.approx_eq(np.float64(1.0), Decimal('1.00000001'), atol=1e-8) 99 assert not cirq.approx_eq(np.complex64(1e-5j), Decimal('0.001'), atol=1e-4) 100 101 102def test_approx_eq_special_numerics(): 103 assert not cirq.approx_eq(float('nan'), 0, atol=0.0) 104 assert not cirq.approx_eq(float('nan'), float('nan'), atol=0.0) 105 assert not cirq.approx_eq(float('inf'), float('-inf'), atol=0.0) 106 assert not cirq.approx_eq(float('inf'), 5, atol=0.0) 107 assert not cirq.approx_eq(float('inf'), 0, atol=0.0) 108 assert cirq.approx_eq(float('inf'), float('inf'), atol=0.0) 109 110 111class X(Number): 112 """Subtype of Number that can fallback to __eq__""" 113 114 def __init__(self, val): 115 self.val = val 116 117 def __eq__(self, other): 118 if not isinstance(self, type(other)): 119 return NotImplemented 120 return self.val == other.val 121 122 123class Y(Number): 124 """Subtype of Number that cannot fallback to __eq__""" 125 126 def __init__(self): 127 pass 128 129 130def test_approx_eq_number_uses__eq__(): 131 assert cirq.approx_eq(C(0), C(0), atol=0.0) 132 assert not cirq.approx_eq(X(0), X(1), atol=0.0) 133 assert not cirq.approx_eq(X(0), 0, atol=0.0) 134 assert not cirq.approx_eq(Y(), 1, atol=0.0) 135 136 137def test_approx_eq_tuple(): 138 assert cirq.approx_eq((1, 1), (1, 1), atol=0.0) 139 assert not cirq.approx_eq((1, 1), (1, 1, 1), atol=0.0) 140 assert not cirq.approx_eq((1, 1), (1,), atol=0.0) 141 assert cirq.approx_eq((1.1, 1.2, 1.3), (1, 1, 1), atol=0.4) 142 assert not cirq.approx_eq((1.1, 1.2, 1.3), (1, 1, 1), atol=0.2) 143 144 145def test_approx_eq_list(): 146 assert cirq.approx_eq([], [], atol=0.0) 147 assert not cirq.approx_eq([], [[]], atol=0.0) 148 assert cirq.approx_eq([1, 1], [1, 1], atol=0.0) 149 assert not cirq.approx_eq([1, 1], [1, 1, 1], atol=0.0) 150 assert not cirq.approx_eq( 151 [1, 1], 152 [ 153 1, 154 ], 155 atol=0.0, 156 ) 157 assert cirq.approx_eq([1.1, 1.2, 1.3], [1, 1, 1], atol=0.4) 158 assert not cirq.approx_eq([1.1, 1.2, 1.3], [1, 1, 1], atol=0.2) 159 160 161def test_approx_eq_symbol(): 162 q = cirq.GridQubit(0, 0) 163 s = sympy.Symbol("s") 164 t = sympy.Symbol("t") 165 166 assert not cirq.approx_eq(t + 1.51 + s, t + 1.50 + s, atol=0.005) 167 assert cirq.approx_eq(t + 1.51 + s, t + 1.50 + s, atol=0.020) 168 169 with pytest.raises( 170 AttributeError, 171 match="Insufficient information to decide whether expressions are " 172 "approximately equal .* vs .*", 173 ): 174 cirq.approx_eq(t, 0.0, atol=0.005) 175 176 symbol_1 = cirq.Circuit(cirq.rz(1.515 + s)(q)) 177 symbol_2 = cirq.Circuit(cirq.rz(1.510 + s)(q)) 178 assert cirq.approx_eq(symbol_1, symbol_2, atol=0.2) 179 180 symbol_3 = cirq.Circuit(cirq.rz(1.510 + t)(q)) 181 with pytest.raises( 182 AttributeError, 183 match="Insufficient information to decide whether expressions are " 184 "approximately equal .* vs .*", 185 ): 186 cirq.approx_eq(symbol_1, symbol_3, atol=0.2) 187 188 189def test_approx_eq_default(): 190 assert cirq.approx_eq(1.0, 1.0 + 1e-9) 191 assert cirq.approx_eq(1.0, 1.0 - 1e-9) 192 assert not cirq.approx_eq(1.0, 1.0 + 1e-7) 193 assert not cirq.approx_eq(1.0, 1.0 - 1e-7) 194 195 196def test_approx_eq_iterables(): 197 def gen_1_1(): 198 yield 1 199 yield 1 200 201 assert cirq.approx_eq((1, 1), [1, 1], atol=0.0) 202 assert cirq.approx_eq((1, 1), gen_1_1(), atol=0.0) 203 assert cirq.approx_eq(gen_1_1(), [1, 1], atol=0.0) 204 205 206class A: 207 def __init__(self, val): 208 self.val = val 209 210 def _approx_eq_(self, other, atol): 211 if not isinstance(self, type(other)): 212 return NotImplemented 213 return cirq.approx_eq(self.val, other.val, atol=atol) 214 215 216class B: 217 def __init__(self, val): 218 self.val = val 219 220 def _approx_eq_(self, other, atol): 221 if not isinstance(self.val, type(other)): 222 return NotImplemented 223 return cirq.approx_eq(self.val, other, atol=atol) 224 225 226def test_approx_eq_supported(): 227 assert cirq.approx_eq(A(0.0), A(0.1), atol=0.1) 228 assert not cirq.approx_eq(A(0.0), A(0.1), atol=0.0) 229 assert cirq.approx_eq(B(0.0), 0.1, atol=0.1) 230 assert cirq.approx_eq(0.1, B(0.0), atol=0.1) 231 232 233class C: 234 def __init__(self, val): 235 self.val = val 236 237 def __eq__(self, other): 238 if not isinstance(self, type(other)): 239 return NotImplemented 240 return self.val == other.val 241 242 243def test_approx_eq_uses__eq__(): 244 assert cirq.approx_eq(C(0), C(0), atol=0.0) 245 assert not cirq.approx_eq(C(1), C(2), atol=0.0) 246 assert cirq.approx_eq([C(0)], [C(0)], atol=0.0) 247 assert not cirq.approx_eq([C(1)], [C(2)], atol=0.0) 248 assert cirq.approx_eq(complex(0, 0), 0, atol=0.0) 249 assert cirq.approx_eq(0, complex(0, 0), atol=0.0) 250 251 252def test_approx_eq_types_mismatch(): 253 assert not cirq.approx_eq(0, A(0), atol=0.0) 254 assert not cirq.approx_eq(A(0), 0, atol=0.0) 255 assert not cirq.approx_eq(B(0), A(0), atol=0.0) 256 assert not cirq.approx_eq(A(0), B(0), atol=0.0) 257 assert not cirq.approx_eq(C(0), A(0), atol=0.0) 258 assert not cirq.approx_eq(A(0), C(0), atol=0.0) 259 assert not cirq.approx_eq(0, [0], atol=1.0) 260 assert not cirq.approx_eq([0], 0, atol=0.0) 261