1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import numpy as np 15import cirq 16 17 18def test_equal_up_to_global_phase_primitives(): 19 assert cirq.equal_up_to_global_phase(1.0 + 1j, 1.0 + 1j, atol=1e-9) 20 assert not cirq.equal_up_to_global_phase(2.0, 1.0 + 1j, atol=1e-9) 21 assert cirq.equal_up_to_global_phase(1.0 + 1j, 1.0 - 1j, atol=1e-9) 22 assert cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 1.0 + 0.0j, atol=1e-9) 23 assert cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 1.0j, atol=1e-9) 24 assert cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 1, atol=1e-9) 25 assert not cirq.equal_up_to_global_phase(np.exp(1j * 3.3), 0, atol=1e-9) 26 assert cirq.equal_up_to_global_phase(1j, 1 + 1e-10, atol=1e-9) 27 assert not cirq.equal_up_to_global_phase(1j, 1 + 1e-10, atol=1e-11) 28 # atol is applied to magnitude of complex vector, not components. 29 assert cirq.equal_up_to_global_phase(1.0 + 0.1j, 1.0, atol=0.01) 30 assert not cirq.equal_up_to_global_phase(1.0 + 0.1j, 1.0, atol=0.001) 31 assert cirq.equal_up_to_global_phase(1.0 + 1j, np.sqrt(2) + 1e-8, atol=1e-7) 32 assert not cirq.equal_up_to_global_phase(1.0 + 1j, np.sqrt(2) + 1e-7, atol=1e-8) 33 assert cirq.equal_up_to_global_phase(1.0 + 1e-10j, 1.0, atol=1e-15) 34 35 36def test_equal_up_to_global_numeric_iterables(): 37 assert cirq.equal_up_to_global_phase([], [], atol=1e-9) 38 assert cirq.equal_up_to_global_phase([[]], [[]], atol=1e-9) 39 assert cirq.equal_up_to_global_phase([1j, 1], [1j, 1], atol=1e-9) 40 assert cirq.equal_up_to_global_phase([1j, 1j], [1 + 0.1j, 1 + 0.1j], atol=0.01) 41 assert not cirq.equal_up_to_global_phase([1j, 1j], [1 + 0.1j, 1 - 0.1j], atol=0.01) 42 assert not cirq.equal_up_to_global_phase([1j, 1j], [1 + 0.1j, 1 + 0.1j], atol=1e-3) 43 assert not cirq.equal_up_to_global_phase([1j, -1j], [1, 1], atol=0.0) 44 assert not cirq.equal_up_to_global_phase([1j, 1], [1, 1j], atol=0.0) 45 assert not cirq.equal_up_to_global_phase([1j, 1], [1j, 1, 0], atol=0.0) 46 assert cirq.equal_up_to_global_phase((1j, 1j), (1, 1 + 1e-4), atol=1e-3) 47 assert not cirq.equal_up_to_global_phase((1j, 1j), (1, 1 + 1e-4), atol=1e-5) 48 assert not cirq.equal_up_to_global_phase((1j, 1), (1, 1j), atol=1e-09) 49 50 51def test_equal_up_to_global_numpy_array(): 52 assert cirq.equal_up_to_global_phase( 53 np.asarray([1j, 1j]), np.asarray([1, 1], dtype=np.complex64) 54 ) 55 assert not cirq.equal_up_to_global_phase( 56 np.asarray([1j, -1j]), np.asarray([1, 1], dtype=np.complex64) 57 ) 58 assert cirq.equal_up_to_global_phase(np.asarray([]), np.asarray([])) 59 assert cirq.equal_up_to_global_phase(np.asarray([[]]), np.asarray([[]])) 60 61 62def test_equal_up_to_global_mixed_array_types(): 63 a = [1j, 1, -1j, -1] 64 b = [-1, 1j, 1, -1j] 65 c = [-1, 1, -1, 1] 66 assert cirq.equal_up_to_global_phase(a, tuple(b)) 67 assert not cirq.equal_up_to_global_phase(a, tuple(c)) 68 69 c_types = [np.complex64, np.complex128] 70 if hasattr(np, 'complex256'): 71 c_types.append(np.complex256) 72 for c_type in c_types: 73 assert cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), tuple(b)) 74 assert not cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), tuple(c)) 75 assert cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), b) 76 assert not cirq.equal_up_to_global_phase(np.asarray(a, dtype=c_type), c) 77 78 # Object arrays and mixed array/scalar comparisons. 79 assert not cirq.equal_up_to_global_phase([1j], 1j) 80 assert not cirq.equal_up_to_global_phase(np.asarray([1], dtype=np.complex128), np.exp(1j)) 81 assert not cirq.equal_up_to_global_phase([1j, 1j], [1j, "1j"]) 82 assert not cirq.equal_up_to_global_phase([1j], "Non-numeric iterable") 83 assert not cirq.equal_up_to_global_phase([], [[]], atol=0.0) 84 85 86# Dummy container class implementing _equal_up_to_global_phase_ 87# for homogeneous comparison, with nontrivial getter. 88class A: 89 def __init__(self, val): 90 self.val = [val] 91 92 def _equal_up_to_global_phase_(self, other, atol): 93 if not isinstance(other, A): 94 return NotImplemented 95 return cirq.equal_up_to_global_phase(self.val[0], other.val[0], atol=atol) 96 97 98# Dummy container class implementing _equal_up_to_global_phase_ 99# for heterogeneous comparison. 100class B: 101 def __init__(self, val): 102 self.val = [val] 103 104 def _equal_up_to_global_phase_(self, other, atol): 105 if not isinstance(self.val[0], type(other)): 106 return NotImplemented 107 return cirq.equal_up_to_global_phase(self.val[0], other, atol=atol) 108 109 110def test_equal_up_to_global_phase_eq_supported(): 111 assert cirq.equal_up_to_global_phase(A(0.1 + 0j), A(0.1j), atol=1e-2) 112 assert not cirq.equal_up_to_global_phase(A(0.0 + 0j), A(0.1j), atol=0.0) 113 assert not cirq.equal_up_to_global_phase(A(0.0 + 0j), 0.1j, atol=0.0) 114 assert cirq.equal_up_to_global_phase(B(0.0j), 1e-8j, atol=1e-8) 115 assert cirq.equal_up_to_global_phase(1e-8j, B(0.0j), atol=1e-8) 116 assert not cirq.equal_up_to_global_phase(1e-8j, B(0.0 + 0j), atol=1e-10) 117 # cast types 118 assert cirq.equal_up_to_global_phase(A(0.1), A(0.1j), atol=1e-2) 119 assert not cirq.equal_up_to_global_phase(1e-8j, B(0.0), atol=1e-10) 120