1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import itertools 16 17import numpy as np 18import pytest 19import scipy.linalg 20 21import cirq 22 23I = np.eye(2) 24X = np.array([[0, 1], [1, 0]]) 25Y = np.array([[0, -1j], [1j, 0]]) 26Z = np.array([[1, 0], [0, -1]]) 27H = np.array([[1, 1], [1, -1]]) * np.sqrt(0.5) 28SQRT_X = np.array([[np.sqrt(1j), np.sqrt(-1j)], [np.sqrt(-1j), np.sqrt(1j)]]) * np.sqrt(0.5) 29SQRT_Y = np.array([[np.sqrt(1j), -np.sqrt(1j)], [np.sqrt(1j), np.sqrt(1j)]]) * np.sqrt(0.5) 30SQRT_Z = np.diag([1, 1j]) 31E00 = np.diag([1, 0]) 32E01 = np.array([[0, 1], [0, 0]]) 33E10 = np.array([[0, 0], [1, 0]]) 34E11 = np.diag([0, 1]) 35PAULI_BASIS = cirq.PAULI_BASIS 36STANDARD_BASIS = {'a': E00, 'b': E01, 'c': E10, 'd': E11} 37 38 39def _one_hot_matrix(size: int, i: int, j: int) -> np.ndarray: 40 result = np.zeros((size, size)) 41 result[i, j] = 1 42 return result 43 44 45@pytest.mark.parametrize( 46 'basis1, basis2, expected_kron_basis', 47 ( 48 ( 49 PAULI_BASIS, 50 PAULI_BASIS, 51 { 52 'II': np.eye(4), 53 'IX': scipy.linalg.block_diag(X, X), 54 'IY': scipy.linalg.block_diag(Y, Y), 55 'IZ': np.diag([1, -1, 1, -1]), 56 'XI': np.array([[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]), 57 'XX': np.rot90(np.eye(4)), 58 'XY': np.rot90(np.diag([1j, -1j, 1j, -1j])), 59 'XZ': np.array([[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, -1, 0, 0]]), 60 'YI': np.array([[0, 0, -1j, 0], [0, 0, 0, -1j], [1j, 0, 0, 0], [0, 1j, 0, 0]]), 61 'YX': np.rot90(np.diag([1j, 1j, -1j, -1j])), 62 'YY': np.rot90(np.diag([-1, 1, 1, -1])), 63 'YZ': np.array([[0, 0, -1j, 0], [0, 0, 0, 1j], [1j, 0, 0, 0], [0, -1j, 0, 0]]), 64 'ZI': np.diag([1, 1, -1, -1]), 65 'ZX': scipy.linalg.block_diag(X, -X), 66 'ZY': scipy.linalg.block_diag(Y, -Y), 67 'ZZ': np.diag([1, -1, -1, 1]), 68 }, 69 ), 70 ( 71 STANDARD_BASIS, 72 STANDARD_BASIS, 73 { 74 'abcd'[2 * row_outer + col_outer] 75 + 'abcd'[2 * row_inner + col_inner]: _one_hot_matrix( 76 4, 2 * row_outer + row_inner, 2 * col_outer + col_inner 77 ) 78 for row_outer in range(2) 79 for row_inner in range(2) 80 for col_outer in range(2) 81 for col_inner in range(2) 82 }, 83 ), 84 ), 85) 86def test_kron_bases(basis1, basis2, expected_kron_basis): 87 kron_basis = cirq.kron_bases(basis1, basis2) 88 assert len(kron_basis) == 16 89 assert set(kron_basis.keys()) == set(expected_kron_basis.keys()) 90 for name in kron_basis.keys(): 91 assert np.all(kron_basis[name] == expected_kron_basis[name]) 92 93 94@pytest.mark.parametrize( 95 'basis1,basis2', 96 ( 97 (PAULI_BASIS, cirq.kron_bases(PAULI_BASIS)), 98 (STANDARD_BASIS, cirq.kron_bases(STANDARD_BASIS, repeat=1)), 99 (cirq.kron_bases(PAULI_BASIS, PAULI_BASIS), cirq.kron_bases(PAULI_BASIS, repeat=2)), 100 ( 101 cirq.kron_bases( 102 cirq.kron_bases(PAULI_BASIS, repeat=2), 103 cirq.kron_bases(PAULI_BASIS, repeat=3), 104 PAULI_BASIS, 105 ), 106 cirq.kron_bases(PAULI_BASIS, repeat=6), 107 ), 108 ( 109 cirq.kron_bases( 110 cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS), 111 cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS), 112 ), 113 cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS, repeat=2), 114 ), 115 ), 116) 117def test_kron_bases_consistency(basis1, basis2): 118 assert set(basis1.keys()) == set(basis2.keys()) 119 for name in basis1.keys(): 120 assert np.all(basis1[name] == basis2[name]) 121 122 123@pytest.mark.parametrize( 124 'basis,repeat', itertools.product((PAULI_BASIS, STANDARD_BASIS), range(1, 5)) 125) 126def test_kron_bases_repeat_sanity_checks(basis, repeat): 127 product_basis = cirq.kron_bases(basis, repeat=repeat) 128 assert len(product_basis) == 4 ** repeat 129 for name1, matrix1 in product_basis.items(): 130 for name2, matrix2 in product_basis.items(): 131 p = cirq.hilbert_schmidt_inner_product(matrix1, matrix2) 132 if name1 != name2: 133 assert p == 0 134 else: 135 assert abs(p) >= 1 136 137 138@pytest.mark.parametrize( 139 'm1,m2,expect_real', 140 ( 141 (X, X, True), 142 (X, Y, True), 143 (X, H, True), 144 (X, SQRT_X, False), 145 (I, SQRT_Z, False), 146 ), 147) 148def test_hilbert_schmidt_inner_product_is_conjugate_symmetric(m1, m2, expect_real): 149 v1 = cirq.hilbert_schmidt_inner_product(m1, m2) 150 v2 = cirq.hilbert_schmidt_inner_product(m2, m1) 151 assert v1 == v2.conjugate() 152 153 assert np.isreal(v1) == expect_real 154 if not expect_real: 155 assert v1 != v2 156 157 158@pytest.mark.parametrize( 159 'a,m1,b,m2', 160 ( 161 (1, X, 1, Z), 162 (2, X, 3, Y), 163 (2j, X, 3, I), 164 (2, X, 3, X), 165 ), 166) 167def test_hilbert_schmidt_inner_product_is_linear(a, m1, b, m2): 168 v1 = cirq.hilbert_schmidt_inner_product(H, (a * m1 + b * m2)) 169 v2 = a * cirq.hilbert_schmidt_inner_product(H, m1) + b * cirq.hilbert_schmidt_inner_product( 170 H, m2 171 ) 172 assert v1 == v2 173 174 175@pytest.mark.parametrize('m', (I, X, Y, Z, H, SQRT_X, SQRT_Y, SQRT_Z)) 176def test_hilbert_schmidt_inner_product_is_positive_definite(m): 177 v = cirq.hilbert_schmidt_inner_product(m, m) 178 assert np.isreal(v) 179 assert v.real > 0 180 181 182@pytest.mark.parametrize( 183 'm1,m2,expected_value', 184 ( 185 (X, I, 0), 186 (X, X, 2), 187 (X, Y, 0), 188 (X, Z, 0), 189 (H, X, np.sqrt(2)), 190 (H, Y, 0), 191 (H, Z, np.sqrt(2)), 192 (Z, E00, 1), 193 (Z, E01, 0), 194 (Z, E10, 0), 195 (Z, E11, -1), 196 (SQRT_X, E00, np.sqrt(-0.5j)), 197 (SQRT_X, E01, np.sqrt(0.5j)), 198 (SQRT_X, E10, np.sqrt(0.5j)), 199 (SQRT_X, E11, np.sqrt(-0.5j)), 200 ), 201) 202def test_hilbert_schmidt_inner_product_values(m1, m2, expected_value): 203 v = cirq.hilbert_schmidt_inner_product(m1, m2) 204 assert np.isclose(v, expected_value) 205 206 207@pytest.mark.parametrize( 208 'm,basis', 209 itertools.product( 210 (I, X, Y, Z, H, SQRT_X, SQRT_Y, SQRT_Z), 211 (PAULI_BASIS, STANDARD_BASIS), 212 ), 213) 214def test_expand_matrix_in_orthogonal_basis(m, basis): 215 expansion = cirq.expand_matrix_in_orthogonal_basis(m, basis) 216 217 reconstructed = np.zeros(m.shape, dtype=complex) 218 for name, coefficient in expansion.items(): 219 reconstructed += coefficient * basis[name] 220 assert np.allclose(m, reconstructed) 221 222 223@pytest.mark.parametrize( 224 'expansion', 225 ( 226 {'I': 1}, 227 {'X': 1}, 228 {'Y': 1}, 229 {'Z': 1}, 230 {'X': 1, 'Z': 1}, 231 {'I': 0.5, 'X': 0.4, 'Y': 0.3, 'Z': 0.2}, 232 {'I': 1, 'X': 2, 'Y': 3, 'Z': 4}, 233 ), 234) 235def test_matrix_from_basis_coefficients(expansion): 236 m = cirq.matrix_from_basis_coefficients(expansion, PAULI_BASIS) 237 238 for name, coefficient in expansion.items(): 239 element = PAULI_BASIS[name] 240 expected_coefficient = cirq.hilbert_schmidt_inner_product( 241 m, element 242 ) / cirq.hilbert_schmidt_inner_product(element, element) 243 assert np.isclose(coefficient, expected_coefficient) 244 245 246@pytest.mark.parametrize( 247 'm1,basis', 248 ( 249 itertools.product( 250 (I, X, Y, Z, H, SQRT_X, SQRT_Y, SQRT_Z, E00, E01, E10, E11), 251 (PAULI_BASIS, STANDARD_BASIS), 252 ) 253 ), 254) 255def test_expand_is_inverse_of_reconstruct(m1, basis): 256 c1 = cirq.expand_matrix_in_orthogonal_basis(m1, basis) 257 m2 = cirq.matrix_from_basis_coefficients(c1, basis) 258 c2 = cirq.expand_matrix_in_orthogonal_basis(m2, basis) 259 assert np.allclose(m1, m2) 260 assert c1 == c2 261 262 263@pytest.mark.parametrize( 264 'coefficients,exponent', 265 itertools.product( 266 ( 267 (0, 0, 0, 0), 268 (-1, 0, 0, 0), 269 (0.5, 0, 0, 0), 270 (0.5j, 0, 0, 0), 271 (1, 0, 0, 0), 272 (2, 0, 0, 0), 273 (0, -1, 0, 0), 274 (0, 0.5, 0, 0), 275 (0, 0.5j, 0, 0), 276 (0, 1, 0, 0), 277 (0, 2, 0, 0), 278 (0, 0, -1, 0), 279 (0, 0, 0.5, 0), 280 (0, 0, 0.5j, 0), 281 (0, 0, 1, 0), 282 (0, 0, 2, 0), 283 (0, 0, 0, -1), 284 (0, 0, 0, 0.5), 285 (0, 0, 0, 0.5j), 286 (0, 0, 0, 1), 287 (0, 0, 0, 2), 288 (0, -1, 0, -1), 289 (0, 1, 0, 1j), 290 (0, 0.5, 0, 0.5), 291 (0, 0.5j, 0, 0.5j), 292 (0, 0.5, 0, 0.5j), 293 (0, 1, 0, 1), 294 (0, 2, 0, 2), 295 (0, 0.5, 0.5, 0.5), 296 (0, 1, 1, 1), 297 (0, 1.1j, 0.5 - 0.4j, 0.9), 298 (0.7j, 1.1j, 0.5 - 0.4j, 0.9), 299 (0.25, 0.25, 0.25, 0.25), 300 (0.25j, 0.25j, 0.25j, 0.25j), 301 (0.4, 0, 0.5, 0), 302 (1, 2, 3, 4), 303 (-1, -2, -3, -4), 304 (-1, -2, 3, 4), 305 (1j, 2j, 3j, 4j), 306 (1j, 2j, 3, 4), 307 ), 308 (0, 1, 2, 3, 4, 5, 100, 101), 309 ), 310) 311def test_pow_pauli_combination(coefficients, exponent): 312 i = cirq.PAULI_BASIS['I'] 313 x = cirq.PAULI_BASIS['X'] 314 y = cirq.PAULI_BASIS['Y'] 315 z = cirq.PAULI_BASIS['Z'] 316 ai, ax, ay, az = coefficients 317 318 matrix = ai * i + ax * x + ay * y + az * z 319 expected_result = np.linalg.matrix_power(matrix, exponent) 320 321 bi, bx, by, bz = cirq.pow_pauli_combination(ai, ax, ay, az, exponent) 322 result = bi * i + bx * x + by * y + bz * z 323 324 assert np.allclose(result, expected_result) 325