1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import numpy as np
15import cirq
16
17
18def test_equal_up_to_global_phase_primitives():
19    assert cirq.equal_up_to_global_phase(1.0 + 1j, 1.0 + 1j, atol=1e-9)
20    assert not cirq.equal_up_to_global_phase(2.0, 1.0 + 1j, atol=1e-9)
21    assert cirq.equal_up_to_global_phase(1.0 + 1j, 1.0 - 1j, atol=1e-9)
22    assert cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 1.0 + 0.0j, atol=1e-9)
23    assert cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 1.0j, atol=1e-9)
24    assert cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 1, atol=1e-9)
25    assert not cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 0, atol=1e-9)
26    assert cirq.equal_up_to_global_phase(1j, 1 + 1e-10, atol=1e-9)
27    assert not cirq.equal_up_to_global_phase(1j, 1 + 1e-10, atol=1e-11)
28    # atol is applied to magnitude of complex vector, not components.
29    assert cirq.equal_up_to_global_phase(1.0 + 0.1j, 1.0, atol=0.01)
30    assert not cirq.equal_up_to_global_phase(1.0 + 0.1j, 1.0, atol=0.001)
31    assert cirq.equal_up_to_global_phase(1.0 + 1j, np.sqrt(2) + 1e-8, atol=1e-7)
32    assert not cirq.equal_up_to_global_phase(1.0 + 1j, np.sqrt(2) + 1e-7, atol=1e-8)
33    assert cirq.equal_up_to_global_phase(1.0 + 1e-10j, 1.0, atol=1e-15)
34
35
36def test_equal_up_to_global_numeric_iterables():
37    assert cirq.equal_up_to_global_phase([], [], atol=1e-9)
38    assert cirq.equal_up_to_global_phase([[]], [[]], atol=1e-9)
39    assert cirq.equal_up_to_global_phase([1j, 1], [1j, 1], atol=1e-9)
40    assert cirq.equal_up_to_global_phase([1j, 1j], [1 + 0.1j, 1 + 0.1j], atol=0.01)
41    assert not cirq.equal_up_to_global_phase([1j, 1j], [1 + 0.1j, 1 - 0.1j], atol=0.01)
42    assert not cirq.equal_up_to_global_phase([1j, 1j], [1 + 0.1j, 1 + 0.1j], atol=1e-3)
43    assert not cirq.equal_up_to_global_phase([1j, -1j], [1, 1], atol=0.0)
44    assert not cirq.equal_up_to_global_phase([1j, 1], [1, 1j], atol=0.0)
45    assert not cirq.equal_up_to_global_phase([1j, 1], [1j, 1, 0], atol=0.0)
46    assert cirq.equal_up_to_global_phase((1j, 1j), (1, 1 + 1e-4), atol=1e-3)
47    assert not cirq.equal_up_to_global_phase((1j, 1j), (1, 1 + 1e-4), atol=1e-5)
48    assert not cirq.equal_up_to_global_phase((1j, 1), (1, 1j), atol=1e-09)
49
50
51def test_equal_up_to_global_numpy_array():
52    assert cirq.equal_up_to_global_phase(
53        np.asarray([1j, 1j]), np.asarray([1, 1], dtype=np.complex64)
54    )
55    assert not cirq.equal_up_to_global_phase(
56        np.asarray([1j, -1j]), np.asarray([1, 1], dtype=np.complex64)
57    )
58    assert cirq.equal_up_to_global_phase(np.asarray([]), np.asarray([]))
59    assert cirq.equal_up_to_global_phase(np.asarray([[]]), np.asarray([[]]))
60
61
62def test_equal_up_to_global_mixed_array_types():
63    a = [1j, 1, -1j, -1]
64    b = [-1, 1j, 1, -1j]
65    c = [-1, 1, -1, 1]
66    assert cirq.equal_up_to_global_phase(a, tuple(b))
67    assert not cirq.equal_up_to_global_phase(a, tuple(c))
68
69    c_types = [np.complex64, np.complex128]
70    if hasattr(np, 'complex256'):
71        c_types.append(np.complex256)
72    for c_type in c_types:
73        assert cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), tuple(b))
74        assert not cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), tuple(c))
75        assert cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), b)
76        assert not cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), c)
77
78    # Object arrays and mixed array/scalar comparisons.
79    assert not cirq.equal_up_to_global_phase([1j], 1j)
80    assert not cirq.equal_up_to_global_phase(np.asarray([1], dtype=np.complex128), np.exp(1j))
81    assert not cirq.equal_up_to_global_phase([1j, 1j], [1j, "1j"])
82    assert not cirq.equal_up_to_global_phase([1j], "Non-numeric iterable")
83    assert not cirq.equal_up_to_global_phase([], [[]], atol=0.0)
84
85
86# Dummy container class implementing _equal_up_to_global_phase_
87# for homogeneous comparison, with nontrivial getter.
88class A:
89    def __init__(self, val):
90        self.val = [val]
91
92    def _equal_up_to_global_phase_(self, other, atol):
93        if not isinstance(other, A):
94            return NotImplemented
95        return cirq.equal_up_to_global_phase(self.val[0], other.val[0], atol=atol)
96
97
98# Dummy container class implementing _equal_up_to_global_phase_
99# for heterogeneous comparison.
100class B:
101    def __init__(self, val):
102        self.val = [val]
103
104    def _equal_up_to_global_phase_(self, other, atol):
105        if not isinstance(self.val[0], type(other)):
106            return NotImplemented
107        return cirq.equal_up_to_global_phase(self.val[0], other, atol=atol)
108
109
110def test_equal_up_to_global_phase_eq_supported():
111    assert cirq.equal_up_to_global_phase(A(0.1 + 0j), A(0.1j), atol=1e-2)
112    assert not cirq.equal_up_to_global_phase(A(0.0 + 0j), A(0.1j), atol=0.0)
113    assert not cirq.equal_up_to_global_phase(A(0.0 + 0j), 0.1j, atol=0.0)
114    assert cirq.equal_up_to_global_phase(B(0.0j), 1e-8j, atol=1e-8)
115    assert cirq.equal_up_to_global_phase(1e-8j, B(0.0j), atol=1e-8)
116    assert not cirq.equal_up_to_global_phase(1e-8j, B(0.0 + 0j), atol=1e-10)
117    # cast types
118    assert cirq.equal_up_to_global_phase(A(0.1), A(0.1j), atol=1e-2)
119    assert not cirq.equal_up_to_global_phase(1e-8j, B(0.0), atol=1e-10)
120