1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import cmath 16import numpy as np 17import pytest 18 19import cirq 20from cirq.linalg import matrix_commutes 21 22 23def test_is_diagonal(): 24 assert cirq.is_diagonal(np.empty((0, 0))) 25 assert cirq.is_diagonal(np.empty((1, 0))) 26 assert cirq.is_diagonal(np.empty((0, 1))) 27 28 assert cirq.is_diagonal(np.array([[1]])) 29 assert cirq.is_diagonal(np.array([[-1]])) 30 assert cirq.is_diagonal(np.array([[5]])) 31 assert cirq.is_diagonal(np.array([[3j]])) 32 33 assert cirq.is_diagonal(np.array([[1, 0]])) 34 assert cirq.is_diagonal(np.array([[1], [0]])) 35 assert not cirq.is_diagonal(np.array([[1, 1]])) 36 assert not cirq.is_diagonal(np.array([[1], [1]])) 37 38 assert cirq.is_diagonal(np.array([[5j, 0], [0, 2]])) 39 assert cirq.is_diagonal(np.array([[1, 0], [0, 1]])) 40 assert not cirq.is_diagonal(np.array([[1, 0], [1, 1]])) 41 assert not cirq.is_diagonal(np.array([[1, 1], [0, 1]])) 42 assert not cirq.is_diagonal(np.array([[1, 1], [1, 1]])) 43 assert not cirq.is_diagonal(np.array([[1, 0.1], [0.1, 1]])) 44 45 assert cirq.is_diagonal(np.array([[1, 1e-11], [1e-10, 1]])) 46 47 48def test_is_diagonal_tolerance(): 49 atol = 0.5 50 51 # Pays attention to specified tolerance. 52 assert cirq.is_diagonal(np.array([[1, 0], [-0.5, 1]]), atol=atol) 53 assert not cirq.is_diagonal(np.array([[1, 0], [-0.6, 1]]), atol=atol) 54 55 # Error isn't accumulated across entries. 56 assert cirq.is_diagonal(np.array([[1, 0.5], [-0.5, 1]]), atol=atol) 57 assert not cirq.is_diagonal(np.array([[1, 0.5], [-0.6, 1]]), atol=atol) 58 59 60def test_is_hermitian(): 61 assert cirq.is_hermitian(np.empty((0, 0))) 62 assert not cirq.is_hermitian(np.empty((1, 0))) 63 assert not cirq.is_hermitian(np.empty((0, 1))) 64 65 assert cirq.is_hermitian(np.array([[1]])) 66 assert cirq.is_hermitian(np.array([[-1]])) 67 assert cirq.is_hermitian(np.array([[5]])) 68 assert not cirq.is_hermitian(np.array([[3j]])) 69 70 assert not cirq.is_hermitian(np.array([[0, 0]])) 71 assert not cirq.is_hermitian(np.array([[0], [0]])) 72 73 assert not cirq.is_hermitian(np.array([[5j, 0], [0, 2]])) 74 assert cirq.is_hermitian(np.array([[5, 0], [0, 2]])) 75 assert cirq.is_hermitian(np.array([[1, 0], [0, 1]])) 76 assert not cirq.is_hermitian(np.array([[1, 0], [1, 1]])) 77 assert not cirq.is_hermitian(np.array([[1, 1], [0, 1]])) 78 assert cirq.is_hermitian(np.array([[1, 1], [1, 1]])) 79 assert cirq.is_hermitian(np.array([[1, 1j], [-1j, 1]])) 80 assert cirq.is_hermitian(np.array([[1, 1j], [-1j, 1]]) * np.sqrt(0.5)) 81 assert not cirq.is_hermitian(np.array([[1, 1j], [1j, 1]])) 82 assert not cirq.is_hermitian(np.array([[1, 0.1], [-0.1, 1]])) 83 84 assert cirq.is_hermitian(np.array([[1, 1j + 1e-11], [-1j, 1 + 1j * 1e-9]])) 85 86 87def test_is_hermitian_tolerance(): 88 atol = 0.5 89 90 # Pays attention to specified tolerance. 91 assert cirq.is_hermitian(np.array([[1, 0], [-0.5, 1]]), atol=atol) 92 assert cirq.is_hermitian(np.array([[1, 0.25], [-0.25, 1]]), atol=atol) 93 assert not cirq.is_hermitian(np.array([[1, 0], [-0.6, 1]]), atol=atol) 94 assert not cirq.is_hermitian(np.array([[1, 0.25], [-0.35, 1]]), atol=atol) 95 96 # Error isn't accumulated across entries. 97 assert cirq.is_hermitian(np.array([[1, 0.5, 0.5], [0, 1, 0], [0, 0, 1]]), atol=atol) 98 assert not cirq.is_hermitian(np.array([[1, 0.5, 0.6], [0, 1, 0], [0, 0, 1]]), atol=atol) 99 assert not cirq.is_hermitian(np.array([[1, 0, 0.6], [0, 1, 0], [0, 0, 1]]), atol=atol) 100 101 102def test_is_unitary(): 103 assert cirq.is_unitary(np.empty((0, 0))) 104 assert not cirq.is_unitary(np.empty((1, 0))) 105 assert not cirq.is_unitary(np.empty((0, 1))) 106 107 assert cirq.is_unitary(np.array([[1]])) 108 assert cirq.is_unitary(np.array([[-1]])) 109 assert cirq.is_unitary(np.array([[1j]])) 110 assert not cirq.is_unitary(np.array([[5]])) 111 assert not cirq.is_unitary(np.array([[3j]])) 112 113 assert not cirq.is_unitary(np.array([[1, 0]])) 114 assert not cirq.is_unitary(np.array([[1], [0]])) 115 116 assert not cirq.is_unitary(np.array([[1, 0], [0, -2]])) 117 assert cirq.is_unitary(np.array([[1, 0], [0, -1]])) 118 assert cirq.is_unitary(np.array([[1j, 0], [0, 1]])) 119 assert not cirq.is_unitary(np.array([[1, 0], [1, 1]])) 120 assert not cirq.is_unitary(np.array([[1, 1], [0, 1]])) 121 assert not cirq.is_unitary(np.array([[1, 1], [1, 1]])) 122 assert not cirq.is_unitary(np.array([[1, -1], [1, 1]])) 123 assert cirq.is_unitary(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5)) 124 assert cirq.is_unitary(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5)) 125 assert not cirq.is_unitary(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5)) 126 127 assert cirq.is_unitary(np.array([[1, 1j + 1e-11], [1j, 1 + 1j * 1e-9]]) * np.sqrt(0.5)) 128 129 130def test_is_unitary_tolerance(): 131 atol = 0.5 132 133 # Pays attention to specified tolerance. 134 assert cirq.is_unitary(np.array([[1, 0], [-0.5, 1]]), atol=atol) 135 assert not cirq.is_unitary(np.array([[1, 0], [-0.6, 1]]), atol=atol) 136 137 # Error isn't accumulated across entries. 138 assert cirq.is_unitary(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol) 139 assert not cirq.is_unitary(np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1.2]]), atol=atol) 140 141 142def test_is_orthogonal(): 143 assert cirq.is_orthogonal(np.empty((0, 0))) 144 assert not cirq.is_orthogonal(np.empty((1, 0))) 145 assert not cirq.is_orthogonal(np.empty((0, 1))) 146 147 assert cirq.is_orthogonal(np.array([[1]])) 148 assert cirq.is_orthogonal(np.array([[-1]])) 149 assert not cirq.is_orthogonal(np.array([[1j]])) 150 assert not cirq.is_orthogonal(np.array([[5]])) 151 assert not cirq.is_orthogonal(np.array([[3j]])) 152 153 assert not cirq.is_orthogonal(np.array([[1, 0]])) 154 assert not cirq.is_orthogonal(np.array([[1], [0]])) 155 156 assert not cirq.is_orthogonal(np.array([[1, 0], [0, -2]])) 157 assert cirq.is_orthogonal(np.array([[1, 0], [0, -1]])) 158 assert not cirq.is_orthogonal(np.array([[1j, 0], [0, 1]])) 159 assert not cirq.is_orthogonal(np.array([[1, 0], [1, 1]])) 160 assert not cirq.is_orthogonal(np.array([[1, 1], [0, 1]])) 161 assert not cirq.is_orthogonal(np.array([[1, 1], [1, 1]])) 162 assert not cirq.is_orthogonal(np.array([[1, -1], [1, 1]])) 163 assert cirq.is_orthogonal(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5)) 164 assert not cirq.is_orthogonal(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5)) 165 assert not cirq.is_orthogonal(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5)) 166 167 assert cirq.is_orthogonal(np.array([[1, 1e-11], [0, 1 + 1e-11]])) 168 169 170def test_is_orthogonal_tolerance(): 171 atol = 0.5 172 173 # Pays attention to specified tolerance. 174 assert cirq.is_orthogonal(np.array([[1, 0], [-0.5, 1]]), atol=atol) 175 assert not cirq.is_orthogonal(np.array([[1, 0], [-0.6, 1]]), atol=atol) 176 177 # Error isn't accumulated across entries. 178 assert cirq.is_orthogonal(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol) 179 assert not cirq.is_orthogonal(np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1.2]]), atol=atol) 180 181 182def test_is_special_orthogonal(): 183 assert cirq.is_special_orthogonal(np.empty((0, 0))) 184 assert not cirq.is_special_orthogonal(np.empty((1, 0))) 185 assert not cirq.is_special_orthogonal(np.empty((0, 1))) 186 187 assert cirq.is_special_orthogonal(np.array([[1]])) 188 assert not cirq.is_special_orthogonal(np.array([[-1]])) 189 assert not cirq.is_special_orthogonal(np.array([[1j]])) 190 assert not cirq.is_special_orthogonal(np.array([[5]])) 191 assert not cirq.is_special_orthogonal(np.array([[3j]])) 192 193 assert not cirq.is_special_orthogonal(np.array([[1, 0]])) 194 assert not cirq.is_special_orthogonal(np.array([[1], [0]])) 195 196 assert not cirq.is_special_orthogonal(np.array([[1, 0], [0, -2]])) 197 assert not cirq.is_special_orthogonal(np.array([[1, 0], [0, -1]])) 198 assert cirq.is_special_orthogonal(np.array([[-1, 0], [0, -1]])) 199 assert not cirq.is_special_orthogonal(np.array([[1j, 0], [0, 1]])) 200 assert not cirq.is_special_orthogonal(np.array([[1, 0], [1, 1]])) 201 assert not cirq.is_special_orthogonal(np.array([[1, 1], [0, 1]])) 202 assert not cirq.is_special_orthogonal(np.array([[1, 1], [1, 1]])) 203 assert not cirq.is_special_orthogonal(np.array([[1, -1], [1, 1]])) 204 assert cirq.is_special_orthogonal(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5)) 205 assert not cirq.is_special_orthogonal(np.array([[1, 1], [1, -1]]) * np.sqrt(0.5)) 206 assert not cirq.is_special_orthogonal(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5)) 207 assert not cirq.is_special_orthogonal(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5)) 208 209 assert cirq.is_special_orthogonal(np.array([[1, 1e-11], [0, 1 + 1e-11]])) 210 211 212def test_is_special_orthogonal_tolerance(): 213 atol = 0.5 214 215 # Pays attention to specified tolerance. 216 assert cirq.is_special_orthogonal(np.array([[1, 0], [-0.5, 1]]), atol=atol) 217 assert not cirq.is_special_orthogonal(np.array([[1, 0], [-0.6, 1]]), atol=atol) 218 219 # Error isn't accumulated across entries, except for determinant factors. 220 assert cirq.is_special_orthogonal( 221 np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1 / 1.2]]), atol=atol 222 ) 223 assert not cirq.is_special_orthogonal( 224 np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol 225 ) 226 assert not cirq.is_special_orthogonal( 227 np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1 / 1.2]]), atol=atol 228 ) 229 230 231def test_is_special_unitary(): 232 assert cirq.is_special_unitary(np.empty((0, 0))) 233 assert not cirq.is_special_unitary(np.empty((1, 0))) 234 assert not cirq.is_special_unitary(np.empty((0, 1))) 235 236 assert cirq.is_special_unitary(np.array([[1]])) 237 assert not cirq.is_special_unitary(np.array([[-1]])) 238 assert not cirq.is_special_unitary(np.array([[5]])) 239 assert not cirq.is_special_unitary(np.array([[3j]])) 240 241 assert not cirq.is_special_unitary(np.array([[1, 0], [0, -2]])) 242 assert not cirq.is_special_unitary(np.array([[1, 0], [0, -1]])) 243 assert cirq.is_special_unitary(np.array([[-1, 0], [0, -1]])) 244 assert not cirq.is_special_unitary(np.array([[1j, 0], [0, 1]])) 245 assert cirq.is_special_unitary(np.array([[1j, 0], [0, -1j]])) 246 assert not cirq.is_special_unitary(np.array([[1, 0], [1, 1]])) 247 assert not cirq.is_special_unitary(np.array([[1, 1], [0, 1]])) 248 assert not cirq.is_special_unitary(np.array([[1, 1], [1, 1]])) 249 assert not cirq.is_special_unitary(np.array([[1, -1], [1, 1]])) 250 assert cirq.is_special_unitary(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5)) 251 assert cirq.is_special_unitary(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5)) 252 assert not cirq.is_special_unitary(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5)) 253 254 assert cirq.is_special_unitary(np.array([[1, 1j + 1e-11], [1j, 1 + 1j * 1e-9]]) * np.sqrt(0.5)) 255 256 257def test_is_special_unitary_tolerance(): 258 atol = 0.5 259 260 # Pays attention to specified tolerance. 261 assert cirq.is_special_unitary(np.array([[1, 0], [-0.5, 1]]), atol=atol) 262 assert not cirq.is_special_unitary(np.array([[1, 0], [-0.6, 1]]), atol=atol) 263 assert cirq.is_special_unitary(np.array([[1, 0], [0, 1]]) * cmath.exp(1j * 0.1), atol=atol) 264 assert not cirq.is_special_unitary(np.array([[1, 0], [0, 1]]) * cmath.exp(1j * 0.3), atol=atol) 265 266 # Error isn't accumulated across entries, except for determinant factors. 267 assert cirq.is_special_unitary(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1 / 1.2]]), atol=atol) 268 assert not cirq.is_special_unitary(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol) 269 assert not cirq.is_special_unitary( 270 np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1 / 1.2]]), atol=atol 271 ) 272 273 274def test_is_normal(): 275 assert cirq.is_normal(np.array([[1]])) 276 assert cirq.is_normal(np.array([[3j]])) 277 assert cirq.is_normal(cirq.testing.random_density_matrix(4)) 278 assert cirq.is_normal(cirq.testing.random_unitary(5)) 279 assert not cirq.is_normal(np.array([[0, 1], [0, 0]])) 280 assert not cirq.is_normal(np.zeros((1, 0))) 281 282 283def test_is_normal_tolerance(): 284 atol = 0.25 285 286 # Pays attention to specified tolerance. 287 assert cirq.is_normal(np.array([[0, 0.5], [0, 0]]), atol=atol) 288 assert not cirq.is_normal(np.array([[0, 0.6], [0, 0]]), atol=atol) 289 290 # Error isn't accumulated across entries. 291 assert cirq.is_normal(np.array([[0, 0.5, 0], [0, 0, 0.5], [0, 0, 0]]), atol=atol) 292 assert not cirq.is_normal(np.array([[0, 0.5, 0], [0, 0, 0.6], [0, 0, 0]]), atol=atol) 293 294 295def test_is_cptp(): 296 rt2 = np.sqrt(0.5) 297 # Amplitude damping with gamma=0.5. 298 assert cirq.is_cptp(kraus_ops=[np.array([[1, 0], [0, rt2]]), np.array([[0, rt2], [0, 0]])]) 299 # Depolarizing channel with p=0.75. 300 assert cirq.is_cptp( 301 kraus_ops=[ 302 np.array([[1, 0], [0, 1]]) * 0.5, 303 np.array([[0, 1], [1, 0]]) * 0.5, 304 np.array([[0, -1j], [1j, 0]]) * 0.5, 305 np.array([[1, 0], [0, -1]]) * 0.5, 306 ] 307 ) 308 309 assert not cirq.is_cptp(kraus_ops=[np.array([[1, 0], [0, 1]]), np.array([[0, 1], [0, 0]])]) 310 assert not cirq.is_cptp( 311 kraus_ops=[ 312 np.array([[1, 0], [0, 1]]), 313 np.array([[0, 1], [1, 0]]), 314 np.array([[0, -1j], [1j, 0]]), 315 np.array([[1, 0], [0, -1]]), 316 ] 317 ) 318 319 # Makes 4 2x2 kraus ops. 320 one_qubit_u = cirq.testing.random_unitary(8) 321 one_qubit_kraus = np.reshape(one_qubit_u[:, :2], (-1, 2, 2)) 322 assert cirq.is_cptp(kraus_ops=one_qubit_kraus) 323 324 # Makes 16 4x4 kraus ops. 325 two_qubit_u = cirq.testing.random_unitary(64) 326 two_qubit_kraus = np.reshape(two_qubit_u[:, :4], (-1, 4, 4)) 327 assert cirq.is_cptp(kraus_ops=two_qubit_kraus) 328 329 330def test_is_cptp_tolerance(): 331 rt2_ish = np.sqrt(0.5) - 0.01 332 atol = 0.25 333 # Moderately-incorrect amplitude damping with gamma=0.5. 334 assert cirq.is_cptp( 335 kraus_ops=[np.array([[1, 0], [0, rt2_ish]]), np.array([[0, rt2_ish], [0, 0]])], atol=atol 336 ) 337 assert not cirq.is_cptp( 338 kraus_ops=[np.array([[1, 0], [0, rt2_ish]]), np.array([[0, rt2_ish], [0, 0]])], atol=1e-8 339 ) 340 341 342def test_commutes(): 343 assert matrix_commutes(np.empty((0, 0)), np.empty((0, 0))) 344 assert not matrix_commutes(np.empty((1, 0)), np.empty((0, 1))) 345 assert not matrix_commutes(np.empty((0, 1)), np.empty((1, 0))) 346 assert not matrix_commutes(np.empty((1, 0)), np.empty((1, 0))) 347 assert not matrix_commutes(np.empty((0, 1)), np.empty((0, 1))) 348 349 assert matrix_commutes(np.array([[1]]), np.array([[2]])) 350 assert matrix_commutes(np.array([[1]]), np.array([[0]])) 351 352 x = np.array([[0, 1], [1, 0]]) 353 y = np.array([[0, -1j], [1j, 0]]) 354 z = np.array([[1, 0], [0, -1]]) 355 xx = np.kron(x, x) 356 zz = np.kron(z, z) 357 358 assert matrix_commutes(x, x) 359 assert matrix_commutes(y, y) 360 assert matrix_commutes(z, z) 361 assert not matrix_commutes(x, y) 362 assert not matrix_commutes(x, z) 363 assert not matrix_commutes(y, z) 364 365 assert matrix_commutes(xx, zz) 366 assert matrix_commutes(xx, np.diag([1, -1, -1, 1 + 1e-9])) 367 368 369def test_commutes_tolerance(): 370 atol = 0.5 371 372 x = np.array([[0, 1], [1, 0]]) 373 z = np.array([[1, 0], [0, -1]]) 374 375 # Pays attention to specified tolerance. 376 assert matrix_commutes(x, x + z * 0.1, atol=atol) 377 assert not matrix_commutes(x, x + z * 0.5, atol=atol) 378 379 380def test_allclose_up_to_global_phase(): 381 assert cirq.allclose_up_to_global_phase(np.array([1]), np.array([1j])) 382 383 assert not cirq.allclose_up_to_global_phase(np.array([[[1]]]), np.array([1])) 384 385 assert cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[1]])) 386 assert cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[-1]])) 387 388 assert cirq.allclose_up_to_global_phase(np.array([[0]]), np.array([[0]])) 389 390 assert cirq.allclose_up_to_global_phase(np.array([[1, 2]]), np.array([[1j, 2j]])) 391 392 assert cirq.allclose_up_to_global_phase(np.array([[1, 2.0000000001]]), np.array([[1j, 2j]])) 393 394 assert not cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[1, 0]])) 395 assert not cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[2]])) 396 assert not cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[2]])) 397 398 399def test_binary_sub_tensor_slice(): 400 a = slice(None) 401 e = Ellipsis 402 403 assert cirq.slice_for_qubits_equal_to([], 0) == (e,) 404 assert cirq.slice_for_qubits_equal_to([0], 0b0) == (0, e) 405 assert cirq.slice_for_qubits_equal_to([0], 0b1) == (1, e) 406 assert cirq.slice_for_qubits_equal_to([1], 0b0) == (a, 0, e) 407 assert cirq.slice_for_qubits_equal_to([1], 0b1) == (a, 1, e) 408 assert cirq.slice_for_qubits_equal_to([2], 0b0) == (a, a, 0, e) 409 assert cirq.slice_for_qubits_equal_to([2], 0b1) == (a, a, 1, e) 410 411 assert cirq.slice_for_qubits_equal_to([0, 1], 0b00) == (0, 0, e) 412 assert cirq.slice_for_qubits_equal_to([1, 2], 0b00) == (a, 0, 0, e) 413 assert cirq.slice_for_qubits_equal_to([1, 3], 0b00) == (a, 0, a, 0, e) 414 assert cirq.slice_for_qubits_equal_to([1, 3], 0b10) == (a, 0, a, 1, e) 415 assert cirq.slice_for_qubits_equal_to([3, 1], 0b10) == (a, 1, a, 0, e) 416 417 assert cirq.slice_for_qubits_equal_to([2, 1, 0], 0b001) == (0, 0, 1, e) 418 assert cirq.slice_for_qubits_equal_to([2, 1, 0], 0b010) == (0, 1, 0, e) 419 assert cirq.slice_for_qubits_equal_to([2, 1, 0], 0b100) == (1, 0, 0, e) 420 assert cirq.slice_for_qubits_equal_to([0, 1, 2], 0b101) == (1, 0, 1, e) 421 assert cirq.slice_for_qubits_equal_to([0, 2, 1], 0b101) == (1, 1, 0, e) 422 423 m = np.array([0] * 16).reshape((2, 2, 2, 2)) 424 for k in range(16): 425 m[cirq.slice_for_qubits_equal_to([3, 2, 1, 0], k)] = k 426 assert list(m.reshape(16)) == list(range(16)) 427 428 assert cirq.slice_for_qubits_equal_to([0], 0b1, num_qubits=1) == (1,) 429 assert cirq.slice_for_qubits_equal_to([1], 0b0, num_qubits=2) == (a, 0) 430 assert cirq.slice_for_qubits_equal_to([1], 0b0, num_qubits=3) == (a, 0, a) 431 assert cirq.slice_for_qubits_equal_to([2], 0b0, num_qubits=3) == (a, a, 0) 432 433 434def test_binary_sub_tensor_slice_big_endian(): 435 a = slice(None) 436 e = Ellipsis 437 sfqet = cirq.slice_for_qubits_equal_to 438 439 assert sfqet([], big_endian_qureg_value=0) == (e,) 440 assert sfqet([0], big_endian_qureg_value=0b0) == (0, e) 441 assert sfqet([0], big_endian_qureg_value=0b1) == (1, e) 442 assert sfqet([1], big_endian_qureg_value=0b0) == (a, 0, e) 443 assert sfqet([1], big_endian_qureg_value=0b1) == (a, 1, e) 444 assert sfqet([2], big_endian_qureg_value=0b0) == (a, a, 0, e) 445 assert sfqet([2], big_endian_qureg_value=0b1) == (a, a, 1, e) 446 447 assert sfqet([0, 1], big_endian_qureg_value=0b00) == (0, 0, e) 448 assert sfqet([1, 2], big_endian_qureg_value=0b00) == (a, 0, 0, e) 449 assert sfqet([1, 3], big_endian_qureg_value=0b00) == (a, 0, a, 0, e) 450 assert sfqet([1, 3], big_endian_qureg_value=0b01) == (a, 0, a, 1, e) 451 assert sfqet([3, 1], big_endian_qureg_value=0b01) == (a, 1, a, 0, e) 452 453 assert sfqet([2, 1, 0], big_endian_qureg_value=0b100) == (0, 0, 1, e) 454 assert sfqet([2, 1, 0], big_endian_qureg_value=0b010) == (0, 1, 0, e) 455 assert sfqet([2, 1, 0], big_endian_qureg_value=0b001) == (1, 0, 0, e) 456 assert sfqet([0, 1, 2], big_endian_qureg_value=0b101) == (1, 0, 1, e) 457 assert sfqet([0, 2, 1], big_endian_qureg_value=0b101) == (1, 1, 0, e) 458 459 m = np.array([0] * 16).reshape((2, 2, 2, 2)) 460 for k in range(16): 461 m[sfqet([0, 1, 2, 3], big_endian_qureg_value=k)] = k 462 assert list(m.reshape(16)) == list(range(16)) 463 464 assert sfqet([0], big_endian_qureg_value=0b1, num_qubits=1) == (1,) 465 assert sfqet([1], big_endian_qureg_value=0b0, num_qubits=2) == (a, 0) 466 assert sfqet([1], big_endian_qureg_value=0b0, num_qubits=3) == (a, 0, a) 467 assert sfqet([2], big_endian_qureg_value=0b0, num_qubits=3) == (a, a, 0) 468 469 470def test_qudit_sub_tensor_slice(): 471 a = slice(None) 472 sfqet = cirq.slice_for_qubits_equal_to 473 474 assert sfqet([], 0, qid_shape=()) == () 475 assert sfqet([0], 0, qid_shape=(3,)) == (0,) 476 assert sfqet([0], 1, qid_shape=(3,)) == (1,) 477 assert sfqet([0], 2, qid_shape=(3,)) == (2,) 478 assert sfqet([2], 0, qid_shape=(1, 2, 3)) == (a, a, 0) 479 assert sfqet([2], 2, qid_shape=(1, 2, 3)) == (a, a, 2) 480 assert sfqet([2], big_endian_qureg_value=2, qid_shape=(1, 2, 3)) == (a, a, 2) 481 482 assert sfqet([1, 3], 3 * 2 + 1, qid_shape=(2, 3, 4, 5)) == (a, 1, a, 2) 483 assert sfqet([3, 1], 5 * 2 + 1, qid_shape=(2, 3, 4, 5)) == (a, 2, a, 1) 484 assert sfqet([2, 1, 0], 9 * 2 + 3 * 1, qid_shape=(3,) * 3) == (2, 1, 0) 485 assert sfqet([1, 3], big_endian_qureg_value=5 * 1 + 2, qid_shape=(2, 3, 4, 5)) == (a, 1, a, 2) 486 assert sfqet([3, 1], big_endian_qureg_value=3 * 1 + 2, qid_shape=(2, 3, 4, 5)) == (a, 2, a, 1) 487 488 m = np.array([0] * 24).reshape((1, 2, 3, 4)) 489 for k in range(24): 490 m[sfqet([3, 2, 1, 0], k, qid_shape=(1, 2, 3, 4))] = k 491 assert list(m.reshape(24)) == list(range(24)) 492 493 assert sfqet([0], 1, num_qubits=1, qid_shape=(3,)) == (1,) 494 assert sfqet([1], 0, num_qubits=3, qid_shape=(3, 3, 3)) == (a, 0, a) 495 496 with pytest.raises(ValueError, match='len.* !='): 497 sfqet([], num_qubits=2, qid_shape=(1, 2, 3)) 498 499 with pytest.raises(ValueError, match='exactly one'): 500 sfqet([0, 1, 2], 0b101, big_endian_qureg_value=0b101) 501