1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from fractions import Fraction
16from decimal import Decimal
17from numbers import Number
18import numpy as np
19import pytest
20import sympy
21import cirq
22
23
24def test_approx_eq_primitives():
25    assert not cirq.approx_eq(1, 2, atol=1e-01)
26    assert cirq.approx_eq(1.0, 1.0 + 1e-10, atol=1e-09)
27    assert not cirq.approx_eq(1.0, 1.0 + 1e-10, atol=1e-11)
28    assert cirq.approx_eq(0.0, 1e-10, atol=1e-09)
29    assert not cirq.approx_eq(0.0, 1e-10, atol=1e-11)
30    assert cirq.approx_eq(complex(1, 1), complex(1.1, 1.2), atol=0.3)
31    assert not cirq.approx_eq(complex(1, 1), complex(1.1, 1.2), atol=0.1)
32    assert cirq.approx_eq('ab', 'ab', atol=1e-3)
33    assert not cirq.approx_eq('ab', 'ac', atol=1e-3)
34    assert not cirq.approx_eq('1', '2', atol=999)
35    assert not cirq.approx_eq('test', 1, atol=1e-3)
36    assert not cirq.approx_eq('1', 1, atol=1e-3)
37
38
39def test_approx_eq_mixed_primitives():
40    assert cirq.approx_eq(complex(1, 1e-10), 1, atol=1e-09)
41    assert not cirq.approx_eq(complex(1, 1e-4), 1, atol=1e-09)
42    assert cirq.approx_eq(complex(1, 1e-10), 1.0, atol=1e-09)
43    assert not cirq.approx_eq(complex(1, 1e-8), 1.0, atol=1e-09)
44    assert cirq.approx_eq(1, 1.0 + 1e-10, atol=1e-9)
45    assert not cirq.approx_eq(1, 1.0 + 1e-10, atol=1e-11)
46
47
48def test_numpy_dtype_compatibility():
49    i_a, i_b, i_c = 0, 1, 2
50    i_types = [np.intc, np.intp, np.int0, np.int8, np.int16, np.int32, np.int64]
51    for i_type in i_types:
52        assert cirq.approx_eq(i_type(i_a), i_type(i_b), atol=1)
53        assert not cirq.approx_eq(i_type(i_a), i_type(i_c), atol=1)
54    u_types = [np.uint, np.uint0, np.uint8, np.uint16, np.uint32, np.uint64]
55    for u_type in u_types:
56        assert cirq.approx_eq(u_type(i_a), u_type(i_b), atol=1)
57        assert not cirq.approx_eq(u_type(i_a), u_type(i_c), atol=1)
58
59    f_a, f_b, f_c = 0, 1e-8, 1
60    f_types = [np.float16, np.float32, np.float64]
61    if hasattr(np, 'float128'):
62        f_types.append(np.float128)
63    for f_type in f_types:
64        assert cirq.approx_eq(f_type(f_a), f_type(f_b), atol=1e-8)
65        assert not cirq.approx_eq(f_type(f_a), f_type(f_c), atol=1e-8)
66
67    c_a, c_b, c_c = 0, 1e-8j, 1j
68    c_types = [np.complex64, np.complex128]
69    if hasattr(np, 'complex256'):
70        c_types.append(np.complex256)
71    for c_type in c_types:
72        assert cirq.approx_eq(c_type(c_a), c_type(c_b), atol=1e-8)
73        assert not cirq.approx_eq(c_type(c_a), c_type(c_c), atol=1e-8)
74
75
76def test_fractions_compatibility():
77    assert cirq.approx_eq(Fraction(0), Fraction(1, int(1e10)), atol=1e-9)
78    assert not cirq.approx_eq(Fraction(0), Fraction(1, int(1e7)), atol=1e-9)
79
80
81def test_decimal_compatibility():
82    assert cirq.approx_eq(Decimal('0'), Decimal('0.0000000001'), atol=1e-9)
83    assert not cirq.approx_eq(Decimal('0'), Decimal('0.00000001'), atol=1e-9)
84    assert not cirq.approx_eq(Decimal('NaN'), Decimal('-Infinity'), atol=1e-9)
85
86
87def test_approx_eq_mixed_types():
88    assert cirq.approx_eq(np.float32(1), 1.0 + 1e-10, atol=1e-9)
89    assert cirq.approx_eq(np.float64(1), np.complex64(1 + 1e-8j), atol=1e-4)
90    assert cirq.approx_eq(np.uint8(1), np.complex64(1 + 1e-8j), atol=1e-4)
91    if hasattr(np, 'complex256'):
92        assert cirq.approx_eq(np.complex256(1), complex(1, 1e-8), atol=1e-4)
93    assert cirq.approx_eq(np.int32(1), 1, atol=1e-9)
94    assert cirq.approx_eq(complex(0.5, 0), Fraction(1, 2), atol=0.0)
95    assert cirq.approx_eq(0.5 + 1e-4j, Fraction(1, 2), atol=1e-4)
96    assert cirq.approx_eq(0, Fraction(1, 100000000), atol=1e-8)
97    assert cirq.approx_eq(np.uint16(1), Decimal('1'), atol=0.0)
98    assert cirq.approx_eq(np.float64(1.0), Decimal('1.00000001'), atol=1e-8)
99    assert not cirq.approx_eq(np.complex64(1e-5j), Decimal('0.001'), atol=1e-4)
100
101
102def test_approx_eq_special_numerics():
103    assert not cirq.approx_eq(float('nan'), 0, atol=0.0)
104    assert not cirq.approx_eq(float('nan'), float('nan'), atol=0.0)
105    assert not cirq.approx_eq(float('inf'), float('-inf'), atol=0.0)
106    assert not cirq.approx_eq(float('inf'), 5, atol=0.0)
107    assert not cirq.approx_eq(float('inf'), 0, atol=0.0)
108    assert cirq.approx_eq(float('inf'), float('inf'), atol=0.0)
109
110
111class X(Number):
112    """Subtype of Number that can fallback to __eq__"""
113
114    def __init__(self, val):
115        self.val = val
116
117    def __eq__(self, other):
118        if not isinstance(self, type(other)):
119            return NotImplemented
120        return self.val == other.val
121
122
123class Y(Number):
124    """Subtype of Number that cannot fallback to __eq__"""
125
126    def __init__(self):
127        pass
128
129
130def test_approx_eq_number_uses__eq__():
131    assert cirq.approx_eq(C(0), C(0), atol=0.0)
132    assert not cirq.approx_eq(X(0), X(1), atol=0.0)
133    assert not cirq.approx_eq(X(0), 0, atol=0.0)
134    assert not cirq.approx_eq(Y(), 1, atol=0.0)
135
136
137def test_approx_eq_tuple():
138    assert cirq.approx_eq((1, 1), (1, 1), atol=0.0)
139    assert not cirq.approx_eq((1, 1), (1, 1, 1), atol=0.0)
140    assert not cirq.approx_eq((1, 1), (1,), atol=0.0)
141    assert cirq.approx_eq((1.1, 1.2, 1.3), (1, 1, 1), atol=0.4)
142    assert not cirq.approx_eq((1.1, 1.2, 1.3), (1, 1, 1), atol=0.2)
143
144
145def test_approx_eq_list():
146    assert cirq.approx_eq([], [], atol=0.0)
147    assert not cirq.approx_eq([], [[]], atol=0.0)
148    assert cirq.approx_eq([1, 1], [1, 1], atol=0.0)
149    assert not cirq.approx_eq([1, 1], [1, 1, 1], atol=0.0)
150    assert not cirq.approx_eq(
151        [1, 1],
152        [
153            1,
154        ],
155        atol=0.0,
156    )
157    assert cirq.approx_eq([1.1, 1.2, 1.3], [1, 1, 1], atol=0.4)
158    assert not cirq.approx_eq([1.1, 1.2, 1.3], [1, 1, 1], atol=0.2)
159
160
161def test_approx_eq_symbol():
162    q = cirq.GridQubit(0, 0)
163    s = sympy.Symbol("s")
164    t = sympy.Symbol("t")
165
166    assert not cirq.approx_eq(t + 1.51 + s, t + 1.50 + s, atol=0.005)
167    assert cirq.approx_eq(t + 1.51 + s, t + 1.50 + s, atol=0.020)
168
169    with pytest.raises(
170        AttributeError,
171        match="Insufficient information to decide whether expressions are "
172        "approximately equal .* vs .*",
173    ):
174        cirq.approx_eq(t, 0.0, atol=0.005)
175
176    symbol_1 = cirq.Circuit(cirq.rz(1.515 + s)(q))
177    symbol_2 = cirq.Circuit(cirq.rz(1.510 + s)(q))
178    assert cirq.approx_eq(symbol_1, symbol_2, atol=0.2)
179
180    symbol_3 = cirq.Circuit(cirq.rz(1.510 + t)(q))
181    with pytest.raises(
182        AttributeError,
183        match="Insufficient information to decide whether expressions are "
184        "approximately equal .* vs .*",
185    ):
186        cirq.approx_eq(symbol_1, symbol_3, atol=0.2)
187
188
189def test_approx_eq_default():
190    assert cirq.approx_eq(1.0, 1.0 + 1e-9)
191    assert cirq.approx_eq(1.0, 1.0 - 1e-9)
192    assert not cirq.approx_eq(1.0, 1.0 + 1e-7)
193    assert not cirq.approx_eq(1.0, 1.0 - 1e-7)
194
195
196def test_approx_eq_iterables():
197    def gen_1_1():
198        yield 1
199        yield 1
200
201    assert cirq.approx_eq((1, 1), [1, 1], atol=0.0)
202    assert cirq.approx_eq((1, 1), gen_1_1(), atol=0.0)
203    assert cirq.approx_eq(gen_1_1(), [1, 1], atol=0.0)
204
205
206class A:
207    def __init__(self, val):
208        self.val = val
209
210    def _approx_eq_(self, other, atol):
211        if not isinstance(self, type(other)):
212            return NotImplemented
213        return cirq.approx_eq(self.val, other.val, atol=atol)
214
215
216class B:
217    def __init__(self, val):
218        self.val = val
219
220    def _approx_eq_(self, other, atol):
221        if not isinstance(self.val, type(other)):
222            return NotImplemented
223        return cirq.approx_eq(self.val, other, atol=atol)
224
225
226def test_approx_eq_supported():
227    assert cirq.approx_eq(A(0.0), A(0.1), atol=0.1)
228    assert not cirq.approx_eq(A(0.0), A(0.1), atol=0.0)
229    assert cirq.approx_eq(B(0.0), 0.1, atol=0.1)
230    assert cirq.approx_eq(0.1, B(0.0), atol=0.1)
231
232
233class C:
234    def __init__(self, val):
235        self.val = val
236
237    def __eq__(self, other):
238        if not isinstance(self, type(other)):
239            return NotImplemented
240        return self.val == other.val
241
242
243def test_approx_eq_uses__eq__():
244    assert cirq.approx_eq(C(0), C(0), atol=0.0)
245    assert not cirq.approx_eq(C(1), C(2), atol=0.0)
246    assert cirq.approx_eq([C(0)], [C(0)], atol=0.0)
247    assert not cirq.approx_eq([C(1)], [C(2)], atol=0.0)
248    assert cirq.approx_eq(complex(0, 0), 0, atol=0.0)
249    assert cirq.approx_eq(0, complex(0, 0), atol=0.0)
250
251
252def test_approx_eq_types_mismatch():
253    assert not cirq.approx_eq(0, A(0), atol=0.0)
254    assert not cirq.approx_eq(A(0), 0, atol=0.0)
255    assert not cirq.approx_eq(B(0), A(0), atol=0.0)
256    assert not cirq.approx_eq(A(0), B(0), atol=0.0)
257    assert not cirq.approx_eq(C(0), A(0), atol=0.0)
258    assert not cirq.approx_eq(A(0), C(0), atol=0.0)
259    assert not cirq.approx_eq(0, [0], atol=1.0)
260    assert not cirq.approx_eq([0], 0, atol=0.0)
261