1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import List
16
17import pytest, sympy
18
19import cirq
20from cirq.circuits.circuit_operation import _full_join_string_lists
21
22
23def test_properties():
24    a, b, c = cirq.LineQubit.range(3)
25    circuit = cirq.FrozenCircuit(
26        cirq.X(a),
27        cirq.Y(b),
28        cirq.H(c),
29        cirq.CX(a, b) ** sympy.Symbol('exp'),
30        cirq.measure(a, b, c, key='m'),
31    )
32    op = cirq.CircuitOperation(circuit)
33    assert op.circuit is circuit
34    assert op.qubits == (a, b, c)
35    assert op.qubit_map == {}
36    assert op.measurement_key_map == {}
37    assert op.param_resolver == cirq.ParamResolver()
38    assert op.repetitions == 1
39    assert op.repetition_ids is None
40    # Despite having the same decomposition, these objects are not equal.
41    assert op != circuit
42    assert op == circuit.to_op()
43
44
45def test_circuit_type():
46    a, b, c = cirq.LineQubit.range(3)
47    circuit = cirq.Circuit(
48        cirq.X(a),
49        cirq.Y(b),
50        cirq.H(c),
51        cirq.CX(a, b) ** sympy.Symbol('exp'),
52        cirq.measure(a, b, c, key='m'),
53    )
54    with pytest.raises(TypeError, match='Expected circuit of type FrozenCircuit'):
55        _ = cirq.CircuitOperation(circuit)
56
57
58def test_non_invertible_circuit():
59    a, b, c = cirq.LineQubit.range(3)
60    circuit = cirq.FrozenCircuit(
61        cirq.X(a),
62        cirq.Y(b),
63        cirq.H(c),
64        cirq.CX(a, b) ** sympy.Symbol('exp'),
65        cirq.measure(a, b, c, key='m'),
66    )
67    with pytest.raises(ValueError, match='circuit is not invertible'):
68        _ = cirq.CircuitOperation(circuit, repetitions=-2)
69
70
71def test_repetitions_and_ids_length_mismatch():
72    a, b, c = cirq.LineQubit.range(3)
73    circuit = cirq.FrozenCircuit(
74        cirq.X(a),
75        cirq.Y(b),
76        cirq.H(c),
77        cirq.CX(a, b) ** sympy.Symbol('exp'),
78        cirq.measure(a, b, c, key='m'),
79    )
80    with pytest.raises(ValueError, match='Expected repetition_ids to be a list of length 2'):
81        _ = cirq.CircuitOperation(circuit, repetitions=2, repetition_ids=['a', 'b', 'c'])
82
83
84def test_is_measurement_memoization():
85    a = cirq.LineQubit(0)
86    circuit = cirq.FrozenCircuit(cirq.measure(a, key='m'))
87    c_op = cirq.CircuitOperation(circuit)
88    assert circuit._has_measurements is None
89    # Memoize `_has_measurements` in the circuit.
90    assert cirq.is_measurement(c_op)
91    assert circuit._has_measurements is True
92
93
94def test_invalid_measurement_keys():
95    a = cirq.LineQubit(0)
96    circuit = cirq.FrozenCircuit(cirq.measure(a, key='m'))
97    c_op = cirq.CircuitOperation(circuit)
98    # Invalid key remapping
99    with pytest.raises(ValueError, match='Mapping to invalid key: m:a'):
100        _ = c_op.with_measurement_key_mapping({'m': 'm:a'})
101
102    # Invalid key remapping nested CircuitOperation
103    with pytest.raises(ValueError, match='Mapping to invalid key: m:a'):
104        _ = cirq.CircuitOperation(cirq.FrozenCircuit(c_op), measurement_key_map={'m': 'm:a'})
105
106    # Originally invalid key
107    with pytest.raises(ValueError, match='Invalid key name: m:a'):
108        _ = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='m:a')))
109
110    # Remapped to valid key
111    _ = cirq.CircuitOperation(circuit, measurement_key_map={'m:a': 'ma'})
112
113
114def test_invalid_qubit_mapping():
115    q = cirq.LineQubit(0)
116    q3 = cirq.LineQid(1, dimension=3)
117
118    # Invalid qid remapping dict in constructor
119    with pytest.raises(ValueError, match='Qid dimension conflict'):
120        _ = cirq.CircuitOperation(cirq.FrozenCircuit(), qubit_map={q: q3})
121
122    # Invalid qid remapping dict in with_qubit_mapping call
123    c_op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.X(q)))
124    with pytest.raises(ValueError, match='Qid dimension conflict'):
125        _ = c_op.with_qubit_mapping({q: q3})
126
127    # Invalid qid remapping function in with_qubit_mapping call
128    with pytest.raises(ValueError, match='Qid dimension conflict'):
129        _ = c_op.with_qubit_mapping(lambda q: q3)
130
131
132def test_circuit_sharing():
133    a, b, c = cirq.LineQubit.range(3)
134    circuit = cirq.FrozenCircuit(
135        cirq.X(a),
136        cirq.Y(b),
137        cirq.H(c),
138        cirq.CX(a, b) ** sympy.Symbol('exp'),
139        cirq.measure(a, b, c, key='m'),
140    )
141    op1 = cirq.CircuitOperation(circuit)
142    op2 = cirq.CircuitOperation(op1.circuit)
143    op3 = circuit.to_op()
144    assert op1.circuit is circuit
145    assert op2.circuit is circuit
146    assert op3.circuit is circuit
147
148    assert hash(op1) == hash(op2)
149    assert hash(op1) == hash(op3)
150
151
152def test_with_qubits():
153    a, b, c, d = cirq.LineQubit.range(4)
154    circuit = cirq.FrozenCircuit(cirq.H(a), cirq.CX(a, b))
155    op_base = cirq.CircuitOperation(circuit)
156
157    op_with_qubits = op_base.with_qubits(d, c)
158    assert op_with_qubits.base_operation() == op_base
159    assert op_with_qubits.qubits == (d, c)
160    assert op_with_qubits.qubit_map == {a: d, b: c}
161
162    assert op_base.with_qubit_mapping({a: d, b: c, d: a}) == op_with_qubits
163
164    def map_fn(qubit: 'cirq.Qid') -> 'cirq.Qid':
165        if qubit == a:
166            return d
167        if qubit == b:
168            return c
169        return qubit
170
171    fn_op = op_base.with_qubit_mapping(map_fn)
172    assert fn_op == op_with_qubits
173    # map_fn does not affect qubits c and d.
174    assert fn_op.with_qubit_mapping(map_fn) == op_with_qubits
175
176    # with_qubits must receive the same number of qubits as the circuit contains.
177    with pytest.raises(ValueError, match='Expected 2 qubits, got 3'):
178        _ = op_base.with_qubits(c, d, b)
179
180    # Two qubits cannot be mapped onto the same target qubit.
181    with pytest.raises(ValueError, match='Collision in qubit map'):
182        _ = op_base.with_qubit_mapping({a: b})
183
184    # Two qubits cannot be transformed into the same target qubit.
185    with pytest.raises(ValueError, match='Collision in qubit map'):
186        _ = op_base.with_qubit_mapping(lambda q: b)
187    # with_qubit_mapping requires exactly one argument.
188    with pytest.raises(TypeError, match='must be a function or dict'):
189        _ = op_base.with_qubit_mapping('bad arg')
190
191
192def test_with_measurement_keys():
193    a, b = cirq.LineQubit.range(2)
194    circuit = cirq.FrozenCircuit(
195        cirq.X(a),
196        cirq.measure(b, key='mb'),
197        cirq.measure(a, key='ma'),
198    )
199    op_base = cirq.CircuitOperation(circuit)
200
201    op_with_keys = op_base.with_measurement_key_mapping({'ma': 'pa', 'x': 'z'})
202    assert op_with_keys.base_operation() == op_base
203    assert op_with_keys.measurement_key_map == {'ma': 'pa'}
204    assert cirq.measurement_keys(op_with_keys) == {'pa', 'mb'}
205
206    assert cirq.with_measurement_key_mapping(op_base, {'ma': 'pa'}) == op_with_keys
207
208    # Two measurement keys cannot be mapped onto the same target string.
209    with pytest.raises(ValueError):
210        _ = op_base.with_measurement_key_mapping({'ma': 'mb'})
211
212
213def test_with_params():
214    a = cirq.LineQubit(0)
215    z_exp = sympy.Symbol('z_exp')
216    x_exp = sympy.Symbol('x_exp')
217    delta = sympy.Symbol('delta')
218    theta = sympy.Symbol('theta')
219    circuit = cirq.FrozenCircuit(cirq.Z(a) ** z_exp, cirq.X(a) ** x_exp, cirq.Z(a) ** delta)
220    op_base = cirq.CircuitOperation(circuit)
221
222    param_dict = {
223        z_exp: 2,
224        x_exp: theta,
225        sympy.Symbol('k'): sympy.Symbol('phi'),
226    }
227    op_with_params = op_base.with_params(param_dict)
228    assert op_with_params.base_operation() == op_base
229    assert op_with_params.param_resolver == cirq.ParamResolver(
230        {
231            z_exp: 2,
232            x_exp: theta,
233            # As 'k' is irrelevant to the circuit, it does not appear here.
234        }
235    )
236    assert cirq.parameter_names(op_with_params) == {'theta', 'delta'}
237
238    assert (
239        cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict), recursive=False)
240        == op_with_params
241    )
242
243    # Recursive parameter resolution is rejected.
244    with pytest.raises(ValueError, match='Use "recursive=False"'):
245        _ = cirq.resolve_parameters(op_base, cirq.ParamResolver(param_dict))
246
247
248@pytest.mark.parametrize('add_measurements', [True, False])
249@pytest.mark.parametrize('use_default_ids_for_initial_rep', [True, False])
250def test_repeat(add_measurements, use_default_ids_for_initial_rep):
251    a, b = cirq.LineQubit.range(2)
252    circuit = cirq.Circuit(cirq.H(a), cirq.CX(a, b))
253    if add_measurements:
254        circuit.append([cirq.measure(b, key='mb'), cirq.measure(a, key='ma')])
255    op_base = cirq.CircuitOperation(circuit.freeze())
256    assert op_base.repeat(1) is op_base
257    assert op_base.repeat(1, ['0']) != op_base
258    assert op_base.repeat(1, ['0']) == op_base.repeat(repetition_ids=['0'])
259    assert op_base.repeat(1, ['0']) == op_base.with_repetition_ids(['0'])
260
261    initial_repetitions = -3
262    if add_measurements:
263        with pytest.raises(ValueError, match='circuit is not invertible'):
264            _ = op_base.repeat(initial_repetitions)
265        initial_repetitions = abs(initial_repetitions)
266
267    op_with_reps = None  # type: cirq.CircuitOperation
268    rep_ids = []
269    if use_default_ids_for_initial_rep:
270        op_with_reps = op_base.repeat(initial_repetitions)
271        rep_ids = ['0', '1', '2']
272        assert op_base ** initial_repetitions == op_with_reps
273    else:
274        rep_ids = ['a', 'b', 'c']
275        op_with_reps = op_base.repeat(initial_repetitions, rep_ids)
276        assert op_base ** initial_repetitions != op_with_reps
277        assert (op_base ** initial_repetitions).replace(repetition_ids=rep_ids) == op_with_reps
278    assert op_with_reps.repetitions == initial_repetitions
279    assert op_with_reps.repetition_ids == rep_ids
280    assert op_with_reps.repeat(1) is op_with_reps
281
282    final_repetitions = 2 * initial_repetitions
283
284    op_with_consecutive_reps = op_with_reps.repeat(2)
285    assert op_with_consecutive_reps.repetitions == final_repetitions
286    assert op_with_consecutive_reps.repetition_ids == _full_join_string_lists(['0', '1'], rep_ids)
287    assert op_base ** final_repetitions != op_with_consecutive_reps
288
289    op_with_consecutive_reps = op_with_reps.repeat(2, ['a', 'b'])
290    assert op_with_reps.repeat(repetition_ids=['a', 'b']) == op_with_consecutive_reps
291    assert op_with_consecutive_reps.repetitions == final_repetitions
292    assert op_with_consecutive_reps.repetition_ids == _full_join_string_lists(['a', 'b'], rep_ids)
293
294    with pytest.raises(ValueError, match='length to be 2'):
295        _ = op_with_reps.repeat(2, ['a', 'b', 'c'])
296
297    with pytest.raises(
298        ValueError, match='At least one of repetitions and repetition_ids must be set'
299    ):
300        _ = op_base.repeat()
301
302    with pytest.raises(TypeError, match='Only integer repetitions are allowed'):
303        _ = op_base.repeat(1.3)
304
305
306def test_qid_shape():
307    circuit = cirq.FrozenCircuit(
308        cirq.IdentityGate(qid_shape=(q.dimension,)).on(q)
309        for q in cirq.LineQid.for_qid_shape((1, 2, 3, 4))
310    )
311    op = cirq.CircuitOperation(circuit)
312    assert cirq.qid_shape(op) == (1, 2, 3, 4)
313    assert cirq.num_qubits(op) == 4
314
315    id_circuit = cirq.FrozenCircuit(cirq.I(q) for q in cirq.LineQubit.range(3))
316    id_op = cirq.CircuitOperation(id_circuit)
317    assert cirq.qid_shape(id_op) == (2, 2, 2)
318    assert cirq.num_qubits(id_op) == 3
319
320
321def test_string_format():
322    x, y, z = cirq.LineQubit.range(3)
323
324    fc0 = cirq.FrozenCircuit()
325    op0 = cirq.CircuitOperation(fc0)
326    assert (
327        str(op0)
328        == f"""\
329{op0.circuit.diagram_name()}:
330[                         ]"""
331    )
332
333    fc0_global_phase_inner = cirq.FrozenCircuit(
334        cirq.GlobalPhaseOperation(1j), cirq.GlobalPhaseOperation(1j)
335    )
336    op0_global_phase_inner = cirq.CircuitOperation(fc0_global_phase_inner)
337    fc0_global_phase_outer = cirq.FrozenCircuit(
338        op0_global_phase_inner, cirq.GlobalPhaseOperation(1j)
339    )
340    op0_global_phase_outer = cirq.CircuitOperation(fc0_global_phase_outer)
341    assert (
342        str(op0_global_phase_outer)
343        == f"""\
344{op0_global_phase_outer.circuit.diagram_name()}:
345[                         ]
346[                         ]
347[ global phase:   -0.5π   ]"""
348    )
349
350    fc1 = cirq.FrozenCircuit(cirq.X(x), cirq.H(y), cirq.CX(y, z), cirq.measure(x, y, z, key='m'))
351    op1 = cirq.CircuitOperation(fc1)
352    assert (
353        str(op1)
354        == f"""\
355{op1.circuit.diagram_name()}:
356[ 0: ───X───────M('m')─── ]
357[               │         ]
358[ 1: ───H───@───M──────── ]
359[           │   │         ]
360[ 2: ───────X───M──────── ]"""
361    )
362    assert (
363        repr(op1)
364        == f"""\
365cirq.CircuitOperation(
366    circuit=cirq.FrozenCircuit([
367        cirq.Moment(
368            cirq.X(cirq.LineQubit(0)),
369            cirq.H(cirq.LineQubit(1)),
370        ),
371        cirq.Moment(
372            cirq.CNOT(cirq.LineQubit(1), cirq.LineQubit(2)),
373        ),
374        cirq.Moment(
375            cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), cirq.LineQubit(2), key='m'),
376        ),
377    ]),
378)"""
379    )
380
381    fc2 = cirq.FrozenCircuit(cirq.X(x), cirq.H(y), cirq.CX(y, x))
382    op2 = cirq.CircuitOperation(
383        circuit=fc2,
384        qubit_map=({y: z}),
385        repetitions=3,
386        parent_path=('outer', 'inner'),
387        repetition_ids=['a', 'b', 'c'],
388    )
389    assert (
390        str(op2)
391        == f"""\
392{op2.circuit.diagram_name()}:
393[ 0: ───X───X───          ]
394[           │             ]
395[ 1: ───H───@───          ](qubit_map={{1: 2}}, parent_path=('outer', 'inner'),\
396 repetition_ids=['a', 'b', 'c'])"""
397    )
398    assert (
399        repr(op2)
400        == """\
401cirq.CircuitOperation(
402    circuit=cirq.FrozenCircuit([
403        cirq.Moment(
404            cirq.X(cirq.LineQubit(0)),
405            cirq.H(cirq.LineQubit(1)),
406        ),
407        cirq.Moment(
408            cirq.CNOT(cirq.LineQubit(1), cirq.LineQubit(0)),
409        ),
410    ]),
411    repetitions=3,
412    qubit_map={cirq.LineQubit(1): cirq.LineQubit(2)},
413    parent_path=('outer', 'inner'),
414    repetition_ids=['a', 'b', 'c'],
415)"""
416    )
417
418    fc3 = cirq.FrozenCircuit(
419        cirq.X(x) ** sympy.Symbol('b'),
420        cirq.measure(x, key='m'),
421    )
422    op3 = cirq.CircuitOperation(
423        circuit=fc3,
424        qubit_map={x: y},
425        measurement_key_map={'m': 'p'},
426        param_resolver={sympy.Symbol('b'): 2},
427    )
428    indented_fc3_repr = repr(fc3).replace('\n', '\n    ')
429    assert (
430        str(op3)
431        == f"""\
432{op3.circuit.diagram_name()}:
433[ 0: ───X^b───M('m')───   ](qubit_map={{0: 1}}, \
434key_map={{m: p}}, params={{b: 2}})"""
435    )
436    assert (
437        repr(op3)
438        == f"""\
439cirq.CircuitOperation(
440    circuit={indented_fc3_repr},
441    qubit_map={{cirq.LineQubit(0): cirq.LineQubit(1)}},
442    measurement_key_map={{'m': 'p'}},
443    param_resolver=cirq.ParamResolver({{sympy.Symbol('b'): 2}}),
444)"""
445    )
446
447    fc4 = cirq.FrozenCircuit(cirq.X(y))
448    op4 = cirq.CircuitOperation(fc4)
449    fc5 = cirq.FrozenCircuit(cirq.X(x), op4)
450    op5 = cirq.CircuitOperation(fc5)
451    assert (
452        repr(op5)
453        == f"""\
454cirq.CircuitOperation(
455    circuit=cirq.FrozenCircuit([
456        cirq.Moment(
457            cirq.X(cirq.LineQubit(0)),
458            cirq.CircuitOperation(
459                circuit=cirq.FrozenCircuit([
460                    cirq.Moment(
461                        cirq.X(cirq.LineQubit(1)),
462                    ),
463                ]),
464            ),
465        ),
466    ]),
467)"""
468    )
469
470
471def test_json_dict():
472    a, b, c = cirq.LineQubit.range(3)
473    circuit = cirq.FrozenCircuit(
474        cirq.X(a),
475        cirq.Y(b),
476        cirq.H(c),
477        cirq.CX(a, b) ** sympy.Symbol('exp'),
478        cirq.measure(a, b, c, key='m'),
479    )
480    op = cirq.CircuitOperation(
481        circuit=circuit,
482        qubit_map={c: b, b: c},
483        measurement_key_map={'m': 'p'},
484        param_resolver={'exp': 'theta'},
485        parent_path=('nested', 'path'),
486    )
487
488    assert op._json_dict_() == {
489        'cirq_type': 'CircuitOperation',
490        'circuit': circuit,
491        'repetitions': 1,
492        'qubit_map': sorted([(k, v) for k, v in op.qubit_map.items()]),
493        'measurement_key_map': op.measurement_key_map,
494        'param_resolver': op.param_resolver,
495        'parent_path': op.parent_path,
496        'repetition_ids': None,
497    }
498
499
500def test_terminal_matches():
501    a, b = cirq.LineQubit.range(2)
502    fc = cirq.FrozenCircuit(
503        cirq.H(a),
504        cirq.measure(b, key='m1'),
505    )
506    op = cirq.CircuitOperation(fc)
507
508    c = cirq.Circuit(cirq.X(a), op)
509    assert c.are_all_measurements_terminal()
510    assert c.are_any_measurements_terminal()
511
512    c = cirq.Circuit(cirq.X(b), op)
513    assert c.are_all_measurements_terminal()
514    assert c.are_any_measurements_terminal()
515
516    c = cirq.Circuit(cirq.measure(a), op)
517    assert not c.are_all_measurements_terminal()
518    assert c.are_any_measurements_terminal()
519
520    c = cirq.Circuit(cirq.measure(b), op)
521    assert not c.are_all_measurements_terminal()
522    assert c.are_any_measurements_terminal()
523
524    c = cirq.Circuit(op, cirq.X(a))
525    assert c.are_all_measurements_terminal()
526    assert c.are_any_measurements_terminal()
527
528    c = cirq.Circuit(op, cirq.X(b))
529    assert not c.are_all_measurements_terminal()
530    assert not c.are_any_measurements_terminal()
531
532    c = cirq.Circuit(op, cirq.measure(a))
533    assert c.are_all_measurements_terminal()
534    assert c.are_any_measurements_terminal()
535
536    c = cirq.Circuit(op, cirq.measure(b))
537    assert not c.are_all_measurements_terminal()
538    assert c.are_any_measurements_terminal()
539
540
541def test_nonterminal_in_subcircuit():
542    a, b = cirq.LineQubit.range(2)
543    fc = cirq.FrozenCircuit(
544        cirq.H(a),
545        cirq.measure(b, key='m1'),
546        cirq.X(b),
547    )
548    op = cirq.CircuitOperation(fc)
549    c = cirq.Circuit(cirq.X(a), op)
550    assert isinstance(op, cirq.CircuitOperation)
551    assert not c.are_all_measurements_terminal()
552    assert not c.are_any_measurements_terminal()
553
554    op = op.with_tags('test')
555    c = cirq.Circuit(cirq.X(a), op)
556    assert not isinstance(op, cirq.CircuitOperation)
557    assert not c.are_all_measurements_terminal()
558    assert not c.are_any_measurements_terminal()
559
560
561def test_decompose_applies_maps():
562    a, b, c = cirq.LineQubit.range(3)
563    exp = sympy.Symbol('exp')
564    theta = sympy.Symbol('theta')
565    circuit = cirq.FrozenCircuit(
566        cirq.X(a) ** theta,
567        cirq.Y(b),
568        cirq.H(c),
569        cirq.CX(a, b) ** exp,
570        cirq.measure(a, b, c, key='m'),
571    )
572    op = cirq.CircuitOperation(
573        circuit=circuit,
574        qubit_map={
575            c: b,
576            b: c,
577        },
578        measurement_key_map={'m': 'p'},
579        param_resolver={exp: theta, theta: exp},
580    )
581
582    expected_circuit = cirq.Circuit(
583        cirq.X(a) ** exp,
584        cirq.Y(c),
585        cirq.H(b),
586        cirq.CX(a, c) ** theta,
587        cirq.measure(a, c, b, key='p'),
588    )
589    assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit
590
591
592def test_decompose_loops():
593    a, b = cirq.LineQubit.range(2)
594    circuit = cirq.FrozenCircuit(
595        cirq.H(a),
596        cirq.CX(a, b),
597    )
598    base_op = cirq.CircuitOperation(circuit)
599
600    op = base_op.with_qubits(b, a).repeat(3)
601    expected_circuit = cirq.Circuit(
602        cirq.H(b),
603        cirq.CX(b, a),
604        cirq.H(b),
605        cirq.CX(b, a),
606        cirq.H(b),
607        cirq.CX(b, a),
608    )
609    assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit
610
611    op = base_op.repeat(-2)
612    expected_circuit = cirq.Circuit(
613        cirq.CX(a, b),
614        cirq.H(a),
615        cirq.CX(a, b),
616        cirq.H(a),
617    )
618    assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit
619
620
621def test_decompose_loops_with_measurements():
622    a, b = cirq.LineQubit.range(2)
623    circuit = cirq.FrozenCircuit(
624        cirq.H(a),
625        cirq.CX(a, b),
626        cirq.measure(a, b, key='m'),
627    )
628    base_op = cirq.CircuitOperation(circuit)
629
630    op = base_op.with_qubits(b, a).repeat(3)
631    expected_circuit = cirq.Circuit(
632        cirq.H(b),
633        cirq.CX(b, a),
634        cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('0:m')),
635        cirq.H(b),
636        cirq.CX(b, a),
637        cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('1:m')),
638        cirq.H(b),
639        cirq.CX(b, a),
640        cirq.measure(b, a, key=cirq.MeasurementKey.parse_serialized('2:m')),
641    )
642    assert cirq.Circuit(cirq.decompose_once(op)) == expected_circuit
643
644
645def test_decompose_nested():
646    a, b, c, d = cirq.LineQubit.range(4)
647    exp1 = sympy.Symbol('exp1')
648    exp_half = sympy.Symbol('exp_half')
649    exp_one = sympy.Symbol('exp_one')
650    exp_two = sympy.Symbol('exp_two')
651    circuit1 = cirq.FrozenCircuit(cirq.X(a) ** exp1, cirq.measure(a, key='m1'))
652    op1 = cirq.CircuitOperation(circuit1)
653    circuit2 = cirq.FrozenCircuit(
654        op1.with_qubits(a).with_measurement_key_mapping({'m1': 'ma'}),
655        op1.with_qubits(b).with_measurement_key_mapping({'m1': 'mb'}),
656        op1.with_qubits(c).with_measurement_key_mapping({'m1': 'mc'}),
657        op1.with_qubits(d).with_measurement_key_mapping({'m1': 'md'}),
658    )
659    op2 = cirq.CircuitOperation(circuit2)
660    circuit3 = cirq.FrozenCircuit(
661        op2.with_params({exp1: exp_half}),
662        op2.with_params({exp1: exp_one}),
663        op2.with_params({exp1: exp_two}),
664    )
665    op3 = cirq.CircuitOperation(circuit3)
666
667    final_op = op3.with_params({exp_half: 0.5, exp_one: 1.0, exp_two: 2.0})
668
669    expected_circuit1 = cirq.Circuit(
670        op2.with_params({exp1: 0.5, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}),
671        op2.with_params({exp1: 1.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}),
672        op2.with_params({exp1: 2.0, exp_half: 0.5, exp_one: 1.0, exp_two: 2.0}),
673    )
674
675    result_ops1 = cirq.decompose_once(final_op)
676    assert cirq.Circuit(result_ops1) == expected_circuit1
677
678    expected_circuit = cirq.Circuit(
679        cirq.X(a) ** 0.5,
680        cirq.measure(a, key='ma'),
681        cirq.X(b) ** 0.5,
682        cirq.measure(b, key='mb'),
683        cirq.X(c) ** 0.5,
684        cirq.measure(c, key='mc'),
685        cirq.X(d) ** 0.5,
686        cirq.measure(d, key='md'),
687        cirq.X(a) ** 1.0,
688        cirq.measure(a, key='ma'),
689        cirq.X(b) ** 1.0,
690        cirq.measure(b, key='mb'),
691        cirq.X(c) ** 1.0,
692        cirq.measure(c, key='mc'),
693        cirq.X(d) ** 1.0,
694        cirq.measure(d, key='md'),
695        cirq.X(a) ** 2.0,
696        cirq.measure(a, key='ma'),
697        cirq.X(b) ** 2.0,
698        cirq.measure(b, key='mb'),
699        cirq.X(c) ** 2.0,
700        cirq.measure(c, key='mc'),
701        cirq.X(d) ** 2.0,
702        cirq.measure(d, key='md'),
703    )
704    assert cirq.Circuit(cirq.decompose(final_op)) == expected_circuit
705    # Verify that mapped_circuit gives the same operations.
706    assert final_op.mapped_circuit(deep=True) == expected_circuit
707
708
709def test_decompose_repeated_nested_measurements():
710    # Details of this test described at
711    # https://tinyurl.com/measurement-repeated-circuitop#heading=h.sbgxcsyin9wt.
712    a = cirq.LineQubit(0)
713
714    op1 = (
715        cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='A')))
716        .with_measurement_key_mapping({'A': 'B'})
717        .repeat(2, ['zero', 'one'])
718    )
719
720    op2 = (
721        cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='P'), op1))
722        .with_measurement_key_mapping({'B': 'C', 'P': 'Q'})
723        .repeat(2, ['zero', 'one'])
724    )
725
726    op3 = (
727        cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='X'), op2))
728        .with_measurement_key_mapping({'C': 'D', 'X': 'Y'})
729        .repeat(2, ['zero', 'one'])
730    )
731
732    expected_measurement_keys_in_order = [
733        'zero:Y',
734        'zero:zero:Q',
735        'zero:zero:zero:D',
736        'zero:zero:one:D',
737        'zero:one:Q',
738        'zero:one:zero:D',
739        'zero:one:one:D',
740        'one:Y',
741        'one:zero:Q',
742        'one:zero:zero:D',
743        'one:zero:one:D',
744        'one:one:Q',
745        'one:one:zero:D',
746        'one:one:one:D',
747    ]
748    assert cirq.measurement_keys(op3) == set(expected_measurement_keys_in_order)
749
750    expected_circuit = cirq.Circuit()
751    for key in expected_measurement_keys_in_order:
752        expected_circuit.append(cirq.measure(a, key=cirq.MeasurementKey.parse_serialized(key)))
753
754    assert cirq.Circuit(cirq.decompose(op3)) == expected_circuit
755    assert cirq.measurement_keys(expected_circuit) == set(expected_measurement_keys_in_order)
756
757    # Verify that mapped_circuit gives the same operations.
758    assert op3.mapped_circuit(deep=True) == expected_circuit
759
760
761def test_mapped_circuit_preserves_moments():
762    q0, q1 = cirq.LineQubit.range(2)
763    fc = cirq.FrozenCircuit(cirq.Moment(cirq.X(q0)), cirq.Moment(cirq.X(q1)))
764    op = cirq.CircuitOperation(fc)
765    assert op.mapped_circuit() == fc
766    assert op.repeat(3).mapped_circuit(deep=True) == fc * 3
767
768
769def test_mapped_op():
770    q0, q1 = cirq.LineQubit.range(2)
771    a, b = (sympy.Symbol(x) for x in 'ab')
772    fc1 = cirq.FrozenCircuit(cirq.X(q0) ** a, cirq.measure(q0, q1, key='m'))
773    op1 = (
774        cirq.CircuitOperation(fc1)
775        .with_params({'a': 'b'})
776        .with_qubits(q1, q0)
777        .with_measurement_key_mapping({'m': 'k'})
778    )
779    fc2 = cirq.FrozenCircuit(cirq.X(q1) ** b, cirq.measure(q1, q0, key='k'))
780    op2 = cirq.CircuitOperation(fc2)
781
782    assert op1.mapped_op() == op2
783
784
785def test_tag_propagation():
786    # Tags are not propagated from the CircuitOperation to its components.
787    # TODO: support tag propagation for better serialization.
788    a, b, c = cirq.LineQubit.range(3)
789    circuit = cirq.FrozenCircuit(
790        cirq.X(a),
791        cirq.H(b),
792        cirq.H(c),
793        cirq.CZ(a, c),
794    )
795    op = cirq.CircuitOperation(circuit)
796    test_tag = 'test_tag'
797    op = op.with_tags(test_tag)
798
799    assert test_tag in op.tags
800
801    # TODO: Tags must propagate during decomposition.
802    sub_ops = cirq.decompose(op)
803    for op in sub_ops:
804        assert test_tag not in op.tags
805
806
807# TODO: Operation has a "gate" property. What is this for a CircuitOperation?
808