1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17 18import cirq 19from cirq.protocols.apply_unitary_protocol import ( 20 _incorporate_result_into_target, 21) 22 23 24def test_apply_unitary_presence_absence(): 25 m = np.diag([1, -1]) 26 27 class NoUnitaryEffect: 28 pass 29 30 class HasUnitary: 31 def _unitary_(self) -> np.ndarray: 32 return m 33 34 class HasApplyReturnsNotImplemented: 35 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): 36 return NotImplemented 37 38 class HasApplyReturnsNotImplementedButHasUnitary: 39 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): 40 return NotImplemented 41 42 def _unitary_(self) -> np.ndarray: 43 return m 44 45 class HasApplyOutputInBuffer: 46 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 47 zero = args.subspace_index(0) 48 one = args.subspace_index(1) 49 args.available_buffer[zero] = args.target_tensor[zero] 50 args.available_buffer[one] = -args.target_tensor[one] 51 return args.available_buffer 52 53 class HasApplyMutateInline: 54 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 55 one = args.subspace_index(1) 56 args.target_tensor[one] *= -1 57 return args.target_tensor 58 59 fails = [ 60 NoUnitaryEffect(), 61 HasApplyReturnsNotImplemented(), 62 ] 63 passes = [ 64 HasUnitary(), 65 HasApplyReturnsNotImplementedButHasUnitary(), 66 HasApplyOutputInBuffer(), 67 HasApplyMutateInline(), 68 ] 69 70 def make_input(): 71 return np.ones((2, 2)) 72 73 def assert_works(val): 74 expected_outputs = [ 75 np.array([1, 1, -1, -1]).reshape((2, 2)), 76 np.array([1, -1, 1, -1]).reshape((2, 2)), 77 ] 78 for axis in range(2): 79 result = cirq.apply_unitary(val, cirq.ApplyUnitaryArgs(make_input(), buf, [axis])) 80 np.testing.assert_allclose(result, expected_outputs[axis]) 81 82 buf = np.empty(shape=(2, 2), dtype=np.complex128) 83 84 for f in fails: 85 with pytest.raises(TypeError, match='failed to satisfy'): 86 _ = cirq.apply_unitary(f, cirq.ApplyUnitaryArgs(make_input(), buf, [0])) 87 assert ( 88 cirq.apply_unitary(f, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=None) 89 is None 90 ) 91 assert ( 92 cirq.apply_unitary( 93 f, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=NotImplemented 94 ) 95 is NotImplemented 96 ) 97 assert cirq.apply_unitary(f, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=1) == 1 98 99 for s in passes: 100 assert_works(s) 101 assert ( 102 cirq.apply_unitary(s, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=None) 103 is not None 104 ) 105 106 107def test_apply_unitary_args_tensor_manipulation(): 108 # All below are qubit swap operations with 1j global phase 109 110 class ModifyTargetTensor: 111 def _apply_unitary_(self, args): 112 zo = args.subspace_index(0b01) 113 oz = args.subspace_index(0b10) 114 args.available_buffer[zo] = args.target_tensor[zo] 115 args.target_tensor[zo] = args.target_tensor[oz] 116 args.target_tensor[oz] = args.available_buffer[zo] 117 args.target_tensor[...] *= 1j 118 args.available_buffer[...] = 99 # Destroy buffer data just in case 119 return args.target_tensor 120 121 class TransposeTargetTensor: 122 def _apply_unitary_(self, args): 123 indices = list(range(len(args.target_tensor.shape))) 124 indices[args.axes[0]], indices[args.axes[1]] = ( 125 indices[args.axes[1]], 126 indices[args.axes[0]], 127 ) 128 target = args.target_tensor.transpose(*indices) 129 target[...] *= 1j 130 args.available_buffer[...] = 99 # Destroy buffer data just in case 131 return target 132 133 class ReshapeTargetTensor: 134 def _apply_unitary_(self, args): 135 zz = args.subspace_index(0b00) 136 zo = args.subspace_index(0b01) 137 oz = args.subspace_index(0b10) 138 oo = args.subspace_index(0b11) 139 args.available_buffer[zz] = args.target_tensor[zz] 140 args.available_buffer[zo] = args.target_tensor[zo] 141 args.available_buffer[oz] = args.target_tensor[oz] 142 args.available_buffer[oo] = args.target_tensor[oo] 143 # Do a pointless reshape and transpose 144 target = args.target_tensor.transpose( 145 *range(1, len(args.target_tensor.shape)), 0 146 ).reshape(args.target_tensor.shape) 147 target[zz] = args.available_buffer[zz] 148 target[zo] = args.available_buffer[oz] 149 target[oz] = args.available_buffer[zo] 150 target[oo] = args.available_buffer[oo] 151 target[...] *= 1j 152 args.available_buffer[...] = 99 # Destroy buffer data just in case 153 return target 154 155 class ModifyAvailableBuffer: 156 def _apply_unitary_(self, args): 157 zz = args.subspace_index(0b00) 158 zo = args.subspace_index(0b01) 159 oz = args.subspace_index(0b10) 160 oo = args.subspace_index(0b11) 161 args.available_buffer[zz] = args.target_tensor[zz] 162 args.available_buffer[zo] = args.target_tensor[oz] 163 args.available_buffer[oz] = args.target_tensor[zo] 164 args.available_buffer[oo] = args.target_tensor[oo] 165 args.available_buffer[...] *= 1j 166 args.target_tensor[...] = 99 # Destroy buffer data just in case 167 return args.available_buffer 168 169 class TransposeAvailableBuffer: 170 def _apply_unitary_(self, args): 171 indices = list(range(len(args.target_tensor.shape))) 172 indices[args.axes[0]], indices[args.axes[1]] = ( 173 indices[args.axes[1]], 174 indices[args.axes[0]], 175 ) 176 output = args.available_buffer.transpose(*indices) 177 args.available_buffer[...] = args.target_tensor 178 output *= 1j 179 args.target_tensor[...] = 99 # Destroy buffer data just in case 180 return output 181 182 class ReshapeAvailableBuffer: 183 def _apply_unitary_(self, args): 184 zz = args.subspace_index(0b00) 185 zo = args.subspace_index(0b01) 186 oz = args.subspace_index(0b10) 187 oo = args.subspace_index(0b11) 188 # Do a pointless reshape and transpose 189 output = args.available_buffer.transpose( 190 *range(1, len(args.available_buffer.shape)), 0 191 ).reshape(args.available_buffer.shape) 192 output[zz] = args.target_tensor[zz] 193 output[zo] = args.target_tensor[oz] 194 output[oz] = args.target_tensor[zo] 195 output[oo] = args.target_tensor[oo] 196 output[...] *= 1j 197 args.target_tensor[...] = 99 # Destroy buffer data just in case 198 return output 199 200 class CreateNewBuffer: 201 def _apply_unitary_(self, args): 202 u = ( 203 np.array( 204 [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], 205 dtype=args.target_tensor.dtype, 206 ) 207 * 1j 208 ) # yapf: disable 209 # Flatten last two axes and add a dummy index to the end of 210 # target_tensor so np.matmul treats it like an array of two-qubit 211 # column vectors. 212 new_shape = args.target_tensor.shape[:-2] + (4, 1) 213 ret = np.matmul(u, args.target_tensor.reshape(new_shape)).reshape( 214 args.target_tensor.shape 215 ) 216 args.target_tensor[...] = 99 # Destroy buffer data just in case 217 args.available_buffer[...] = 98 218 return ret 219 220 operations = [ 221 ModifyTargetTensor(), 222 TransposeTargetTensor(), 223 ReshapeTargetTensor(), 224 ModifyAvailableBuffer(), 225 TransposeAvailableBuffer(), 226 ReshapeAvailableBuffer(), 227 CreateNewBuffer(), 228 ] 229 230 def assert_is_swap_simple(val: cirq.SupportsConsistentApplyUnitary) -> None: 231 qid_shape = (2, 2) 232 op_indices = [0, 1] 233 state = np.arange(3 * 3, dtype=np.complex64).reshape((1, 3, 3)) 234 expected = state.copy() 235 buf = expected[..., 0, 1].copy() 236 expected[..., 0, 1] = expected[..., 1, 0] 237 expected[..., 1, 0] = buf 238 expected[..., :2, :2] *= 1j 239 240 args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [1, 2]) 241 sub_args = args._for_operation_with_qid_shape( 242 op_indices, tuple(qid_shape[i] for i in op_indices) 243 ) 244 sub_result = val._apply_unitary_(sub_args) 245 result = _incorporate_result_into_target(args, sub_args, sub_result) 246 np.testing.assert_allclose(result, expected, atol=1e-8) 247 248 def assert_is_swap(val: cirq.SupportsConsistentApplyUnitary) -> None: 249 qid_shape = (1, 2, 4, 2) 250 op_indices = [1, 3] 251 state = np.arange(2 * (1 * 3 * 4 * 5), dtype=np.complex64).reshape((1, 2, 1, 5, 3, 1, 4)) 252 expected = state.copy() 253 buf = expected[..., 0, 1, :, :].copy() 254 expected[..., 0, 1, :, :] = expected[..., 1, 0, :, :] 255 expected[..., 1, 0, :, :] = buf 256 expected[..., :2, :2, :, :] *= 1j 257 258 args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [5, 4, 6, 3]) 259 sub_args = args._for_operation_with_qid_shape( 260 op_indices, tuple(qid_shape[i] for i in op_indices) 261 ) 262 sub_result = val._apply_unitary_(sub_args) 263 result = _incorporate_result_into_target(args, sub_args, sub_result) 264 np.testing.assert_allclose(result, expected, atol=1e-8, verbose=True) 265 266 for op in operations: 267 assert_is_swap_simple(op) 268 assert_is_swap(op) 269 270 271def test_big_endian_subspace_index(): 272 state = np.zeros(shape=(2, 3, 4, 5, 1, 6, 1, 1)) 273 args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [1, 3]) 274 s = slice(None) 275 assert args.subspace_index(little_endian_bits_int=1) == (s, 1, s, 0, s, s, s, s) 276 assert args.subspace_index(big_endian_bits_int=1) == (s, 0, s, 1, s, s, s, s) 277 278 279def test_apply_unitaries(): 280 a, b, c = cirq.LineQubit.range(3) 281 282 result = cirq.apply_unitaries( 283 unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], qubits=[a, b, c] 284 ) 285 np.testing.assert_allclose( 286 result.reshape(8), 287 [ 288 np.sqrt(0.5), 289 0, 290 0, 291 0, 292 0, 293 0, 294 0.5, 295 0.5, 296 ], 297 atol=1e-8, 298 ) 299 300 # Different order. 301 result = cirq.apply_unitaries( 302 unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], qubits=[a, c, b] 303 ) 304 np.testing.assert_allclose( 305 result.reshape(8), 306 [ 307 np.sqrt(0.5), 308 0, 309 0, 310 0, 311 0, 312 0.5, 313 0, 314 0.5, 315 ], 316 atol=1e-8, 317 ) 318 319 # Explicit arguments. 320 result = cirq.apply_unitaries( 321 unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], 322 qubits=[a, b, c], 323 args=cirq.ApplyUnitaryArgs.default(num_qubits=3), 324 ) 325 np.testing.assert_allclose( 326 result.reshape(8), 327 [ 328 np.sqrt(0.5), 329 0, 330 0, 331 0, 332 0, 333 0, 334 0.5, 335 0.5, 336 ], 337 atol=1e-8, 338 ) 339 340 # Empty. 341 result = cirq.apply_unitaries(unitary_values=[], qubits=[]) 342 np.testing.assert_allclose(result, [1]) 343 result = cirq.apply_unitaries(unitary_values=[], qubits=[], default=None) 344 np.testing.assert_allclose(result, [1]) 345 346 # Non-unitary operation. 347 with pytest.raises(TypeError, match='non-unitary'): 348 _ = cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a]) 349 assert ( 350 cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a], default=None) 351 is None 352 ) 353 assert ( 354 cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a], default=1) 355 == 1 356 ) 357 358 # Inconsistent arguments. 359 with pytest.raises(ValueError, match='len'): 360 _ = cirq.apply_unitaries( 361 unitary_values=[], qubits=[], args=cirq.ApplyUnitaryArgs.default(1) 362 ) 363 364 365def test_apply_unitaries_mixed_qid_shapes(): 366 class PlusOneMod3Gate(cirq.SingleQubitGate): 367 def _qid_shape_(self): 368 return (3,) 369 370 def _unitary_(self): 371 return np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) # yapf: disable 372 373 class PlusOneMod4Gate(cirq.SingleQubitGate): 374 def _qid_shape_(self): 375 return (4,) 376 377 def _unitary_(self): 378 return np.array( 379 [[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]] 380 ) # yapf: disable 381 382 a, b = cirq.LineQid.for_qid_shape((3, 4)) 383 384 result = cirq.apply_unitaries( 385 unitary_values=[ 386 PlusOneMod3Gate().on(a.with_dimension(3)), 387 cirq.X(a.with_dimension(2)), 388 cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), 389 cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), 390 cirq.X(a.with_dimension(2)), 391 PlusOneMod3Gate().on(a.with_dimension(3)), 392 PlusOneMod3Gate().on(a.with_dimension(3)), 393 ], 394 qubits=[a, b], 395 ) 396 np.testing.assert_allclose(result.reshape(12), [1] + [0] * 11, atol=1e-8) 397 398 result = cirq.apply_unitaries( 399 unitary_values=[ 400 PlusOneMod3Gate().on(a.with_dimension(3)), 401 cirq.X(a.with_dimension(2)), 402 cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), 403 cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), 404 cirq.X(a.with_dimension(2)), 405 PlusOneMod3Gate().on(a.with_dimension(3)), 406 PlusOneMod3Gate().on(a.with_dimension(3)), 407 ], 408 qubits=[a, b], 409 args=cirq.ApplyUnitaryArgs( 410 target_tensor=cirq.eye_tensor((3, 4), dtype=np.complex64), 411 available_buffer=cirq.eye_tensor((3, 4), dtype=np.complex64), 412 axes=(0, 1), 413 ), 414 ) 415 np.testing.assert_allclose(result.reshape(12, 12), np.eye(12), atol=1e-8) 416 417 result = cirq.apply_unitaries( 418 unitary_values=[ 419 PlusOneMod3Gate().on(a.with_dimension(3)), 420 cirq.X(a.with_dimension(2)), 421 PlusOneMod4Gate().on(b.with_dimension(4)), 422 PlusOneMod4Gate().on(b.with_dimension(4)), 423 cirq.X(b.with_dimension(2)), 424 PlusOneMod4Gate().on(b.with_dimension(4)), 425 PlusOneMod4Gate().on(b.with_dimension(4)), 426 cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), 427 PlusOneMod4Gate().on(b.with_dimension(4)), 428 cirq.X(b.with_dimension(2)), 429 cirq.CNOT(a.with_dimension(2), b.with_dimension(2)), 430 cirq.X(a.with_dimension(2)), 431 PlusOneMod3Gate().on(a.with_dimension(3)), 432 PlusOneMod3Gate().on(a.with_dimension(3)), 433 ], 434 qubits=[a, b], 435 args=cirq.ApplyUnitaryArgs( 436 target_tensor=cirq.eye_tensor((3, 4), dtype=np.complex64), 437 available_buffer=cirq.eye_tensor((3, 4), dtype=np.complex64), 438 axes=(0, 1), 439 ), 440 ) 441 np.testing.assert_allclose( 442 result.reshape(12, 12), 443 np.array( 444 [ 445 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 446 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 447 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 448 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 449 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 450 [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 451 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 452 [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 453 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 454 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 455 [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 456 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 457 ] 458 ), 459 atol=1e-8, 460 ) 461 462 463def test_incorporate_result_not_view(): 464 tensor = np.zeros((2, 2)) 465 tensor2 = np.zeros((2, 2)) 466 buffer = np.empty_like(tensor) 467 args = cirq.ApplyUnitaryArgs(tensor, buffer, [0]) 468 not_sub_args = cirq.ApplyUnitaryArgs(tensor2, buffer, [0]) 469 with pytest.raises(ValueError, match='view'): 470 _incorporate_result_into_target(args, not_sub_args, tensor2) 471 472 473def test_default_method_arguments(): 474 with pytest.raises(TypeError, match='exactly one of'): 475 cirq.ApplyUnitaryArgs.default(1, qid_shape=(2,)) 476 477 478def test_apply_unitary_args_with_axes_transposed_to_start(): 479 target = np.zeros((2, 3, 4, 5)) 480 buffer = np.zeros((2, 3, 4, 5)) 481 args = cirq.ApplyUnitaryArgs(target, buffer, [1, 3]) 482 483 new_args = args.with_axes_transposed_to_start() 484 assert new_args.target_tensor.shape == (3, 5, 2, 4) 485 assert new_args.available_buffer.shape == (3, 5, 2, 4) 486 assert new_args.axes == (0, 1) 487 488 # Confirm aliasing. 489 new_args.target_tensor[2, 4, 1, 3] = 1 490 assert args.target_tensor[1, 2, 3, 4] == 1 491 new_args.available_buffer[2, 4, 1, 3] = 2 492 assert args.available_buffer[1, 2, 3, 4] == 2 493