1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import pytest 16 17import numpy as np 18 19import cirq 20from cirq.testing.circuit_compare import ( 21 _assert_apply_unitary_works_when_axes_transposed, 22) 23 24 25def test_sensitive_to_phase(): 26 q = cirq.NamedQubit('q') 27 28 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 29 cirq.Circuit([cirq.Moment([])]), cirq.Circuit(), atol=0 30 ) 31 32 with pytest.raises(AssertionError): 33 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 34 cirq.Circuit([cirq.Moment([cirq.Z(q) ** 0.0001])]), cirq.Circuit(), atol=0 35 ) 36 37 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 38 cirq.Circuit([cirq.Moment([cirq.Z(q) ** 0.0001])]), cirq.Circuit(), atol=0.01 39 ) 40 41 42def test_sensitive_to_measurement_but_not_measured_phase(): 43 q = cirq.NamedQubit('q') 44 45 with pytest.raises(AssertionError): 46 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 47 cirq.Circuit([cirq.Moment([cirq.measure(q)])]), cirq.Circuit(), atol=1e-8 48 ) 49 50 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 51 cirq.Circuit([cirq.Moment([cirq.measure(q)])]), 52 cirq.Circuit( 53 [ 54 cirq.Moment([cirq.Z(q)]), 55 cirq.Moment([cirq.measure(q)]), 56 ] 57 ), 58 atol=1e-8, 59 ) 60 61 a, b = cirq.LineQubit.range(2) 62 63 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 64 cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]), 65 cirq.Circuit( 66 [ 67 cirq.Moment([cirq.Z(a)]), 68 cirq.Moment([cirq.measure(a, b)]), 69 ] 70 ), 71 atol=1e-8, 72 ) 73 74 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 75 cirq.Circuit([cirq.Moment([cirq.measure(a)])]), 76 cirq.Circuit( 77 [ 78 cirq.Moment([cirq.Z(a)]), 79 cirq.Moment([cirq.measure(a)]), 80 ] 81 ), 82 atol=1e-8, 83 ) 84 85 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 86 cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]), 87 cirq.Circuit( 88 [ 89 cirq.Moment([cirq.T(a), cirq.S(b)]), 90 cirq.Moment([cirq.measure(a, b)]), 91 ] 92 ), 93 atol=1e-8, 94 ) 95 96 with pytest.raises(AssertionError): 97 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 98 cirq.Circuit([cirq.Moment([cirq.measure(a)])]), 99 cirq.Circuit( 100 [ 101 cirq.Moment([cirq.T(a), cirq.S(b)]), 102 cirq.Moment([cirq.measure(a)]), 103 ] 104 ), 105 atol=1e-8, 106 ) 107 108 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 109 cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]), 110 cirq.Circuit( 111 [ 112 cirq.Moment([cirq.CZ(a, b)]), 113 cirq.Moment([cirq.measure(a, b)]), 114 ] 115 ), 116 atol=1e-8, 117 ) 118 119 120def test_sensitive_to_measurement_toggle(): 121 q = cirq.NamedQubit('q') 122 123 with pytest.raises(AssertionError): 124 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 125 cirq.Circuit([cirq.Moment([cirq.measure(q)])]), 126 cirq.Circuit( 127 [ 128 cirq.Moment([cirq.X(q)]), 129 cirq.Moment([cirq.measure(q)]), 130 ] 131 ), 132 atol=1e-8, 133 ) 134 135 with pytest.raises(AssertionError): 136 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 137 cirq.Circuit([cirq.Moment([cirq.measure(q)])]), 138 cirq.Circuit( 139 [ 140 cirq.Moment([cirq.measure(q, invert_mask=(True,))]), 141 ] 142 ), 143 atol=1e-8, 144 ) 145 146 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 147 cirq.Circuit([cirq.Moment([cirq.measure(q)])]), 148 cirq.Circuit( 149 [ 150 cirq.Moment([cirq.X(q)]), 151 cirq.Moment([cirq.measure(q, invert_mask=(True,))]), 152 ] 153 ), 154 atol=1e-8, 155 ) 156 157 158def test_measuring_qubits(): 159 a, b = cirq.LineQubit.range(2) 160 161 with pytest.raises(AssertionError): 162 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 163 cirq.Circuit([cirq.Moment([cirq.measure(a)])]), 164 cirq.Circuit([cirq.Moment([cirq.measure(b)])]), 165 atol=1e-8, 166 ) 167 168 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 169 cirq.Circuit([cirq.Moment([cirq.measure(a, b, invert_mask=(True,))])]), 170 cirq.Circuit([cirq.Moment([cirq.measure(b, a, invert_mask=(False, True))])]), 171 atol=1e-8, 172 ) 173 174 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 175 cirq.Circuit( 176 [ 177 cirq.Moment([cirq.measure(a)]), 178 cirq.Moment([cirq.measure(b)]), 179 ] 180 ), 181 cirq.Circuit([cirq.Moment([cirq.measure(a, b)])]), 182 atol=1e-8, 183 ) 184 185 186@pytest.mark.parametrize( 187 'circuit', [cirq.testing.random_circuit(cirq.LineQubit.range(2), 4, 0.5) for _ in range(5)] 188) 189def test_random_same_matrix(circuit): 190 a, b = cirq.LineQubit.range(2) 191 same = cirq.Circuit( 192 cirq.MatrixGate(circuit.unitary(qubits_that_should_be_present=[a, b])).on(a, b) 193 ) 194 195 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, same, atol=1e-8) 196 197 circuit.append(cirq.measure(a)) 198 same.append(cirq.measure(a)) 199 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(circuit, same, atol=1e-8) 200 201 202def test_correct_qubit_ordering(): 203 a, b = cirq.LineQubit.range(2) 204 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 205 cirq.Circuit(cirq.Z(a), cirq.Z(b), cirq.measure(b)), 206 cirq.Circuit(cirq.Z(a), cirq.measure(b)), 207 atol=1e-8, 208 ) 209 210 with pytest.raises(AssertionError): 211 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 212 cirq.Circuit(cirq.Z(a), cirq.Z(b), cirq.measure(b)), 213 cirq.Circuit(cirq.Z(b), cirq.measure(b)), 214 atol=1e-8, 215 ) 216 217 218def test_known_old_failure(): 219 a, b = cirq.LineQubit.range(2) 220 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 221 actual=cirq.Circuit( 222 cirq.PhasedXPowGate(exponent=0.61351656, phase_exponent=0.8034575038876517).on(b), 223 cirq.measure(a, b), 224 ), 225 reference=cirq.Circuit( 226 cirq.PhasedXPowGate(exponent=0.61351656, phase_exponent=0.8034575038876517).on(b), 227 cirq.Z(a) ** 0.5, 228 cirq.Z(b) ** 0.1, 229 cirq.measure(a, b), 230 ), 231 atol=1e-8, 232 ) 233 234 235def test_assert_same_circuits(): 236 a, b = cirq.LineQubit.range(2) 237 238 cirq.testing.assert_same_circuits( 239 cirq.Circuit(cirq.H(a)), 240 cirq.Circuit(cirq.H(a)), 241 ) 242 243 with pytest.raises(AssertionError) as exc_info: 244 cirq.testing.assert_same_circuits( 245 cirq.Circuit(cirq.H(a)), 246 cirq.Circuit(), 247 ) 248 assert 'differing moment:\n0\n' in exc_info.value.args[0] 249 250 with pytest.raises(AssertionError) as exc_info: 251 cirq.testing.assert_same_circuits( 252 cirq.Circuit(cirq.H(a), cirq.H(a)), 253 cirq.Circuit(cirq.H(a), cirq.CZ(a, b)), 254 ) 255 assert 'differing moment:\n1\n' in exc_info.value.args[0] 256 257 with pytest.raises(AssertionError): 258 cirq.testing.assert_same_circuits( 259 cirq.Circuit(cirq.CNOT(a, b)), 260 cirq.Circuit(cirq.ControlledGate(cirq.X).on(a, b)), 261 ) 262 263 264def test_assert_has_diagram(): 265 a, b = cirq.LineQubit.range(2) 266 circuit = cirq.Circuit(cirq.CNOT(a, b)) 267 cirq.testing.assert_has_diagram( 268 circuit, 269 """ 2700: ───@─── 271 │ 2721: ───X─── 273""", 274 ) 275 276 expected_error = """Circuit's text diagram differs from the desired diagram. 277 278Diagram of actual circuit: 2790: ───@─── 280 │ 2811: ───X─── 282 283Desired text diagram: 2840: ───@─── 285 │ 2861: ───Z─── 287 288Highlighted differences: 2890: ───@─── 290 │ 2911: ───█─── 292 293""" 294 295 with pytest.raises(AssertionError) as ex_info: 296 cirq.testing.assert_has_diagram( 297 circuit, 298 """ 2990: ───@─── 300 │ 3011: ───Z─── 302""", 303 ) 304 assert expected_error in ex_info.value.args[0] 305 306 307def test_assert_has_consistent_apply_unitary(): 308 class IdentityReturningUnalteredWorkspace: 309 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 310 return args.available_buffer 311 312 def _unitary_(self): 313 return np.eye(2) 314 315 def _num_qubits_(self): 316 return 1 317 318 with pytest.raises(AssertionError): 319 cirq.testing.assert_has_consistent_apply_unitary(IdentityReturningUnalteredWorkspace()) 320 321 class DifferentEffect: 322 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 323 o = args.subspace_index(0) 324 i = args.subspace_index(1) 325 args.available_buffer[o] = args.target_tensor[i] 326 args.available_buffer[i] = args.target_tensor[o] 327 return args.available_buffer 328 329 def _unitary_(self): 330 return np.eye(2, dtype=np.complex128) 331 332 def _num_qubits_(self): 333 return 1 334 335 with pytest.raises(AssertionError): 336 cirq.testing.assert_has_consistent_apply_unitary(DifferentEffect()) 337 338 class IgnoreAxisEffect: 339 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 340 if args.target_tensor.shape[0] > 1: 341 args.available_buffer[0] = args.target_tensor[1] 342 args.available_buffer[1] = args.target_tensor[0] 343 return args.available_buffer 344 345 def _unitary_(self): 346 return np.array([[0, 1], [1, 0]]) 347 348 def _num_qubits_(self): 349 return 1 350 351 with pytest.raises(AssertionError, match='Not equal|acted differently'): 352 cirq.testing.assert_has_consistent_apply_unitary(IgnoreAxisEffect()) 353 354 class SameEffect: 355 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 356 o = args.subspace_index(0) 357 i = args.subspace_index(1) 358 args.available_buffer[o] = args.target_tensor[i] 359 args.available_buffer[i] = args.target_tensor[o] 360 return args.available_buffer 361 362 def _unitary_(self): 363 return np.array([[0, 1], [1, 0]]) 364 365 def _num_qubits_(self): 366 return 1 367 368 cirq.testing.assert_has_consistent_apply_unitary(SameEffect()) 369 370 class SameQuditEffect: 371 def _qid_shape_(self): 372 return (3,) 373 374 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 375 args.available_buffer[..., 0] = args.target_tensor[..., 2] 376 args.available_buffer[..., 1] = args.target_tensor[..., 0] 377 args.available_buffer[..., 2] = args.target_tensor[..., 1] 378 return args.available_buffer 379 380 def _unitary_(self): 381 return np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) 382 383 cirq.testing.assert_has_consistent_apply_unitary(SameQuditEffect()) 384 385 class BadExponent: 386 def __init__(self, power): 387 self.power = power 388 389 def __pow__(self, power): 390 return BadExponent(self.power * power) 391 392 def _num_qubits_(self): 393 return 1 394 395 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 396 i = args.subspace_index(1) 397 args.target_tensor[i] *= self.power * 2 398 return args.target_tensor 399 400 def _unitary_(self): 401 return np.array([[1, 0], [0, 2]]) 402 403 cirq.testing.assert_has_consistent_apply_unitary(BadExponent(1)) 404 405 with pytest.raises(AssertionError): 406 cirq.testing.assert_has_consistent_apply_unitary_for_various_exponents( 407 BadExponent(1), exponents=[1, 2] 408 ) 409 410 class EffectWithoutUnitary: 411 def _num_qubits_(self): 412 return 1 413 414 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 415 return args.target_tensor 416 417 cirq.testing.assert_has_consistent_apply_unitary(EffectWithoutUnitary()) 418 419 class NoEffect: 420 def _num_qubits_(self): 421 return 1 422 423 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray: 424 return NotImplemented 425 426 cirq.testing.assert_has_consistent_apply_unitary(NoEffect()) 427 428 class UnknownCountEffect: 429 pass 430 431 with pytest.raises(TypeError, match="no _num_qubits_ or _qid_shape_"): 432 cirq.testing.assert_has_consistent_apply_unitary(UnknownCountEffect()) 433 434 cirq.testing.assert_has_consistent_apply_unitary(cirq.X) 435 436 cirq.testing.assert_has_consistent_apply_unitary(cirq.X.on(cirq.NamedQubit('q'))) 437 438 439def test_assert_has_consistent_qid_shape(): 440 class ConsistentGate(cirq.Gate): 441 def _num_qubits_(self): 442 return 4 443 444 def _qid_shape_(self): 445 return 1, 2, 3, 4 446 447 class InconsistentGate(cirq.Gate): 448 def _num_qubits_(self): 449 return 2 450 451 def _qid_shape_(self): 452 return 1, 2, 3, 4 453 454 class BadShapeGate(cirq.Gate): 455 def _num_qubits_(self): 456 return 4 457 458 def _qid_shape_(self): 459 return 1, 2, 0, 4 460 461 class ConsistentOp(cirq.Operation): 462 def with_qubits(self, *qubits): 463 raise NotImplementedError # coverage: ignore 464 465 @property 466 def qubits(self): 467 return cirq.LineQubit.range(4) 468 469 def _num_qubits_(self): 470 return 4 471 472 def _qid_shape_(self): 473 return (1, 2, 3, 4) 474 475 # The 'coverage: ignore' comments in the InconsistentOp classes is needed 476 # because test_assert_has_consistent_qid_shape may only need to check two of 477 # the three methods before finding an inconsistency and throwing an error. 478 class InconsistentOp1(cirq.Operation): 479 def with_qubits(self, *qubits): 480 raise NotImplementedError # coverage: ignore 481 482 @property 483 def qubits(self): 484 return cirq.LineQubit.range(2) 485 486 def _num_qubits_(self): 487 return 4 # coverage: ignore 488 489 def _qid_shape_(self): 490 return (1, 2, 3, 4) # coverage: ignore 491 492 class InconsistentOp2(cirq.Operation): 493 def with_qubits(self, *qubits): 494 raise NotImplementedError # coverage: ignore 495 496 @property 497 def qubits(self): 498 return cirq.LineQubit.range(4) # coverage: ignore 499 500 def _num_qubits_(self): 501 return 2 502 503 def _qid_shape_(self): 504 return (1, 2, 3, 4) # coverage: ignore 505 506 class InconsistentOp3(cirq.Operation): 507 def with_qubits(self, *qubits): 508 raise NotImplementedError # coverage: ignore 509 510 @property 511 def qubits(self): 512 return cirq.LineQubit.range(4) # coverage: ignore 513 514 def _num_qubits_(self): 515 return 4 # coverage: ignore 516 517 def _qid_shape_(self): 518 return 1, 2 519 520 class NoProtocol: 521 pass 522 523 cirq.testing.assert_has_consistent_qid_shape(ConsistentGate()) 524 with pytest.raises(AssertionError, match='disagree'): 525 cirq.testing.assert_has_consistent_qid_shape(InconsistentGate()) 526 with pytest.raises(AssertionError, match='positive'): 527 cirq.testing.assert_has_consistent_qid_shape(BadShapeGate()) 528 cirq.testing.assert_has_consistent_qid_shape(ConsistentOp()) 529 with pytest.raises(AssertionError, match='disagree'): 530 cirq.testing.assert_has_consistent_qid_shape(InconsistentOp1()) 531 with pytest.raises(AssertionError, match='disagree'): 532 cirq.testing.assert_has_consistent_qid_shape(InconsistentOp2()) 533 with pytest.raises(AssertionError, match='disagree'): 534 cirq.testing.assert_has_consistent_qid_shape(InconsistentOp3()) 535 cirq.testing.assert_has_consistent_qid_shape(NoProtocol()) 536 537 538def test_assert_apply_unitary_works_when_axes_transposed_failure(): 539 class BadOp: 540 def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): 541 # Get a more convenient view of the data. 542 a, b = args.axes 543 rest = list(range(len(args.target_tensor.shape))) 544 rest.remove(a) 545 rest.remove(b) 546 size = args.target_tensor.size 547 view = args.target_tensor.transpose([a, b, *rest]) 548 view = view.reshape((4, size // 4)) # Oops. Reshape might copy. 549 550 # Apply phase gradient. 551 view[1, ...] *= 1j 552 view[2, ...] *= -1 553 view[3, ...] *= -1j 554 return args.target_tensor 555 556 def _num_qubits_(self): 557 return 2 558 559 bad_op = BadOp() 560 assert cirq.has_unitary(bad_op) 561 562 # Appears to work. 563 np.testing.assert_allclose(cirq.unitary(bad_op), np.diag([1, 1j, -1, -1j])) 564 # But fails the more discerning test. 565 with pytest.raises(AssertionError, match='acted differently on out-of-order axes'): 566 for _ in range(100): # Axis orders chosen at random. Brute force a hit. 567 _assert_apply_unitary_works_when_axes_transposed(bad_op) 568