1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import List, Dict, Any, Sequence, Tuple, Optional, Union
15
16import cirq
17
18
19class EmptyActOnArgs(cirq.ActOnArgs):
20    def __init__(self, qubits, logs):
21        super().__init__(
22            qubits=qubits,
23            log_of_measurement_results=logs,
24        )
25
26    def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]:
27        return []
28
29    def copy(self) -> 'EmptyActOnArgs':
30        return EmptyActOnArgs(
31            qubits=self.qubits,
32            logs=self.log_of_measurement_results.copy(),
33        )
34
35    def _act_on_fallback_(
36        self,
37        action: Union['cirq.Operation', 'cirq.Gate'],
38        qubits: Sequence['cirq.Qid'],
39        allow_decompose: bool = True,
40    ) -> bool:
41        return True
42
43    def _on_copy(self, args):
44        pass
45
46    def _on_kronecker_product(self, other, target):
47        pass
48
49    def _on_transpose_to_qubit_order(self, qubits, target):
50        pass
51
52    def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07):
53        pass
54
55    def sample(self, qubits, repetitions=1, seed=None):
56        pass
57
58
59q0, q1, q2 = qs3 = cirq.LineQubit.range(3)
60qs2 = cirq.LineQubit.range(2)
61
62
63def create_container(
64    qubits: Sequence['cirq.Qid'],
65    split_untangled_states=True,
66) -> cirq.ActOnArgsContainer[EmptyActOnArgs]:
67    args_map: Dict[Optional['cirq.Qid'], EmptyActOnArgs] = {}
68    log: Dict[str, Any] = {}
69    if split_untangled_states:
70        for q in reversed(qubits):
71            args_map[q] = EmptyActOnArgs([q], log)
72        args_map[None] = EmptyActOnArgs((), log)
73    else:
74        args = EmptyActOnArgs(qubits, log)
75        for q in qubits:
76            args_map[q] = args
77        args_map[None] = args if not split_untangled_states else EmptyActOnArgs((), log)
78    return cirq.ActOnArgsContainer(args_map, qubits, split_untangled_states, log)
79
80
81def test_entanglement_causes_join():
82    args = create_container(qs2)
83    assert len(set(args.values())) == 3
84    args.apply_operation(cirq.CNOT(q0, q1))
85    assert len(set(args.values())) == 2
86    assert args[q0] is args[q1]
87    assert args[None] is not args[q0]
88
89
90def test_subcircuit_entanglement_causes_join():
91    args = create_container(qs2)
92    assert len(set(args.values())) == 3
93    args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1))))
94    assert len(set(args.values())) == 2
95    assert args[q0] is args[q1]
96
97
98def test_subcircuit_entanglement_causes_join_in_subset():
99    args = create_container(qs3)
100    assert len(set(args.values())) == 4
101    args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1))))
102    assert len(set(args.values())) == 3
103    assert args[q0] is args[q1]
104    args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q2))))
105    assert len(set(args.values())) == 2
106    assert args[q0] is args[q1] is args[q2]
107
108
109def test_identity_does_not_join():
110    args = create_container(qs2)
111    assert len(set(args.values())) == 3
112    args.apply_operation(cirq.IdentityGate(2)(q0, q1))
113    assert len(set(args.values())) == 3
114    assert args[q0] is not args[q1]
115    assert args[q0] is not args[None]
116
117
118def test_identity_fallback_does_not_join():
119    args = create_container(qs2)
120    assert len(set(args.values())) == 3
121    args._act_on_fallback_(cirq.I, (q0, q1))
122    assert len(set(args.values())) == 3
123    assert args[q0] is not args[q1]
124    assert args[q0] is not args[None]
125
126
127def test_subcircuit_identity_does_not_join():
128    args = create_container(qs2)
129    assert len(set(args.values())) == 3
130    args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.IdentityGate(2)(q0, q1))))
131    assert len(set(args.values())) == 3
132    assert args[q0] is not args[q1]
133
134
135def test_measurement_causes_split():
136    args = create_container(qs2)
137    args.apply_operation(cirq.CNOT(q0, q1))
138    assert len(set(args.values())) == 2
139    args.apply_operation(cirq.measure(q0))
140    assert len(set(args.values())) == 3
141    assert args[q0] is not args[q1]
142    assert args[q0] is not args[None]
143
144
145def test_subcircuit_measurement_causes_split():
146    args = create_container(qs2)
147    args.apply_operation(cirq.CNOT(q0, q1))
148    assert len(set(args.values())) == 2
149    args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0))))
150    assert len(set(args.values())) == 3
151    assert args[q0] is not args[q1]
152
153
154def test_subcircuit_measurement_causes_split_in_subset():
155    args = create_container(qs3)
156    args.apply_operation(cirq.CNOT(q0, q1))
157    args.apply_operation(cirq.CNOT(q0, q2))
158    assert len(set(args.values())) == 2
159    args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0))))
160    assert len(set(args.values())) == 3
161    assert args[q0] is not args[q1]
162    args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q1))))
163    assert len(set(args.values())) == 4
164    assert args[q0] is not args[q1]
165    assert args[q0] is not args[q2]
166    assert args[q1] is not args[q2]
167
168
169def test_reset_causes_split():
170    args = create_container(qs2)
171    args.apply_operation(cirq.CNOT(q0, q1))
172    assert len(set(args.values())) == 2
173    args.apply_operation(cirq.reset(q0))
174    assert len(set(args.values())) == 3
175    assert args[q0] is not args[q1]
176    assert args[q0] is not args[None]
177
178
179def test_measurement_does_not_split_if_disabled():
180    args = create_container(qs2, False)
181    args.apply_operation(cirq.CNOT(q0, q1))
182    assert len(set(args.values())) == 1
183    args.apply_operation(cirq.measure(q0))
184    assert len(set(args.values())) == 1
185    assert args[q1] is args[q0]
186    assert args[None] is args[q0]
187
188
189def test_reset_does_not_split_if_disabled():
190    args = create_container(qs2, False)
191    args.apply_operation(cirq.CNOT(q0, q1))
192    assert len(set(args.values())) == 1
193    args.apply_operation(cirq.reset(q0))
194    assert len(set(args.values())) == 1
195    assert args[q1] is args[q0]
196    assert args[None] is args[q0]
197
198
199def test_measurement_of_all_qubits_causes_split():
200    args = create_container(qs2)
201    args.apply_operation(cirq.CNOT(q0, q1))
202    assert len(set(args.values())) == 2
203    args.apply_operation(cirq.measure(q0, q1))
204    assert len(set(args.values())) == 3
205    assert args[q0] is not args[q1]
206    assert args[q0] is not args[None]
207
208
209def test_measurement_in_single_qubit_circuit_passes():
210    args = create_container([q0])
211    assert len(set(args.values())) == 2
212    args.apply_operation(cirq.measure(q0))
213    assert len(set(args.values())) == 2
214    assert args[q0] is not args[None]
215
216
217def test_reorder_succeeds():
218    args = create_container(qs2, False)
219    reordered = args[q0].transpose_to_qubit_order([q1, q0])
220    assert reordered.qubits == (q1, q0)
221
222
223def test_copy_succeeds():
224    args = create_container(qs2, False)
225    copied = args[q0].copy()
226    assert copied.qubits == (q0, q1)
227
228
229def test_merge_succeeds():
230    args = create_container(qs2, False)
231    merged = args.create_merged_state()
232    assert merged.qubits == (q0, q1)
233
234
235def test_swap_does_not_merge():
236    args = create_container(qs2)
237    old_q0 = args[q0]
238    old_q1 = args[q1]
239    args.apply_operation(cirq.SWAP(q0, q1))
240    assert len(set(args.values())) == 3
241    assert args[q0] is not old_q0
242    assert args[q1] is old_q0
243    assert args[q1] is not old_q1
244    assert args[q0] is old_q1
245    assert args[q0].qubits == (q0,)
246    assert args[q1].qubits == (q1,)
247
248
249def test_half_swap_does_merge():
250    args = create_container(qs2)
251    args.apply_operation(cirq.SWAP(q0, q1) ** 0.5)
252    assert len(set(args.values())) == 2
253    assert args[q0] is args[q1]
254
255
256def test_swap_after_entangle_reorders():
257    args = create_container(qs2)
258    args.apply_operation(cirq.CX(q0, q1))
259    assert len(set(args.values())) == 2
260    assert args[q0].qubits == (q0, q1)
261    args.apply_operation(cirq.SWAP(q0, q1))
262    assert len(set(args.values())) == 2
263    assert args[q0] is args[q1]
264    assert args[q0].qubits == (q1, q0)
265
266
267def test_act_on_gate_does_not_join():
268    args = create_container(qs2)
269    assert len(set(args.values())) == 3
270    cirq.act_on(cirq.X, args, [q0])
271    assert len(set(args.values())) == 3
272    assert args[q0] is not args[q1]
273    assert args[q0] is not args[None]
274