1# Copyright 2021 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import List, Dict, Any, Sequence, Tuple, Optional, Union 15 16import cirq 17 18 19class EmptyActOnArgs(cirq.ActOnArgs): 20 def __init__(self, qubits, logs): 21 super().__init__( 22 qubits=qubits, 23 log_of_measurement_results=logs, 24 ) 25 26 def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]: 27 return [] 28 29 def copy(self) -> 'EmptyActOnArgs': 30 return EmptyActOnArgs( 31 qubits=self.qubits, 32 logs=self.log_of_measurement_results.copy(), 33 ) 34 35 def _act_on_fallback_( 36 self, 37 action: Union['cirq.Operation', 'cirq.Gate'], 38 qubits: Sequence['cirq.Qid'], 39 allow_decompose: bool = True, 40 ) -> bool: 41 return True 42 43 def _on_copy(self, args): 44 pass 45 46 def _on_kronecker_product(self, other, target): 47 pass 48 49 def _on_transpose_to_qubit_order(self, qubits, target): 50 pass 51 52 def _on_factor(self, qubits, extracted, remainder, validate=True, atol=1e-07): 53 pass 54 55 def sample(self, qubits, repetitions=1, seed=None): 56 pass 57 58 59q0, q1, q2 = qs3 = cirq.LineQubit.range(3) 60qs2 = cirq.LineQubit.range(2) 61 62 63def create_container( 64 qubits: Sequence['cirq.Qid'], 65 split_untangled_states=True, 66) -> cirq.ActOnArgsContainer[EmptyActOnArgs]: 67 args_map: Dict[Optional['cirq.Qid'], EmptyActOnArgs] = {} 68 log: Dict[str, Any] = {} 69 if split_untangled_states: 70 for q in reversed(qubits): 71 args_map[q] = EmptyActOnArgs([q], log) 72 args_map[None] = EmptyActOnArgs((), log) 73 else: 74 args = EmptyActOnArgs(qubits, log) 75 for q in qubits: 76 args_map[q] = args 77 args_map[None] = args if not split_untangled_states else EmptyActOnArgs((), log) 78 return cirq.ActOnArgsContainer(args_map, qubits, split_untangled_states, log) 79 80 81def test_entanglement_causes_join(): 82 args = create_container(qs2) 83 assert len(set(args.values())) == 3 84 args.apply_operation(cirq.CNOT(q0, q1)) 85 assert len(set(args.values())) == 2 86 assert args[q0] is args[q1] 87 assert args[None] is not args[q0] 88 89 90def test_subcircuit_entanglement_causes_join(): 91 args = create_container(qs2) 92 assert len(set(args.values())) == 3 93 args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1)))) 94 assert len(set(args.values())) == 2 95 assert args[q0] is args[q1] 96 97 98def test_subcircuit_entanglement_causes_join_in_subset(): 99 args = create_container(qs3) 100 assert len(set(args.values())) == 4 101 args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q1)))) 102 assert len(set(args.values())) == 3 103 assert args[q0] is args[q1] 104 args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CNOT(q0, q2)))) 105 assert len(set(args.values())) == 2 106 assert args[q0] is args[q1] is args[q2] 107 108 109def test_identity_does_not_join(): 110 args = create_container(qs2) 111 assert len(set(args.values())) == 3 112 args.apply_operation(cirq.IdentityGate(2)(q0, q1)) 113 assert len(set(args.values())) == 3 114 assert args[q0] is not args[q1] 115 assert args[q0] is not args[None] 116 117 118def test_identity_fallback_does_not_join(): 119 args = create_container(qs2) 120 assert len(set(args.values())) == 3 121 args._act_on_fallback_(cirq.I, (q0, q1)) 122 assert len(set(args.values())) == 3 123 assert args[q0] is not args[q1] 124 assert args[q0] is not args[None] 125 126 127def test_subcircuit_identity_does_not_join(): 128 args = create_container(qs2) 129 assert len(set(args.values())) == 3 130 args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.IdentityGate(2)(q0, q1)))) 131 assert len(set(args.values())) == 3 132 assert args[q0] is not args[q1] 133 134 135def test_measurement_causes_split(): 136 args = create_container(qs2) 137 args.apply_operation(cirq.CNOT(q0, q1)) 138 assert len(set(args.values())) == 2 139 args.apply_operation(cirq.measure(q0)) 140 assert len(set(args.values())) == 3 141 assert args[q0] is not args[q1] 142 assert args[q0] is not args[None] 143 144 145def test_subcircuit_measurement_causes_split(): 146 args = create_container(qs2) 147 args.apply_operation(cirq.CNOT(q0, q1)) 148 assert len(set(args.values())) == 2 149 args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0)))) 150 assert len(set(args.values())) == 3 151 assert args[q0] is not args[q1] 152 153 154def test_subcircuit_measurement_causes_split_in_subset(): 155 args = create_container(qs3) 156 args.apply_operation(cirq.CNOT(q0, q1)) 157 args.apply_operation(cirq.CNOT(q0, q2)) 158 assert len(set(args.values())) == 2 159 args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q0)))) 160 assert len(set(args.values())) == 3 161 assert args[q0] is not args[q1] 162 args.apply_operation(cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(q1)))) 163 assert len(set(args.values())) == 4 164 assert args[q0] is not args[q1] 165 assert args[q0] is not args[q2] 166 assert args[q1] is not args[q2] 167 168 169def test_reset_causes_split(): 170 args = create_container(qs2) 171 args.apply_operation(cirq.CNOT(q0, q1)) 172 assert len(set(args.values())) == 2 173 args.apply_operation(cirq.reset(q0)) 174 assert len(set(args.values())) == 3 175 assert args[q0] is not args[q1] 176 assert args[q0] is not args[None] 177 178 179def test_measurement_does_not_split_if_disabled(): 180 args = create_container(qs2, False) 181 args.apply_operation(cirq.CNOT(q0, q1)) 182 assert len(set(args.values())) == 1 183 args.apply_operation(cirq.measure(q0)) 184 assert len(set(args.values())) == 1 185 assert args[q1] is args[q0] 186 assert args[None] is args[q0] 187 188 189def test_reset_does_not_split_if_disabled(): 190 args = create_container(qs2, False) 191 args.apply_operation(cirq.CNOT(q0, q1)) 192 assert len(set(args.values())) == 1 193 args.apply_operation(cirq.reset(q0)) 194 assert len(set(args.values())) == 1 195 assert args[q1] is args[q0] 196 assert args[None] is args[q0] 197 198 199def test_measurement_of_all_qubits_causes_split(): 200 args = create_container(qs2) 201 args.apply_operation(cirq.CNOT(q0, q1)) 202 assert len(set(args.values())) == 2 203 args.apply_operation(cirq.measure(q0, q1)) 204 assert len(set(args.values())) == 3 205 assert args[q0] is not args[q1] 206 assert args[q0] is not args[None] 207 208 209def test_measurement_in_single_qubit_circuit_passes(): 210 args = create_container([q0]) 211 assert len(set(args.values())) == 2 212 args.apply_operation(cirq.measure(q0)) 213 assert len(set(args.values())) == 2 214 assert args[q0] is not args[None] 215 216 217def test_reorder_succeeds(): 218 args = create_container(qs2, False) 219 reordered = args[q0].transpose_to_qubit_order([q1, q0]) 220 assert reordered.qubits == (q1, q0) 221 222 223def test_copy_succeeds(): 224 args = create_container(qs2, False) 225 copied = args[q0].copy() 226 assert copied.qubits == (q0, q1) 227 228 229def test_merge_succeeds(): 230 args = create_container(qs2, False) 231 merged = args.create_merged_state() 232 assert merged.qubits == (q0, q1) 233 234 235def test_swap_does_not_merge(): 236 args = create_container(qs2) 237 old_q0 = args[q0] 238 old_q1 = args[q1] 239 args.apply_operation(cirq.SWAP(q0, q1)) 240 assert len(set(args.values())) == 3 241 assert args[q0] is not old_q0 242 assert args[q1] is old_q0 243 assert args[q1] is not old_q1 244 assert args[q0] is old_q1 245 assert args[q0].qubits == (q0,) 246 assert args[q1].qubits == (q1,) 247 248 249def test_half_swap_does_merge(): 250 args = create_container(qs2) 251 args.apply_operation(cirq.SWAP(q0, q1) ** 0.5) 252 assert len(set(args.values())) == 2 253 assert args[q0] is args[q1] 254 255 256def test_swap_after_entangle_reorders(): 257 args = create_container(qs2) 258 args.apply_operation(cirq.CX(q0, q1)) 259 assert len(set(args.values())) == 2 260 assert args[q0].qubits == (q0, q1) 261 args.apply_operation(cirq.SWAP(q0, q1)) 262 assert len(set(args.values())) == 2 263 assert args[q0] is args[q1] 264 assert args[q0].qubits == (q1, q0) 265 266 267def test_act_on_gate_does_not_join(): 268 args = create_container(qs2) 269 assert len(set(args.values())) == 3 270 cirq.act_on(cirq.X, args, [q0]) 271 assert len(set(args.values())) == 3 272 assert args[q0] is not args[q1] 273 assert args[q0] is not args[None] 274