1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import List 16 17import pytest, sympy 18 19import cirq 20from cirq.circuits.circuit_operation import _full_join_string_lists 21 22 23def test_properties(): 24 a, b, c = cirq.LineQubit.range(3) 25 circuit = cirq.FrozenCircuit( 26 cirq.X(a), 27 cirq.Y(b), 28 cirq.H(c), 29 cirq.CX(a, b) ** sympy.Symbol('exp'), 30 cirq.measure(a, b, c, key='m'), 31 ) 32 op = cirq.CircuitOperation(circuit) 33 assert op.circuit is circuit 34 assert op.qubits == (a, b, c) 35 assert op.qubit_map == {} 36 assert op.measurement_key_map == {} 37 assert op.param_resolver == cirq.ParamResolver() 38 assert op.repetitions == 1 39 assert op.repetition_ids is None 40 # Despite having the same decomposition, these objects are not equal. 41 assert op != circuit 42 assert op == circuit.to_op() 43 44 45def test_circuit_type(): 46 a, b, c = cirq.LineQubit.range(3) 47 circuit = cirq.Circuit( 48 cirq.X(a), 49 cirq.Y(b), 50 cirq.H(c), 51 cirq.CX(a, b) ** sympy.Symbol('exp'), 52 cirq.measure(a, b, c, key='m'), 53 ) 54 with pytest.raises(TypeError, match='Expected circuit of type FrozenCircuit'): 55 _ = cirq.CircuitOperation(circuit) 56 57 58def test_non_invertible_circuit(): 59 a, b, c = cirq.LineQubit.range(3) 60 circuit = cirq.FrozenCircuit( 61 cirq.X(a), 62 cirq.Y(b), 63 cirq.H(c), 64 cirq.CX(a, b) ** sympy.Symbol('exp'), 65 cirq.measure(a, b, c, key='m'), 66 ) 67 with pytest.raises(ValueError, match='circuit is not invertible'): 68 _ = cirq.CircuitOperation(circuit, repetitions=-2) 69 70 71def test_repetitions_and_ids_length_mismatch(): 72 a, b, c = cirq.LineQubit.range(3) 73 circuit = cirq.FrozenCircuit( 74 cirq.X(a), 75 cirq.Y(b), 76 cirq.H(c), 77 cirq.CX(a, b) ** sympy.Symbol('exp'), 78 cirq.measure(a, b, c, key='m'), 79 ) 80 with pytest.raises(ValueError, match='Expected repetition_ids to be a list of length 2'): 81 _ = cirq.CircuitOperation(circuit, repetitions=2, repetition_ids=['a', 'b', 'c']) 82 83 84def test_is_measurement_memoization(): 85 a = cirq.LineQubit(0) 86 circuit = cirq.FrozenCircuit(cirq.measure(a, key='m')) 87 c_op = cirq.CircuitOperation(circuit) 88 assert circuit._has_measurements is None 89 # Memoize `_has_measurements` in the circuit. 90 assert cirq.is_measurement(c_op) 91 assert circuit._has_measurements is True 92 93 94def test_invalid_measurement_keys(): 95 a = cirq.LineQubit(0) 96 circuit = cirq.FrozenCircuit(cirq.measure(a, key='m')) 97 c_op = cirq.CircuitOperation(circuit) 98 # Invalid key remapping 99 with pytest.raises(ValueError, match='Mapping to invalid key: m:a'): 100 _ = c_op.with_measurement_key_mapping({'m': 'm:a'}) 101 102 # Invalid key remapping nested CircuitOperation 103 with pytest.raises(ValueError, match='Mapping to invalid key: m:a'): 104 _ = cirq.CircuitOperation(cirq.FrozenCircuit(c_op), measurement_key_map={'m': 'm:a'}) 105 106 # Originally invalid key 107 with pytest.raises(ValueError, match='Invalid key name: m:a'): 108 _ = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='m:a'))) 109 110 # Remapped to valid key 111 _ = cirq.CircuitOperation(circuit, measurement_key_map={'m:a': 'ma'}) 112 113 114def test_invalid_qubit_mapping(): 115 q = cirq.LineQubit(0) 116 q3 = cirq.LineQid(1, dimension=3) 117 118 # Invalid qid remapping dict in constructor 119 with pytest.raises(ValueError, match='Qid dimension conflict'): 120 _ = cirq.CircuitOperation(cirq.FrozenCircuit(), qubit_map={q: q3}) 121 122 # Invalid qid remapping dict in with_qubit_mapping call 123 c_op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q))) 124 with pytest.raises(ValueError, match='Qid dimension conflict'): 125 _ = c_op.with_qubit_mapping({q: q3}) 126 127 # Invalid qid remapping function in with_qubit_mapping call 128 with pytest.raises(ValueError, match='Qid dimension conflict'): 129 _ = c_op.with_qubit_mapping(lambda q: q3) 130 131 132def test_circuit_sharing(): 133 a, b, c = cirq.LineQubit.range(3) 134 circuit = cirq.FrozenCircuit( 135 cirq.X(a), 136 cirq.Y(b), 137 cirq.H(c), 138 cirq.CX(a, b) ** sympy.Symbol('exp'), 139 cirq.measure(a, b, c, key='m'), 140 ) 141 op1 = cirq.CircuitOperation(circuit) 142 op2 = cirq.CircuitOperation(op1.circuit) 143 op3 = circuit.to_op() 144 assert op1.circuit is circuit 145 assert op2.circuit is circuit 146 assert op3.circuit is circuit 147 148 assert hash(op1) == hash(op2) 149 assert hash(op1) == hash(op3) 150 151 152def test_with_qubits(): 153 a, b, c, d = cirq.LineQubit.range(4) 154 circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b)) 155 op_base = cirq.CircuitOperation(circuit) 156 157 op_with_qubits = op_base.with_qubits(d, c) 158 assert op_with_qubits.base_operation() == op_base 159 assert op_with_qubits.qubits == (d, c) 160 assert op_with_qubits.qubit_map == {a: d, b: c} 161 162 assert op_base.with_qubit_mapping({a: d, b: c, d: a}) == op_with_qubits 163 164 def map_fn(qubit: 'cirq.Qid') -> 'cirq.Qid': 165 if qubit == a: 166 return d 167 if qubit == b: 168 return c 169 return qubit 170 171 fn_op = op_base.with_qubit_mapping(map_fn) 172 assert fn_op == op_with_qubits 173 # map_fn does not affect qubits c and d. 174 assert fn_op.with_qubit_mapping(map_fn) == op_with_qubits 175 176 # with_qubits must receive the same number of qubits as the circuit contains. 177 with pytest.raises(ValueError, match='Expected 2 qubits, got 3'): 178 _ = op_base.with_qubits(c, d, b) 179 180 # Two qubits cannot be mapped onto the same target qubit. 181 with pytest.raises(ValueError, match='Collision in qubit map'): 182 _ = op_base.with_qubit_mapping({a: b}) 183 184 # Two qubits cannot be transformed into the same target qubit. 185 with pytest.raises(ValueError, match='Collision in qubit map'): 186 _ = op_base.with_qubit_mapping(lambda q: b) 187 # with_qubit_mapping requires exactly one argument. 188 with pytest.raises(TypeError, match='must be a function or dict'): 189 _ = op_base.with_qubit_mapping('bad arg') 190 191 192def test_with_measurement_keys(): 193 a, b = cirq.LineQubit.range(2) 194 circuit = cirq.FrozenCircuit( 195 cirq.X(a), 196 cirq.measure(b, key='mb'), 197 cirq.measure(a, key='ma'), 198 ) 199 op_base = cirq.CircuitOperation(circuit) 200 201 op_with_keys = op_base.with_measurement_key_mapping({'ma': 'pa', 'x': 'z'}) 202 assert op_with_keys.base_operation() == op_base 203 assert op_with_keys.measurement_key_map == {'ma': 'pa'} 204 assert cirq.measurement_keys(op_with_keys) == {'pa', 'mb'} 205 206 assert cirq.with_measurement_key_mapping(op_base, {'ma': 'pa'}) == op_with_keys 207 208 # Two measurement keys cannot be mapped onto the same target string. 209 with pytest.raises(ValueError): 210 _ = op_base.with_measurement_key_mapping({'ma': 'mb'}) 211 212 213def test_with_params(): 214 a = cirq.LineQubit(0) 215 z_exp = sympy.Symbol('z_exp') 216 x_exp = sympy.Symbol('x_exp') 217 delta = sympy.Symbol('delta') 218 theta = sympy.Symbol('theta') 219 circuit = cirq.FrozenCircuit(cirq.Z(a) ** z_exp, cirq.X(a) ** x_exp, cirq.Z(a) ** delta) 220 op_base = cirq.CircuitOperation(circuit) 221 222 param_dict = { 223 z_exp: 2, 224 x_exp: theta, 225 sympy.Symbol('k'): sympy.Symbol('phi'), 226 } 227 op_with_params = op_base.with_params(param_dict) 228 assert op_with_params.base_operation() == op_base 229 assert op_with_params.param_resolver == cirq.ParamResolver( 230 { 231 z_exp: 2, 232 x_exp: theta, 233 # As 'k' is irrelevant to the circuit, it does not appear here. 234 } 235 ) 236 assert cirq.parameter_names(op_with_params) == {'theta', 'delta'} 237 238 assert ( 239 cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict), recursive=False) 240 == op_with_params 241 ) 242 243 # Recursive parameter resolution is rejected. 244 with pytest.raises(ValueError, match='Use "recursive=False"'): 245 _ = cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict)) 246 247 248@pytest.mark.parametrize('add_measurements', [True, False]) 249@pytest.mark.parametrize('use_default_ids_for_initial_rep', [True, False]) 250def test_repeat(add_measurements, use_default_ids_for_initial_rep): 251 a, b = cirq.LineQubit.range(2) 252 circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b)) 253 if add_measurements: 254 circuit.append([cirq.measure(b, key='mb'), cirq.measure(a, key='ma')]) 255 op_base = cirq.CircuitOperation(circuit.freeze()) 256 assert op_base.repeat(1) is op_base 257 assert op_base.repeat(1, ['0']) != op_base 258 assert op_base.repeat(1, ['0']) == op_base.repeat(repetition_ids=['0']) 259 assert op_base.repeat(1, ['0']) == op_base.with_repetition_ids(['0']) 260 261 initial_repetitions = -3 262 if add_measurements: 263 with pytest.raises(ValueError, match='circuit is not invertible'): 264 _ = op_base.repeat(initial_repetitions) 265 initial_repetitions = abs(initial_repetitions) 266 267 op_with_reps = None # type: cirq.CircuitOperation 268 rep_ids = [] 269 if use_default_ids_for_initial_rep: 270 op_with_reps = op_base.repeat(initial_repetitions) 271 rep_ids = ['0', '1', '2'] 272 assert op_base ** initial_repetitions == op_with_reps 273 else: 274 rep_ids = ['a', 'b', 'c'] 275 op_with_reps = op_base.repeat(initial_repetitions, rep_ids) 276 assert op_base ** initial_repetitions != op_with_reps 277 assert (op_base ** initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps 278 assert op_with_reps.repetitions == initial_repetitions 279 assert op_with_reps.repetition_ids == rep_ids 280 assert op_with_reps.repeat(1) is op_with_reps 281 282 final_repetitions = 2 * initial_repetitions 283 284 op_with_consecutive_reps = op_with_reps.repeat(2) 285 assert op_with_consecutive_reps.repetitions == final_repetitions 286 assert op_with_consecutive_reps.repetition_ids == _full_join_string_lists(['0', '1'], rep_ids) 287 assert op_base ** final_repetitions != op_with_consecutive_reps 288 289 op_with_consecutive_reps = op_with_reps.repeat(2, ['a', 'b']) 290 assert op_with_reps.repeat(repetition_ids=['a', 'b']) == op_with_consecutive_reps 291 assert op_with_consecutive_reps.repetitions == final_repetitions 292 assert op_with_consecutive_reps.repetition_ids == _full_join_string_lists(['a', 'b'], rep_ids) 293 294 with pytest.raises(ValueError, match='length to be 2'): 295 _ = op_with_reps.repeat(2, ['a', 'b', 'c']) 296 297 with pytest.raises( 298 ValueError, match='At least one of repetitions and repetition_ids must be set' 299 ): 300 _ = op_base.repeat() 301 302 with pytest.raises(TypeError, match='Only integer repetitions are allowed'): 303 _ = op_base.repeat(1.3) 304 305 306def test_qid_shape(): 307 circuit = cirq.FrozenCircuit( 308 cirq.IdentityGate(qid_shape=(q.dimension,)).on(q) 309 for q in cirq.LineQid.for_qid_shape((1, 2, 3, 4)) 310 ) 311 op = cirq.CircuitOperation(circuit) 312 assert cirq.qid_shape(op) == (1, 2, 3, 4) 313 assert cirq.num_qubits(op) == 4 314 315 id_circuit = cirq.FrozenCircuit(cirq.I(q) for q in cirq.LineQubit.range(3)) 316 id_op = cirq.CircuitOperation(id_circuit) 317 assert cirq.qid_shape(id_op) == (2, 2, 2) 318 assert cirq.num_qubits(id_op) == 3 319 320 321def test_string_format(): 322 x, y, z = cirq.LineQubit.range(3) 323 324 fc0 = cirq.FrozenCircuit() 325 op0 = cirq.CircuitOperation(fc0) 326 assert ( 327 str(op0) 328 == f"""\ 329{op0.circuit.diagram_name()}: 330[ ]""" 331 ) 332 333 fc0_global_phase_inner = cirq.FrozenCircuit( 334 cirq.GlobalPhaseOperation(1j), cirq.GlobalPhaseOperation(1j) 335 ) 336 op0_global_phase_inner = cirq.CircuitOperation(fc0_global_phase_inner) 337 fc0_global_phase_outer = cirq.FrozenCircuit( 338 op0_global_phase_inner, cirq.GlobalPhaseOperation(1j) 339 ) 340 op0_global_phase_outer = cirq.CircuitOperation(fc0_global_phase_outer) 341 assert ( 342 str(op0_global_phase_outer) 343 == f"""\ 344{op0_global_phase_outer.circuit.diagram_name()}: 345[ ] 346[ ] 347[ global phase: -0.5π ]""" 348 ) 349 350 fc1 = cirq.FrozenCircuit(cirq.X(x), cirq.H(y), cirq.CX(y, z), cirq.measure(x, y, z, key='m')) 351 op1 = cirq.CircuitOperation(fc1) 352 assert ( 353 str(op1) 354 == f"""\ 355{op1.circuit.diagram_name()}: 356[ 0: ───X───────M('m')─── ] 357[ │ ] 358[ 1: ───H───@───M──────── ] 359[ │ │ ] 360[ 2: ───────X───M──────── ]""" 361 ) 362 assert ( 363 repr(op1) 364 == f"""\ 365cirq.CircuitOperation( 366 circuit=cirq.FrozenCircuit([ 367 cirq.Moment( 368 cirq.X(cirq.LineQubit(0)), 369 cirq.H(cirq.LineQubit(1)), 370 ), 371 cirq.Moment( 372 cirq.CNOT(cirq.LineQubit(1), cirq.LineQubit(2)), 373 ), 374 cirq.Moment( 375 cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), key='m'), 376 ), 377 ]), 378)""" 379 ) 380 381 fc2 = cirq.FrozenCircuit(cirq.X(x), cirq.H(y), cirq.CX(y, x)) 382 op2 = cirq.CircuitOperation( 383 circuit=fc2, 384 qubit_map=({y: z}), 385 repetitions=3, 386 parent_path=('outer', 'inner'), 387 repetition_ids=['a', 'b', 'c'], 388 ) 389 assert ( 390 str(op2) 391 == f"""\ 392{op2.circuit.diagram_name()}: 393[ 0: ───X───X─── ] 394[ │ ] 395[ 1: ───H───@─── ](qubit_map={{1: 2}}, parent_path=('outer', 'inner'),\ 396 repetition_ids=['a', 'b', 'c'])""" 397 ) 398 assert ( 399 repr(op2) 400 == """\ 401cirq.CircuitOperation( 402 circuit=cirq.FrozenCircuit([ 403 cirq.Moment( 404 cirq.X(cirq.LineQubit(0)), 405 cirq.H(cirq.LineQubit(1)), 406 ), 407 cirq.Moment( 408 cirq.CNOT(cirq.LineQubit(1), cirq.LineQubit(0)), 409 ), 410 ]), 411 repetitions=3, 412 qubit_map={cirq.LineQubit(1): cirq.LineQubit(2)}, 413 parent_path=('outer', 'inner'), 414 repetition_ids=['a', 'b', 'c'], 415)""" 416 ) 417 418 fc3 = cirq.FrozenCircuit( 419 cirq.X(x) ** sympy.Symbol('b'), 420 cirq.measure(x, key='m'), 421 ) 422 op3 = cirq.CircuitOperation( 423 circuit=fc3, 424 qubit_map={x: y}, 425 measurement_key_map={'m': 'p'}, 426 param_resolver={sympy.Symbol('b'): 2}, 427 ) 428 indented_fc3_repr = repr(fc3).replace('\n', '\n ') 429 assert ( 430 str(op3) 431 == f"""\ 432{op3.circuit.diagram_name()}: 433[ 0: ───X^b───M('m')─── ](qubit_map={{0: 1}}, \ 434key_map={{m: p}}, params={{b: 2}})""" 435 ) 436 assert ( 437 repr(op3) 438 == f"""\ 439cirq.CircuitOperation( 440 circuit={indented_fc3_repr}, 441 qubit_map={{cirq.LineQubit(0): cirq.LineQubit(1)}}, 442 measurement_key_map={{'m': 'p'}}, 443 param_resolver=cirq.ParamResolver({{sympy.Symbol('b'): 2}}), 444)""" 445 ) 446 447 fc4 = cirq.FrozenCircuit(cirq.X(y)) 448 op4 = cirq.CircuitOperation(fc4) 449 fc5 = cirq.FrozenCircuit(cirq.X(x), op4) 450 op5 = cirq.CircuitOperation(fc5) 451 assert ( 452 repr(op5) 453 == f"""\ 454cirq.CircuitOperation( 455 circuit=cirq.FrozenCircuit([ 456 cirq.Moment( 457 cirq.X(cirq.LineQubit(0)), 458 cirq.CircuitOperation( 459 circuit=cirq.FrozenCircuit([ 460 cirq.Moment( 461 cirq.X(cirq.LineQubit(1)), 462 ), 463 ]), 464 ), 465 ), 466 ]), 467)""" 468 ) 469 470 471def test_json_dict(): 472 a, b, c = cirq.LineQubit.range(3) 473 circuit = cirq.FrozenCircuit( 474 cirq.X(a), 475 cirq.Y(b), 476 cirq.H(c), 477 cirq.CX(a, b) ** sympy.Symbol('exp'), 478 cirq.measure(a, b, c, key='m'), 479 ) 480 op = cirq.CircuitOperation( 481 circuit=circuit, 482 qubit_map={c: b, b: c}, 483 measurement_key_map={'m': 'p'}, 484 param_resolver={'exp': 'theta'}, 485 parent_path=('nested', 'path'), 486 ) 487 488 assert op._json_dict_() == { 489 'cirq_type': 'CircuitOperation', 490 'circuit': circuit, 491 'repetitions': 1, 492 'qubit_map': sorted([(k, v) for k, v in op.qubit_map.items()]), 493 'measurement_key_map': op.measurement_key_map, 494 'param_resolver': op.param_resolver, 495 'parent_path': op.parent_path, 496 'repetition_ids': None, 497 } 498 499 500def test_terminal_matches(): 501 a, b = cirq.LineQubit.range(2) 502 fc = cirq.FrozenCircuit( 503 cirq.H(a), 504 cirq.measure(b, key='m1'), 505 ) 506 op = cirq.CircuitOperation(fc) 507 508 c = cirq.Circuit(cirq.X(a), op) 509 assert c.are_all_measurements_terminal() 510 assert c.are_any_measurements_terminal() 511 512 c = cirq.Circuit(cirq.X(b), op) 513 assert c.are_all_measurements_terminal() 514 assert c.are_any_measurements_terminal() 515 516 c = cirq.Circuit(cirq.measure(a), op) 517 assert not c.are_all_measurements_terminal() 518 assert c.are_any_measurements_terminal() 519 520 c = cirq.Circuit(cirq.measure(b), op) 521 assert not c.are_all_measurements_terminal() 522 assert c.are_any_measurements_terminal() 523 524 c = cirq.Circuit(op, cirq.X(a)) 525 assert c.are_all_measurements_terminal() 526 assert c.are_any_measurements_terminal() 527 528 c = cirq.Circuit(op, cirq.X(b)) 529 assert not c.are_all_measurements_terminal() 530 assert not c.are_any_measurements_terminal() 531 532 c = cirq.Circuit(op, cirq.measure(a)) 533 assert c.are_all_measurements_terminal() 534 assert c.are_any_measurements_terminal() 535 536 c = cirq.Circuit(op, cirq.measure(b)) 537 assert not c.are_all_measurements_terminal() 538 assert c.are_any_measurements_terminal() 539 540 541def test_nonterminal_in_subcircuit(): 542 a, b = cirq.LineQubit.range(2) 543 fc = cirq.FrozenCircuit( 544 cirq.H(a), 545 cirq.measure(b, key='m1'), 546 cirq.X(b), 547 ) 548 op = cirq.CircuitOperation(fc) 549 c = cirq.Circuit(cirq.X(a), op) 550 assert isinstance(op, cirq.CircuitOperation) 551 assert not c.are_all_measurements_terminal() 552 assert not c.are_any_measurements_terminal() 553 554 op = op.with_tags('test') 555 c = cirq.Circuit(cirq.X(a), op) 556 assert not isinstance(op, cirq.CircuitOperation) 557 assert not c.are_all_measurements_terminal() 558 assert not c.are_any_measurements_terminal() 559 560 561def test_decompose_applies_maps(): 562 a, b, c = cirq.LineQubit.range(3) 563 exp = sympy.Symbol('exp') 564 theta = sympy.Symbol('theta') 565 circuit = cirq.FrozenCircuit( 566 cirq.X(a) ** theta, 567 cirq.Y(b), 568 cirq.H(c), 569 cirq.CX(a, b) ** exp, 570 cirq.measure(a, b, c, key='m'), 571 ) 572 op = cirq.CircuitOperation( 573 circuit=circuit, 574 qubit_map={ 575 c: b, 576 b: c, 577 }, 578 measurement_key_map={'m': 'p'}, 579 param_resolver={exp: theta, theta: exp}, 580 ) 581 582 expected_circuit = cirq.Circuit( 583 cirq.X(a) ** exp, 584 cirq.Y(c), 585 cirq.H(b), 586 cirq.CX(a, c) ** theta, 587 cirq.measure(a, c, b, key='p'), 588 ) 589 assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit 590 591 592def test_decompose_loops(): 593 a, b = cirq.LineQubit.range(2) 594 circuit = cirq.FrozenCircuit( 595 cirq.H(a), 596 cirq.CX(a, b), 597 ) 598 base_op = cirq.CircuitOperation(circuit) 599 600 op = base_op.with_qubits(b, a).repeat(3) 601 expected_circuit = cirq.Circuit( 602 cirq.H(b), 603 cirq.CX(b, a), 604 cirq.H(b), 605 cirq.CX(b, a), 606 cirq.H(b), 607 cirq.CX(b, a), 608 ) 609 assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit 610 611 op = base_op.repeat(-2) 612 expected_circuit = cirq.Circuit( 613 cirq.CX(a, b), 614 cirq.H(a), 615 cirq.CX(a, b), 616 cirq.H(a), 617 ) 618 assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit 619 620 621def test_decompose_loops_with_measurements(): 622 a, b = cirq.LineQubit.range(2) 623 circuit = cirq.FrozenCircuit( 624 cirq.H(a), 625 cirq.CX(a, b), 626 cirq.measure(a, b, key='m'), 627 ) 628 base_op = cirq.CircuitOperation(circuit) 629 630 op = base_op.with_qubits(b, a).repeat(3) 631 expected_circuit = cirq.Circuit( 632 cirq.H(b), 633 cirq.CX(b, a), 634 cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('0:m')), 635 cirq.H(b), 636 cirq.CX(b, a), 637 cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('1:m')), 638 cirq.H(b), 639 cirq.CX(b, a), 640 cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('2:m')), 641 ) 642 assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit 643 644 645def test_decompose_nested(): 646 a, b, c, d = cirq.LineQubit.range(4) 647 exp1 = sympy.Symbol('exp1') 648 exp_half = sympy.Symbol('exp_half') 649 exp_one = sympy.Symbol('exp_one') 650 exp_two = sympy.Symbol('exp_two') 651 circuit1 = cirq.FrozenCircuit(cirq.X(a) ** exp1, cirq.measure(a, key='m1')) 652 op1 = cirq.CircuitOperation(circuit1) 653 circuit2 = cirq.FrozenCircuit( 654 op1.with_qubits(a).with_measurement_key_mapping({'m1': 'ma'}), 655 op1.with_qubits(b).with_measurement_key_mapping({'m1': 'mb'}), 656 op1.with_qubits(c).with_measurement_key_mapping({'m1': 'mc'}), 657 op1.with_qubits(d).with_measurement_key_mapping({'m1': 'md'}), 658 ) 659 op2 = cirq.CircuitOperation(circuit2) 660 circuit3 = cirq.FrozenCircuit( 661 op2.with_params({exp1: exp_half}), 662 op2.with_params({exp1: exp_one}), 663 op2.with_params({exp1: exp_two}), 664 ) 665 op3 = cirq.CircuitOperation(circuit3) 666 667 final_op = op3.with_params({exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}) 668 669 expected_circuit1 = cirq.Circuit( 670 op2.with_params({exp1: 0.5, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}), 671 op2.with_params({exp1: 1.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}), 672 op2.with_params({exp1: 2.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}), 673 ) 674 675 result_ops1 = cirq.decompose_once(final_op) 676 assert cirq.Circuit(result_ops1) == expected_circuit1 677 678 expected_circuit = cirq.Circuit( 679 cirq.X(a) ** 0.5, 680 cirq.measure(a, key='ma'), 681 cirq.X(b) ** 0.5, 682 cirq.measure(b, key='mb'), 683 cirq.X(c) ** 0.5, 684 cirq.measure(c, key='mc'), 685 cirq.X(d) ** 0.5, 686 cirq.measure(d, key='md'), 687 cirq.X(a) ** 1.0, 688 cirq.measure(a, key='ma'), 689 cirq.X(b) ** 1.0, 690 cirq.measure(b, key='mb'), 691 cirq.X(c) ** 1.0, 692 cirq.measure(c, key='mc'), 693 cirq.X(d) ** 1.0, 694 cirq.measure(d, key='md'), 695 cirq.X(a) ** 2.0, 696 cirq.measure(a, key='ma'), 697 cirq.X(b) ** 2.0, 698 cirq.measure(b, key='mb'), 699 cirq.X(c) ** 2.0, 700 cirq.measure(c, key='mc'), 701 cirq.X(d) ** 2.0, 702 cirq.measure(d, key='md'), 703 ) 704 assert cirq.Circuit(cirq.decompose(final_op)) == expected_circuit 705 # Verify that mapped_circuit gives the same operations. 706 assert final_op.mapped_circuit(deep=True) == expected_circuit 707 708 709def test_decompose_repeated_nested_measurements(): 710 # Details of this test described at 711 # https://tinyurl.com/measurement-repeated-circuitop#heading=h.sbgxcsyin9wt. 712 a = cirq.LineQubit(0) 713 714 op1 = ( 715 cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='A'))) 716 .with_measurement_key_mapping({'A': 'B'}) 717 .repeat(2, ['zero', 'one']) 718 ) 719 720 op2 = ( 721 cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='P'), op1)) 722 .with_measurement_key_mapping({'B': 'C', 'P': 'Q'}) 723 .repeat(2, ['zero', 'one']) 724 ) 725 726 op3 = ( 727 cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='X'), op2)) 728 .with_measurement_key_mapping({'C': 'D', 'X': 'Y'}) 729 .repeat(2, ['zero', 'one']) 730 ) 731 732 expected_measurement_keys_in_order = [ 733 'zero:Y', 734 'zero:zero:Q', 735 'zero:zero:zero:D', 736 'zero:zero:one:D', 737 'zero:one:Q', 738 'zero:one:zero:D', 739 'zero:one:one:D', 740 'one:Y', 741 'one:zero:Q', 742 'one:zero:zero:D', 743 'one:zero:one:D', 744 'one:one:Q', 745 'one:one:zero:D', 746 'one:one:one:D', 747 ] 748 assert cirq.measurement_keys(op3) == set(expected_measurement_keys_in_order) 749 750 expected_circuit = cirq.Circuit() 751 for key in expected_measurement_keys_in_order: 752 expected_circuit.append(cirq.measure(a, key=cirq.MeasurementKey.parse_serialized(key))) 753 754 assert cirq.Circuit(cirq.decompose(op3)) == expected_circuit 755 assert cirq.measurement_keys(expected_circuit) == set(expected_measurement_keys_in_order) 756 757 # Verify that mapped_circuit gives the same operations. 758 assert op3.mapped_circuit(deep=True) == expected_circuit 759 760 761def test_mapped_circuit_preserves_moments(): 762 q0, q1 = cirq.LineQubit.range(2) 763 fc = cirq.FrozenCircuit(cirq.Moment(cirq.X(q0)), cirq.Moment(cirq.X(q1))) 764 op = cirq.CircuitOperation(fc) 765 assert op.mapped_circuit() == fc 766 assert op.repeat(3).mapped_circuit(deep=True) == fc * 3 767 768 769def test_mapped_op(): 770 q0, q1 = cirq.LineQubit.range(2) 771 a, b = (sympy.Symbol(x) for x in 'ab') 772 fc1 = cirq.FrozenCircuit(cirq.X(q0) ** a, cirq.measure(q0, q1, key='m')) 773 op1 = ( 774 cirq.CircuitOperation(fc1) 775 .with_params({'a': 'b'}) 776 .with_qubits(q1, q0) 777 .with_measurement_key_mapping({'m': 'k'}) 778 ) 779 fc2 = cirq.FrozenCircuit(cirq.X(q1) ** b, cirq.measure(q1, q0, key='k')) 780 op2 = cirq.CircuitOperation(fc2) 781 782 assert op1.mapped_op() == op2 783 784 785def test_tag_propagation(): 786 # Tags are not propagated from the CircuitOperation to its components. 787 # TODO: support tag propagation for better serialization. 788 a, b, c = cirq.LineQubit.range(3) 789 circuit = cirq.FrozenCircuit( 790 cirq.X(a), 791 cirq.H(b), 792 cirq.H(c), 793 cirq.CZ(a, c), 794 ) 795 op = cirq.CircuitOperation(circuit) 796 test_tag = 'test_tag' 797 op = op.with_tags(test_tag) 798 799 assert test_tag in op.tags 800 801 # TODO: Tags must propagate during decomposition. 802 sub_ops = cirq.decompose(op) 803 for op in sub_ops: 804 assert test_tag not in op.tags 805 806 807# TODO: Operation has a "gate" property. What is this for a CircuitOperation? 808