1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import AbstractSet, Iterator, Any 16 17import pytest 18import numpy as np 19import sympy 20 21import cirq 22 23 24class ValidQubit(cirq.Qid): 25 def __init__(self, name): 26 self._name = name 27 28 @property 29 def dimension(self): 30 return 2 31 32 def _comparison_key(self): 33 return self._name 34 35 def __repr__(self): 36 return f'ValidQubit({self._name!r})' 37 38 def __str__(self): 39 return f'TQ_{self._name!s}' 40 41 42class ValidQid(cirq.Qid): 43 def __init__(self, name, dimension): 44 self._name = name 45 self._dimension = dimension 46 self.validate_dimension(dimension) 47 48 @property 49 def dimension(self): 50 return self._dimension 51 52 def with_dimension(self, dimension): 53 return ValidQid(self._name, dimension) 54 55 def _comparison_key(self): 56 return self._name 57 58 59def test_wrapped_qid(): 60 assert type(ValidQubit('a').with_dimension(3)) is not ValidQubit 61 assert type(ValidQubit('a').with_dimension(2)) is ValidQubit 62 assert type(ValidQubit('a').with_dimension(5).with_dimension(2)) is ValidQubit 63 assert ValidQubit('a').with_dimension(3).with_dimension(4) == ValidQubit('a').with_dimension(4) 64 assert ValidQubit('a').with_dimension(3).qubit == ValidQubit('a') 65 assert ValidQubit('a').with_dimension(3) == ValidQubit('a').with_dimension(3) 66 assert ValidQubit('a').with_dimension(3) < ValidQubit('a').with_dimension(4) 67 assert ValidQubit('a').with_dimension(3) < ValidQubit('b').with_dimension(3) 68 assert ValidQubit('a').with_dimension(4) < ValidQubit('b').with_dimension(3) 69 70 cirq.testing.assert_equivalent_repr( 71 ValidQubit('a').with_dimension(3), global_vals={'ValidQubit': ValidQubit} 72 ) 73 assert str(ValidQubit('a').with_dimension(3)) == 'TQ_a (d=3)' 74 75 assert ValidQubit('zz').with_dimension(3)._json_dict_() == { 76 'cirq_type': '_QubitAsQid', 77 'qubit': ValidQubit('zz'), 78 'dimension': 3, 79 } 80 81 82def test_qid_dimension(): 83 assert ValidQubit('a').dimension == 2 84 assert ValidQubit('a').with_dimension(3).dimension == 3 85 with pytest.raises(ValueError, match='Wrong qid dimension'): 86 _ = ValidQubit('a').with_dimension(0) 87 with pytest.raises(ValueError, match='Wrong qid dimension'): 88 _ = ValidQubit('a').with_dimension(-3) 89 90 assert ValidQid('a', 3).dimension == 3 91 assert ValidQid('a', 3).with_dimension(2).dimension == 2 92 assert ValidQid('a', 3).with_dimension(4) == ValidQid('a', 4) 93 with pytest.raises(ValueError, match='Wrong qid dimension'): 94 _ = ValidQid('a', 3).with_dimension(0) 95 with pytest.raises(ValueError, match='Wrong qid dimension'): 96 _ = ValidQid('a', 3).with_dimension(-3) 97 98 99class ValiGate(cirq.Gate): 100 def _num_qubits_(self): 101 return 2 102 103 def validate_args(self, qubits): 104 if len(qubits) == 1: 105 return # Bypass check for some tests 106 super().validate_args(qubits) 107 108 109def test_gate(): 110 a, b, c = cirq.LineQubit.range(3) 111 112 g = ValiGate() 113 assert cirq.num_qubits(g) == 2 114 115 _ = g.on(a, c) 116 with pytest.raises(ValueError, match='Wrong number'): 117 _ = g.on(a, c, b) 118 119 _ = g(a) # Bypassing validate_args 120 _ = g(a, c) 121 with pytest.raises(ValueError, match='Wrong number'): 122 _ = g(c, b, a) 123 with pytest.raises(ValueError, match='Wrong shape'): 124 _ = g(a, b.with_dimension(3)) 125 126 assert g.controlled(0) is g 127 128 129def test_op(): 130 a, b, c, d = cirq.LineQubit.range(4) 131 g = ValiGate() 132 op = g(a, b) 133 assert op.controlled_by() is op 134 controlled_op = op.controlled_by(c, d) 135 assert controlled_op.sub_operation == op 136 assert controlled_op.controls == (c, d) 137 138 139def test_op_validate(): 140 op = cirq.X(cirq.LineQid(0, 2)) 141 op2 = cirq.CNOT(*cirq.LineQid.range(2, dimension=2)) 142 op.validate_args([cirq.LineQid(1, 2)]) # Valid 143 op2.validate_args(cirq.LineQid.range(1, 3, dimension=2)) # Valid 144 with pytest.raises(ValueError, match='Wrong shape'): 145 op.validate_args([cirq.LineQid(1, 9)]) 146 with pytest.raises(ValueError, match='Wrong number'): 147 op.validate_args([cirq.LineQid(1, 2), cirq.LineQid(2, 2)]) 148 with pytest.raises(ValueError, match='Duplicate'): 149 op2.validate_args([cirq.LineQid(1, 2), cirq.LineQid(1, 2)]) 150 151 152def test_default_validation_and_inverse(): 153 class TestGate(cirq.Gate): 154 def _num_qubits_(self): 155 return 2 156 157 def _decompose_(self, qubits): 158 a, b = qubits 159 yield cirq.Z(a) 160 yield cirq.S(b) 161 yield cirq.X(a) 162 163 def __eq__(self, other): 164 return isinstance(other, TestGate) 165 166 def __repr__(self): 167 return 'TestGate()' 168 169 a, b = cirq.LineQubit.range(2) 170 171 with pytest.raises(ValueError, match='number of qubits'): 172 TestGate().on(a) 173 174 t = TestGate().on(a, b) 175 i = t ** -1 176 assert i ** -1 == t 177 assert t ** -1 == i 178 assert cirq.decompose(i) == [cirq.X(a), cirq.S(b) ** -1, cirq.Z(a)] 179 cirq.testing.assert_allclose_up_to_global_phase( 180 cirq.unitary(i), cirq.unitary(t).conj().T, atol=1e-8 181 ) 182 183 cirq.testing.assert_implements_consistent_protocols(i, local_vals={'TestGate': TestGate}) 184 185 186def test_default_inverse(): 187 class TestGate(cirq.Gate): 188 def _num_qubits_(self): 189 return 3 190 191 def _decompose_(self, qubits): 192 return (cirq.X ** 0.1).on_each(*qubits) 193 194 assert cirq.inverse(TestGate(), None) is not None 195 cirq.testing.assert_has_consistent_qid_shape(cirq.inverse(TestGate())) 196 cirq.testing.assert_has_consistent_qid_shape( 197 cirq.inverse(TestGate().on(*cirq.LineQubit.range(3))) 198 ) 199 200 201def test_no_inverse_if_not_unitary(): 202 class TestGate(cirq.Gate): 203 def _num_qubits_(self): 204 return 1 205 206 def _decompose_(self, qubits): 207 return cirq.amplitude_damp(0.5).on(qubits[0]) 208 209 assert cirq.inverse(TestGate(), None) is None 210 211 212def test_default_qudit_inverse(): 213 class TestGate(cirq.Gate): 214 def _qid_shape_(self): 215 return (1, 2, 3) 216 217 def _decompose_(self, qubits): 218 return (cirq.X ** 0.1).on(qubits[1]) 219 220 assert cirq.qid_shape(cirq.inverse(TestGate(), None)) == (1, 2, 3) 221 cirq.testing.assert_has_consistent_qid_shape(cirq.inverse(TestGate())) 222 223 224@pytest.mark.parametrize( 225 'expression, expected_result', 226 ( 227 (cirq.X * 2, 2 * cirq.X), 228 (cirq.Y * 2, cirq.Y + cirq.Y), 229 (cirq.Z - cirq.Z + cirq.Z, cirq.Z.wrap_in_linear_combination()), 230 (1j * cirq.S * 1j, -cirq.S), 231 (cirq.CZ * 1, cirq.CZ / 1), 232 (-cirq.CSWAP * 1j, cirq.CSWAP / 1j), 233 (cirq.TOFFOLI * 0.5, cirq.TOFFOLI / 2), 234 ), 235) 236def test_gate_algebra(expression, expected_result): 237 assert expression == expected_result 238 239 240def test_gate_shape(): 241 class ShapeGate(cirq.Gate): 242 def _qid_shape_(self): 243 return (1, 2, 3, 4) 244 245 class QubitGate(cirq.Gate): 246 def _num_qubits_(self): 247 return 3 248 249 class DeprecatedGate(cirq.Gate): 250 def num_qubits(self): 251 return 3 252 253 shape_gate = ShapeGate() 254 assert cirq.qid_shape(shape_gate) == (1, 2, 3, 4) 255 assert cirq.num_qubits(shape_gate) == 4 256 assert shape_gate.num_qubits() == 4 257 258 qubit_gate = QubitGate() 259 assert cirq.qid_shape(qubit_gate) == (2, 2, 2) 260 assert cirq.num_qubits(qubit_gate) == 3 261 assert qubit_gate.num_qubits() == 3 262 263 dep_gate = DeprecatedGate() 264 assert cirq.qid_shape(dep_gate) == (2, 2, 2) 265 assert cirq.num_qubits(dep_gate) == 3 266 assert dep_gate.num_qubits() == 3 267 268 269def test_gate_shape_protocol(): 270 """This test is only needed while the `_num_qubits_` and `_qid_shape_` 271 methods are implemented as alternatives. This can be removed once the 272 deprecated `num_qubits` method is removed.""" 273 274 class NotImplementedGate1(cirq.Gate): 275 def _num_qubits_(self): 276 return NotImplemented 277 278 def _qid_shape_(self): 279 return NotImplemented 280 281 class NotImplementedGate2(cirq.Gate): 282 def _num_qubits_(self): 283 return NotImplemented 284 285 class NotImplementedGate3(cirq.Gate): 286 def _qid_shape_(self): 287 return NotImplemented 288 289 class ShapeGate(cirq.Gate): 290 def _num_qubits_(self): 291 return NotImplemented 292 293 def _qid_shape_(self): 294 return (1, 2, 3) 295 296 class QubitGate(cirq.Gate): 297 def _num_qubits_(self): 298 return 2 299 300 def _qid_shape_(self): 301 return NotImplemented 302 303 with pytest.raises(TypeError, match='returned NotImplemented'): 304 cirq.qid_shape(NotImplementedGate1()) 305 with pytest.raises(TypeError, match='returned NotImplemented'): 306 cirq.num_qubits(NotImplementedGate1()) 307 with pytest.raises(TypeError, match='returned NotImplemented'): 308 _ = NotImplementedGate1().num_qubits() # Deprecated 309 with pytest.raises(TypeError, match='returned NotImplemented'): 310 cirq.qid_shape(NotImplementedGate2()) 311 with pytest.raises(TypeError, match='returned NotImplemented'): 312 cirq.num_qubits(NotImplementedGate2()) 313 with pytest.raises(TypeError, match='returned NotImplemented'): 314 _ = NotImplementedGate2().num_qubits() # Deprecated 315 with pytest.raises(TypeError, match='returned NotImplemented'): 316 cirq.qid_shape(NotImplementedGate3()) 317 with pytest.raises(TypeError, match='returned NotImplemented'): 318 cirq.num_qubits(NotImplementedGate3()) 319 with pytest.raises(TypeError, match='returned NotImplemented'): 320 _ = NotImplementedGate3().num_qubits() # Deprecated 321 assert cirq.qid_shape(ShapeGate()) == (1, 2, 3) 322 assert cirq.num_qubits(ShapeGate()) == 3 323 assert ShapeGate().num_qubits() == 3 # Deprecated 324 assert cirq.qid_shape(QubitGate()) == (2, 2) 325 assert cirq.num_qubits(QubitGate()) == 2 326 assert QubitGate().num_qubits() == 2 # Deprecated 327 328 329def test_operation_shape(): 330 class FixedQids(cirq.Operation): 331 def with_qubits(self, *new_qids): 332 raise NotImplementedError # coverage: ignore 333 334 class QubitOp(FixedQids): 335 @property 336 def qubits(self): 337 return cirq.LineQubit.range(2) 338 339 class NumQubitOp(FixedQids): 340 @property 341 def qubits(self): 342 return cirq.LineQubit.range(3) 343 344 def _num_qubits_(self): 345 return 3 346 347 class ShapeOp(FixedQids): 348 @property 349 def qubits(self): 350 return cirq.LineQubit.range(4) 351 352 def _qid_shape_(self): 353 return (1, 2, 3, 4) 354 355 qubit_op = QubitOp() 356 assert len(qubit_op.qubits) == 2 357 assert cirq.qid_shape(qubit_op) == (2, 2) 358 assert cirq.num_qubits(qubit_op) == 2 359 360 num_qubit_op = NumQubitOp() 361 assert len(num_qubit_op.qubits) == 3 362 assert cirq.qid_shape(num_qubit_op) == (2, 2, 2) 363 assert cirq.num_qubits(num_qubit_op) == 3 364 365 shape_op = ShapeOp() 366 assert len(shape_op.qubits) == 4 367 assert cirq.qid_shape(shape_op) == (1, 2, 3, 4) 368 assert cirq.num_qubits(shape_op) == 4 369 370 371def test_gate_json_dict(): 372 g = cirq.CSWAP # not an eigen gate (which has its own _json_dict_) 373 assert g._json_dict_() == { 374 'cirq_type': 'CSwapGate', 375 } 376 377 378def test_inverse_composite_diagram_info(): 379 class Gate(cirq.Gate): 380 def _decompose_(self, qubits): 381 return cirq.S.on(qubits[0]) 382 383 def num_qubits(self) -> int: 384 return 1 385 386 c = cirq.inverse(Gate()) 387 assert cirq.circuit_diagram_info(c, default=None) is None 388 389 class Gate2(cirq.Gate): 390 def _decompose_(self, qubits): 391 return cirq.S.on(qubits[0]) 392 393 def num_qubits(self) -> int: 394 return 1 395 396 def _circuit_diagram_info_(self, args): 397 return 's!' 398 399 c = cirq.inverse(Gate2()) 400 assert cirq.circuit_diagram_info(c) == cirq.CircuitDiagramInfo( 401 wire_symbols=('s!',), exponent=-1 402 ) 403 404 405def test_tagged_operation_equality(): 406 eq = cirq.testing.EqualsTester() 407 q1 = cirq.GridQubit(1, 1) 408 op = cirq.X(q1) 409 op2 = cirq.Y(q1) 410 411 eq.add_equality_group(op) 412 eq.add_equality_group(op.with_tags('tag1'), cirq.TaggedOperation(op, 'tag1')) 413 eq.add_equality_group(op2.with_tags('tag1'), cirq.TaggedOperation(op2, 'tag1')) 414 eq.add_equality_group(op.with_tags('tag2'), cirq.TaggedOperation(op, 'tag2')) 415 eq.add_equality_group( 416 op.with_tags('tag1', 'tag2'), 417 op.with_tags('tag1').with_tags('tag2'), 418 cirq.TaggedOperation(op, 'tag1', 'tag2'), 419 ) 420 421 422def test_tagged_operation(): 423 q1 = cirq.GridQubit(1, 1) 424 q2 = cirq.GridQubit(2, 2) 425 op = cirq.X(q1).with_tags('tag1') 426 op_repr = "cirq.X(cirq.GridQubit(1, 1))" 427 assert repr(op) == f"cirq.TaggedOperation({op_repr}, 'tag1')" 428 429 assert op.qubits == (q1,) 430 assert op.tags == ('tag1',) 431 assert op.gate == cirq.X 432 assert op.with_qubits(q2) == cirq.X(q2).with_tags('tag1') 433 assert op.with_qubits(q2).qubits == (q2,) 434 assert not cirq.is_measurement(op) 435 436 437def test_with_tags_returns_same_instance_if_possible(): 438 untagged = cirq.X(cirq.GridQubit(1, 1)) 439 assert untagged.with_tags() is untagged 440 441 tagged = untagged.with_tags('foo') 442 assert tagged.with_tags() is tagged 443 444 445def test_tagged_measurement(): 446 assert not cirq.is_measurement(cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0')) 447 448 a = cirq.LineQubit(0) 449 op = cirq.measure(a, key='m').with_tags('tag') 450 assert cirq.is_measurement(op) 451 452 remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'}) 453 assert remap_op.tags == ('tag',) 454 assert cirq.is_measurement(remap_op) 455 assert cirq.measurement_key_names(remap_op) == {'k'} 456 assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) == op 457 458 459def test_cannot_remap_non_measurement_gate(): 460 a = cirq.LineQubit(0) 461 op = cirq.X(a).with_tags('tag') 462 463 assert cirq.with_measurement_key_mapping(op, {'m': 'k'}) is NotImplemented 464 465 466def test_circuit_diagram(): 467 class TaggyTag: 468 """Tag with a custom repr function to test circuit diagrams.""" 469 470 def __repr__(self): 471 return 'TaggyTag()' 472 473 h = cirq.H(cirq.GridQubit(1, 1)) 474 tagged_h = h.with_tags('tag1') 475 non_string_tag_h = h.with_tags(TaggyTag()) 476 477 expected = cirq.CircuitDiagramInfo( 478 wire_symbols=("H['tag1']",), 479 exponent=1.0, 480 connected=True, 481 exponent_qubit_index=None, 482 auto_exponent_parens=True, 483 ) 484 args = cirq.CircuitDiagramInfoArgs(None, None, None, None, None, False) 485 assert cirq.circuit_diagram_info(tagged_h) == expected 486 assert cirq.circuit_diagram_info(tagged_h, args) == cirq.circuit_diagram_info(h) 487 488 c = cirq.Circuit(tagged_h) 489 diagram_with_tags = "(1, 1): ───H['tag1']───" 490 diagram_without_tags = "(1, 1): ───H───" 491 assert str(cirq.Circuit(tagged_h)) == diagram_with_tags 492 assert c.to_text_diagram() == diagram_with_tags 493 assert c.to_text_diagram(include_tags=False) == diagram_without_tags 494 495 c = cirq.Circuit(non_string_tag_h) 496 diagram_with_non_string_tag = "(1, 1): ───H[TaggyTag()]───" 497 assert c.to_text_diagram() == diagram_with_non_string_tag 498 assert c.to_text_diagram(include_tags=False) == diagram_without_tags 499 500 501def test_circuit_diagram_tagged_global_phase(): 502 # Tests global phase operation 503 q = cirq.NamedQubit('a') 504 global_phase = cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0') 505 506 # Just global phase in a circuit 507 assert cirq.circuit_diagram_info(global_phase, default='default') == 'default' 508 cirq.testing.assert_has_diagram( 509 cirq.Circuit(global_phase), "\n\nglobal phase: π['tag0']", use_unicode_characters=True 510 ) 511 cirq.testing.assert_has_diagram( 512 cirq.Circuit(global_phase), 513 "\n\nglobal phase: π", 514 use_unicode_characters=True, 515 include_tags=False, 516 ) 517 518 expected = cirq.CircuitDiagramInfo( 519 wire_symbols=(), 520 exponent=1.0, 521 connected=True, 522 exponent_qubit_index=None, 523 auto_exponent_parens=True, 524 ) 525 526 # Operation with no qubits and returns diagram info with no wire symbols 527 class NoWireSymbols(cirq.GlobalPhaseOperation): 528 def _circuit_diagram_info_( 529 self, args: 'cirq.CircuitDiagramInfoArgs' 530 ) -> 'cirq.CircuitDiagramInfo': 531 return expected 532 533 no_wire_symbol_op = NoWireSymbols(coefficient=-1.0).with_tags('tag0') 534 assert cirq.circuit_diagram_info(no_wire_symbol_op, default='default') == expected 535 cirq.testing.assert_has_diagram( 536 cirq.Circuit(no_wire_symbol_op), 537 "\n\nglobal phase: π['tag0']", 538 use_unicode_characters=True, 539 ) 540 541 # Two global phases in one moment 542 tag1 = cirq.GlobalPhaseOperation(coefficient=1j).with_tags('tag1') 543 tag2 = cirq.GlobalPhaseOperation(coefficient=1j).with_tags('tag2') 544 c = cirq.Circuit([cirq.X(q), tag1, tag2]) 545 cirq.testing.assert_has_diagram( 546 c, 547 """\ 548a: ─────────────X─────────────────── 549 550global phase: π['tag1', 'tag2']""", 551 use_unicode_characters=True, 552 precision=2, 553 ) 554 555 # Two moments with global phase, one with another tagged gate 556 c = cirq.Circuit([cirq.X(q).with_tags('x_tag'), tag1]) 557 c.append(cirq.Moment([cirq.X(q), tag2])) 558 cirq.testing.assert_has_diagram( 559 c, 560 """\ 561a: ─────────────X['x_tag']─────X────────────── 562 563global phase: 0.5π['tag1'] 0.5π['tag2'] 564""", 565 use_unicode_characters=True, 566 include_tags=True, 567 ) 568 569 570def test_circuit_diagram_no_circuit_diagram(): 571 class NoCircuitDiagram(cirq.Gate): 572 def num_qubits(self) -> int: 573 return 1 574 575 def __repr__(self): 576 return 'guess-i-will-repr' 577 578 q = cirq.GridQubit(1, 1) 579 expected = "(1, 1): ───guess-i-will-repr───" 580 assert cirq.Circuit(NoCircuitDiagram()(q)).to_text_diagram() == expected 581 expected = "(1, 1): ───guess-i-will-repr['taggy']───" 582 assert cirq.Circuit(NoCircuitDiagram()(q).with_tags('taggy')).to_text_diagram() == expected 583 584 585def test_tagged_operation_forwards_protocols(): 586 """The results of all protocols applied to an operation with a tag should 587 be equivalent to the result without tags. 588 """ 589 q1 = cirq.GridQubit(1, 1) 590 q2 = cirq.GridQubit(1, 2) 591 h = cirq.H(q1) 592 tag = 'tag1' 593 tagged_h = cirq.H(q1).with_tags(tag) 594 595 np.testing.assert_equal(cirq.unitary(tagged_h), cirq.unitary(h)) 596 assert cirq.has_unitary(tagged_h) 597 assert cirq.decompose(tagged_h) == cirq.decompose(h) 598 assert cirq.pauli_expansion(tagged_h) == cirq.pauli_expansion(h) 599 assert cirq.equal_up_to_global_phase(h, tagged_h) 600 assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all() 601 602 assert cirq.measurement_key_name(cirq.measure(q1, key='blah').with_tags(tag)) == 'blah' 603 assert cirq.measurement_key_obj( 604 cirq.measure(q1, key='blah').with_tags(tag) 605 ) == cirq.MeasurementKey('blah') 606 607 parameterized_op = cirq.XPowGate(exponent=sympy.Symbol('t'))(q1).with_tags(tag) 608 assert cirq.is_parameterized(parameterized_op) 609 resolver = cirq.study.ParamResolver({'t': 0.25}) 610 assert cirq.resolve_parameters(parameterized_op, resolver) == cirq.XPowGate(exponent=0.25)( 611 q1 612 ).with_tags(tag) 613 assert cirq.resolve_parameters_once(parameterized_op, resolver) == cirq.XPowGate(exponent=0.25)( 614 q1 615 ).with_tags(tag) 616 617 y = cirq.Y(q1) 618 tagged_y = cirq.Y(q1).with_tags(tag) 619 assert tagged_y ** 0.5 == cirq.YPowGate(exponent=0.5)(q1) 620 assert tagged_y * 2 == (y * 2) 621 assert 3 * tagged_y == (3 * y) 622 assert cirq.phase_by(y, 0.125, 0) == cirq.phase_by(tagged_y, 0.125, 0) 623 controlled_y = tagged_y.controlled_by(q2) 624 assert controlled_y.qubits == ( 625 q2, 626 q1, 627 ) 628 assert isinstance(controlled_y, cirq.Operation) 629 assert not isinstance(controlled_y, cirq.TaggedOperation) 630 631 clifford_x = cirq.SingleQubitCliffordGate.X(q1) 632 tagged_x = cirq.SingleQubitCliffordGate.X(q1).with_tags(tag) 633 assert cirq.commutes(clifford_x, clifford_x) 634 assert cirq.commutes(tagged_x, clifford_x) 635 assert cirq.commutes(clifford_x, tagged_x) 636 assert cirq.commutes(tagged_x, tagged_x) 637 638 assert cirq.trace_distance_bound(y ** 0.001) == cirq.trace_distance_bound( 639 (y ** 0.001).with_tags(tag) 640 ) 641 642 flip = cirq.bit_flip(0.5)(q1) 643 tagged_flip = cirq.bit_flip(0.5)(q1).with_tags(tag) 644 assert cirq.has_mixture(tagged_flip) 645 assert cirq.has_kraus(tagged_flip) 646 647 flip_mixture = cirq.mixture(flip) 648 tagged_mixture = cirq.mixture(tagged_flip) 649 assert len(tagged_mixture) == 2 650 assert len(tagged_mixture[0]) == 2 651 assert len(tagged_mixture[1]) == 2 652 assert tagged_mixture[0][0] == flip_mixture[0][0] 653 assert np.isclose(tagged_mixture[0][1], flip_mixture[0][1]).all() 654 assert tagged_mixture[1][0] == flip_mixture[1][0] 655 assert np.isclose(tagged_mixture[1][1], flip_mixture[1][1]).all() 656 657 qubit_map = {q1: 'q1'} 658 qasm_args = cirq.QasmArgs(qubit_id_map=qubit_map) 659 assert cirq.qasm(h, args=qasm_args) == cirq.qasm(tagged_h, args=qasm_args) 660 661 cirq.testing.assert_has_consistent_apply_unitary(tagged_h) 662 663 664class ParameterizableTag: 665 def __init__(self, value): 666 self.value = value 667 668 def __eq__(self, other): 669 return self.value == other.value 670 671 def _is_parameterized_(self) -> bool: 672 return cirq.is_parameterized(self.value) 673 674 def _parameter_names_(self) -> AbstractSet[str]: 675 return cirq.parameter_names(self.value) 676 677 def _resolve_parameters_( 678 self, resolver: 'cirq.ParamResolver', recursive: bool 679 ) -> 'ParameterizableTag': 680 return ParameterizableTag(cirq.resolve_parameters(self.value, resolver, recursive)) 681 682 683@pytest.mark.parametrize('resolve_fn', [cirq.resolve_parameters, cirq.resolve_parameters_once]) 684def test_tagged_operation_resolves_parameterized_tags(resolve_fn): 685 q = cirq.GridQubit(0, 0) 686 tag = ParameterizableTag(sympy.Symbol('t')) 687 assert cirq.is_parameterized(tag) 688 assert cirq.parameter_names(tag) == {'t'} 689 op = cirq.Z(q).with_tags(tag) 690 assert cirq.is_parameterized(op) 691 assert cirq.parameter_names(op) == {'t'} 692 resolved_op = resolve_fn(op, {'t': 10}) 693 assert resolved_op == cirq.Z(q).with_tags(ParameterizableTag(10)) 694 assert not cirq.is_parameterized(resolved_op) 695 assert cirq.parameter_names(resolved_op) == set() 696 697 698def test_inverse_composite_standards(): 699 @cirq.value_equality 700 class Gate(cirq.Gate): 701 def _decompose_(self, qubits): 702 return cirq.S.on(qubits[0]) 703 704 def num_qubits(self) -> int: 705 return 1 706 707 def _has_unitary_(self): 708 return True 709 710 def _value_equality_values_(self): 711 return () 712 713 def __repr__(self): 714 return 'C()' 715 716 cirq.testing.assert_implements_consistent_protocols( 717 cirq.inverse(Gate()), global_vals={'C': Gate} 718 ) 719 720 721def test_tagged_act_on(): 722 class YesActOn(cirq.Gate): 723 def _num_qubits_(self) -> int: 724 return 1 725 726 def _act_on_(self, args, qubits): 727 return True 728 729 class NoActOn(cirq.Gate): 730 def _num_qubits_(self) -> int: 731 return 1 732 733 def _act_on_(self, args, qubits): 734 return NotImplemented 735 736 class MissingActOn(cirq.Operation): 737 def with_qubits(self, *new_qubits): 738 raise NotImplementedError() 739 740 @property 741 def qubits(self): 742 pass 743 744 q = cirq.LineQubit(1) 745 from cirq.protocols.act_on_protocol_test import DummyActOnArgs 746 747 args = DummyActOnArgs() 748 cirq.act_on(YesActOn()(q).with_tags("test"), args) 749 with pytest.raises(TypeError, match="Failed to act"): 750 cirq.act_on(NoActOn()(q).with_tags("test"), args) 751 with pytest.raises(TypeError, match="Failed to act"): 752 cirq.act_on(MissingActOn().with_tags("test"), args) 753 754 755def test_single_qubit_gate_validates_on_each(): 756 class Dummy(cirq.SingleQubitGate): 757 def matrix(self): 758 pass 759 760 g = Dummy() 761 assert g.num_qubits() == 1 762 763 test_qubits = [cirq.NamedQubit(str(i)) for i in range(3)] 764 765 _ = g.on_each(*test_qubits) 766 _ = g.on_each(test_qubits) 767 768 test_non_qubits = [str(i) for i in range(3)] 769 with pytest.raises(ValueError): 770 _ = g.on_each(*test_non_qubits) 771 with pytest.raises(ValueError): 772 _ = g.on_each(*test_non_qubits) 773 774 775def test_on_each(): 776 class CustomGate(cirq.SingleQubitGate): 777 pass 778 779 a = cirq.NamedQubit('a') 780 b = cirq.NamedQubit('b') 781 c = CustomGate() 782 783 assert c.on_each() == [] 784 assert c.on_each(a) == [c(a)] 785 assert c.on_each(a, b) == [c(a), c(b)] 786 assert c.on_each(b, a) == [c(b), c(a)] 787 788 assert c.on_each([]) == [] 789 assert c.on_each([a]) == [c(a)] 790 assert c.on_each([a, b]) == [c(a), c(b)] 791 assert c.on_each([b, a]) == [c(b), c(a)] 792 assert c.on_each([a, [b, a], b]) == [c(a), c(b), c(a), c(b)] 793 794 with pytest.raises(ValueError): 795 c.on_each('abcd') 796 with pytest.raises(ValueError): 797 c.on_each(['abcd']) 798 with pytest.raises(ValueError): 799 c.on_each([a, 'abcd']) 800 801 qubit_iterator = (q for q in [a, b, a, b]) 802 assert isinstance(qubit_iterator, Iterator) 803 assert c.on_each(qubit_iterator) == [c(a), c(b), c(a), c(b)] 804 805 806def test_on_each_two_qubits(): 807 a = cirq.NamedQubit('a') 808 b = cirq.NamedQubit('b') 809 g = cirq.testing.TwoQubitGate() 810 811 assert g.on_each([]) == [] 812 assert g.on_each([(a, b)]) == [g(a, b)] 813 assert g.on_each([[a, b]]) == [g(a, b)] 814 assert g.on_each([(b, a)]) == [g(b, a)] 815 assert g.on_each([(a, b), (b, a)]) == [g(a, b), g(b, a)] 816 assert g.on_each(zip([a, b], [b, a])) == [g(a, b), g(b, a)] 817 assert g.on_each() == [] 818 assert g.on_each((b, a)) == [g(b, a)] 819 assert g.on_each((a, b), (a, b)) == [g(a, b), g(a, b)] 820 assert g.on_each(*zip([a, b], [b, a])) == [g(a, b), g(b, a)] 821 with pytest.raises(TypeError, match='object is not iterable'): 822 g.on_each(a) 823 with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): 824 g.on_each(a, b) 825 with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): 826 g.on_each([12]) 827 with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): 828 g.on_each([(a, b), 12]) 829 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 830 g.on_each([(a, b), [(a, b)]]) 831 with pytest.raises(ValueError, match='Expected 2 qubits'): 832 g.on_each([()]) 833 with pytest.raises(ValueError, match='Expected 2 qubits'): 834 g.on_each([(a,)]) 835 with pytest.raises(ValueError, match='Expected 2 qubits'): 836 g.on_each([(a, b, a)]) 837 with pytest.raises(ValueError, match='Expected 2 qubits'): 838 g.on_each(zip([a, a])) 839 with pytest.raises(ValueError, match='Expected 2 qubits'): 840 g.on_each(zip([a, a], [b, b], [a, a])) 841 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 842 g.on_each('ab') 843 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 844 g.on_each(('ab',)) 845 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 846 g.on_each([('ab',)]) 847 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 848 g.on_each([(a, 'ab')]) 849 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 850 g.on_each([(a, 'b')]) 851 852 qubit_iterator = (qs for qs in [[a, b], [a, b]]) 853 assert isinstance(qubit_iterator, Iterator) 854 assert g.on_each(qubit_iterator) == [g(a, b), g(a, b)] 855 856 857def test_on_each_three_qubits(): 858 a = cirq.NamedQubit('a') 859 b = cirq.NamedQubit('b') 860 c = cirq.NamedQubit('c') 861 g = cirq.testing.ThreeQubitGate() 862 863 assert g.on_each([]) == [] 864 assert g.on_each([(a, b, c)]) == [g(a, b, c)] 865 assert g.on_each([[a, b, c]]) == [g(a, b, c)] 866 assert g.on_each([(c, b, a)]) == [g(c, b, a)] 867 assert g.on_each([(a, b, c), (c, b, a)]) == [g(a, b, c), g(c, b, a)] 868 assert g.on_each(zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)] 869 assert g.on_each() == [] 870 assert g.on_each((c, b, a)) == [g(c, b, a)] 871 assert g.on_each((a, b, c), (c, b, a)) == [g(a, b, c), g(c, b, a)] 872 assert g.on_each(*zip([a, c], [b, b], [c, a])) == [g(a, b, c), g(c, b, a)] 873 with pytest.raises(TypeError, match='object is not iterable'): 874 g.on_each(a) 875 with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): 876 g.on_each(a, b, c) 877 with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): 878 g.on_each([12]) 879 with pytest.raises(ValueError, match='Inputs to multi-qubit gates must be Sequence'): 880 g.on_each([(a, b, c), 12]) 881 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 882 g.on_each([(a, b, c), [(a, b, c)]]) 883 with pytest.raises(ValueError, match='Expected 3 qubits'): 884 g.on_each([(a,)]) 885 with pytest.raises(ValueError, match='Expected 3 qubits'): 886 g.on_each([(a, b)]) 887 with pytest.raises(ValueError, match='Expected 3 qubits'): 888 g.on_each([(a, b, c, a)]) 889 with pytest.raises(ValueError, match='Expected 3 qubits'): 890 g.on_each(zip([a, a], [b, b])) 891 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 892 g.on_each('abc') 893 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 894 g.on_each(('abc',)) 895 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 896 g.on_each([('abc',)]) 897 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 898 g.on_each([(a, 'abc')]) 899 with pytest.raises(ValueError, match='All values in sequence should be Qids'): 900 g.on_each([(a, 'bc')]) 901 902 qubit_iterator = (qs for qs in [[a, b, c], [a, b, c]]) 903 assert isinstance(qubit_iterator, Iterator) 904 assert g.on_each(qubit_iterator) == [g(a, b, c), g(a, b, c)] 905 906 907def test_on_each_iterable_qid(): 908 class QidIter(cirq.Qid): 909 @property 910 def dimension(self) -> int: 911 return 2 912 913 def _comparison_key(self) -> Any: 914 return 1 915 916 def __iter__(self): 917 raise NotImplementedError() 918 919 assert cirq.H.on_each(QidIter())[0] == cirq.H.on(QidIter()) 920