1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14import pytest 15 16import cirq 17 18 19class NoMethod: 20 pass 21 22 23class DecomposeNotImplemented: 24 def _decompose_(self, qubits=None): 25 return NotImplemented 26 27 28class DecomposeNone: 29 def _decompose_(self, qubits=None): 30 return None 31 32 33class DecomposeGiven: 34 def __init__(self, val): 35 self.val = val 36 37 def _decompose_(self): 38 return self.val 39 40 41class DecomposeWithQubitsGiven: 42 def __init__(self, func): 43 self.func = func 44 45 def _decompose_(self, qubits): 46 return self.func(*qubits) 47 48 49class DecomposeGenerated: 50 def _decompose_(self): 51 yield cirq.X(cirq.LineQubit(0)) 52 yield cirq.Y(cirq.LineQubit(1)) 53 54 55class DecomposeQuditGate: 56 def _decompose_(self, qids): 57 yield cirq.identity_each(*qids) 58 59 60def test_decompose_once(): 61 # No default value results in descriptive error. 62 with pytest.raises(TypeError, match='no _decompose_ method'): 63 _ = cirq.decompose_once(NoMethod()) 64 with pytest.raises(TypeError, match='returned NotImplemented or None'): 65 _ = cirq.decompose_once(DecomposeNotImplemented()) 66 with pytest.raises(TypeError, match='returned NotImplemented or None'): 67 _ = cirq.decompose_once(DecomposeNone()) 68 69 # Default value works. 70 assert cirq.decompose_once(NoMethod(), 5) == 5 71 assert cirq.decompose_once(DecomposeNotImplemented(), None) is None 72 assert cirq.decompose_once(NoMethod(), NotImplemented) is NotImplemented 73 assert cirq.decompose_once(DecomposeNone(), 0) == 0 74 75 # Flattens into a list. 76 op = cirq.X(cirq.NamedQubit('q')) 77 assert cirq.decompose_once(DecomposeGiven(op)) == [op] 78 assert cirq.decompose_once(DecomposeGiven([[[op]], []])) == [op] 79 assert cirq.decompose_once(DecomposeGiven(op for _ in range(2))) == [op, op] 80 assert type(cirq.decompose_once(DecomposeGiven(op for _ in range(2)))) == list 81 assert cirq.decompose_once(DecomposeGenerated()) == [ 82 cirq.X(cirq.LineQubit(0)), 83 cirq.Y(cirq.LineQubit(1)), 84 ] 85 86 87def test_decompose_once_with_qubits(): 88 qs = cirq.LineQubit.range(3) 89 90 # No default value results in descriptive error. 91 with pytest.raises(TypeError, match='no _decompose_ method'): 92 _ = cirq.decompose_once_with_qubits(NoMethod(), qs) 93 with pytest.raises(TypeError, match='returned NotImplemented or None'): 94 _ = cirq.decompose_once_with_qubits(DecomposeNotImplemented(), qs) 95 with pytest.raises(TypeError, match='returned NotImplemented or None'): 96 _ = cirq.decompose_once_with_qubits(DecomposeNone(), qs) 97 98 # Default value works. 99 assert cirq.decompose_once_with_qubits(NoMethod(), qs, 5) == 5 100 assert cirq.decompose_once_with_qubits(DecomposeNotImplemented(), qs, None) is None 101 assert cirq.decompose_once_with_qubits(NoMethod(), qs, NotImplemented) is NotImplemented 102 103 # Flattens into a list. 104 assert cirq.decompose_once_with_qubits(DecomposeWithQubitsGiven(cirq.X.on_each), qs) == [ 105 cirq.X(cirq.LineQubit(0)), 106 cirq.X(cirq.LineQubit(1)), 107 cirq.X(cirq.LineQubit(2)), 108 ] 109 assert cirq.decompose_once_with_qubits( 110 DecomposeWithQubitsGiven(lambda *qubits: cirq.Y(qubits[0])), qs 111 ) == [cirq.Y(cirq.LineQubit(0))] 112 assert cirq.decompose_once_with_qubits( 113 DecomposeWithQubitsGiven(lambda *qubits: (cirq.Y(q) for q in qubits)), qs 114 ) == [cirq.Y(cirq.LineQubit(0)), cirq.Y(cirq.LineQubit(1)), cirq.Y(cirq.LineQubit(2))] 115 116 # Qudits, _decompose_ argument name is not 'qubits'. 117 assert cirq.decompose_once_with_qubits( 118 DecomposeQuditGate(), cirq.LineQid.for_qid_shape((1, 2, 3)) 119 ) == [cirq.identity_each(*cirq.LineQid.for_qid_shape((1, 2, 3)))] 120 121 # Works when qubits are generated. 122 def use_qubits_twice(*qubits): 123 a = list(qubits) 124 b = list(qubits) 125 yield cirq.X.on_each(*a) 126 yield cirq.Y.on_each(*b) 127 128 assert cirq.decompose_once_with_qubits( 129 DecomposeWithQubitsGiven(use_qubits_twice), (q for q in qs) 130 ) == list(cirq.X.on_each(*qs)) + list(cirq.Y.on_each(*qs)) 131 132 133def test_decompose_general(): 134 a, b, c = cirq.LineQubit.range(3) 135 no_method = NoMethod() 136 assert cirq.decompose(no_method) == [no_method] 137 138 # Flattens iterables. 139 assert cirq.decompose([cirq.SWAP(a, b), cirq.SWAP(a, b)]) == 2 * cirq.decompose(cirq.SWAP(a, b)) 140 141 # Decomposed circuit should be equivalent. The ordering should be correct. 142 ops = cirq.TOFFOLI(a, b, c), cirq.H(a), cirq.CZ(a, c) 143 cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( 144 cirq.Circuit(ops), cirq.Circuit(cirq.decompose(ops)), atol=1e-8 145 ) 146 147 148def test_decompose_keep(): 149 a, b = cirq.LineQubit.range(2) 150 151 # Recursion can be stopped. 152 assert cirq.decompose(cirq.SWAP(a, b), keep=lambda e: isinstance(e.gate, cirq.CNotPowGate)) == [ 153 cirq.CNOT(a, b), 154 cirq.CNOT(b, a), 155 cirq.CNOT(a, b), 156 ] 157 158 # Recursion continues down to CZ + single-qubit gates. 159 cirq.testing.assert_has_diagram( 160 cirq.Circuit(cirq.decompose(cirq.SWAP(a, b))), 161 """ 1620: ────────────@───Y^-0.5───@───Y^0.5────@─────────── 163 │ │ │ 1641: ───Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5─── 165""", 166 ) 167 168 # If you're happy with everything, no decomposition happens. 169 assert cirq.decompose(cirq.SWAP(a, b), keep=lambda _: True) == [cirq.SWAP(a, b)] 170 # Unless it's not an operation. 171 assert cirq.decompose(DecomposeGiven(cirq.SWAP(b, a)), keep=lambda _: True) == [cirq.SWAP(b, a)] 172 # E.g. lists still get flattened. 173 assert cirq.decompose([[[cirq.SWAP(a, b)]]], keep=lambda _: True) == [cirq.SWAP(a, b)] 174 175 176def test_decompose_on_stuck_raise(): 177 a, b = cirq.LineQubit.range(2) 178 no_method = NoMethod() 179 180 # If you're not happy with anything, you're going to get an error. 181 with pytest.raises(ValueError, match="but can't be decomposed"): 182 _ = cirq.decompose(NoMethod(), keep=lambda _: False) 183 # Unless there's no operations to be unhappy about. 184 assert cirq.decompose([], keep=lambda _: False) == [] 185 # Or you say you're fine. 186 assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=None) == [no_method] 187 assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=lambda _: None) == [ 188 no_method 189 ] 190 # You can customize the error. 191 with pytest.raises(TypeError, match='test'): 192 _ = cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=TypeError('test')) 193 with pytest.raises(NotImplementedError, match='op cirq.CZ'): 194 _ = cirq.decompose( 195 cirq.CZ(a, b), 196 keep=lambda _: False, 197 on_stuck_raise=lambda op: NotImplementedError(f'op {op!r}'), 198 ) 199 200 # There's a nice warning if you specify `on_stuck_raise` but not `keep`. 201 with pytest.raises(ValueError, match='on_stuck_raise'): 202 assert cirq.decompose([], on_stuck_raise=None) 203 with pytest.raises(ValueError, match='on_stuck_raise'): 204 assert cirq.decompose([], on_stuck_raise=TypeError('x')) 205 206 207def test_decompose_intercept(): 208 a = cirq.NamedQubit('a') 209 b = cirq.NamedQubit('b') 210 211 # Runs instead of normal decomposition. 212 actual = cirq.decompose( 213 cirq.SWAP(a, b), 214 intercepting_decomposer=lambda op: (cirq.X(a) if op == cirq.SWAP(a, b) else NotImplemented), 215 ) 216 assert actual == [cirq.X(a)] 217 218 # Falls back to normal decomposition when NotImplemented. 219 actual = cirq.decompose( 220 cirq.SWAP(a, b), 221 keep=lambda op: isinstance(op.gate, cirq.CNotPowGate), 222 intercepting_decomposer=lambda _: NotImplemented, 223 ) 224 assert actual == [cirq.CNOT(a, b), cirq.CNOT(b, a), cirq.CNOT(a, b)] 225 226 227def test_decompose_preserving_structure(): 228 a, b = cirq.LineQubit.range(2) 229 fc1 = cirq.FrozenCircuit(cirq.SWAP(a, b), cirq.FSimGate(0.1, 0.2).on(a, b)) 230 cop1_1 = cirq.CircuitOperation(fc1).with_tags('test_tag') 231 cop1_2 = cirq.CircuitOperation(fc1).with_qubit_mapping({a: b, b: a}) 232 fc2 = cirq.FrozenCircuit(cirq.X(a), cop1_1, cop1_2) 233 cop2 = cirq.CircuitOperation(fc2) 234 235 circuit = cirq.Circuit(cop2, cirq.measure(a, b, key='m')) 236 actual = cirq.Circuit(cirq.decompose(circuit, preserve_structure=True)) 237 238 # This should keep the CircuitOperations but decompose their SWAPs. 239 fc1_decomp = cirq.FrozenCircuit(cirq.decompose(fc1)) 240 expected = cirq.Circuit( 241 cirq.CircuitOperation( 242 cirq.FrozenCircuit( 243 cirq.X(a), 244 cirq.CircuitOperation(fc1_decomp).with_tags('test_tag'), 245 cirq.CircuitOperation(fc1_decomp).with_qubit_mapping({a: b, b: a}), 246 ) 247 ), 248 cirq.measure(a, b, key='m'), 249 ) 250 assert actual == expected 251 252 253# Test both intercepting and fallback decomposers. 254@pytest.mark.parametrize('decompose_mode', ['intercept', 'fallback']) 255def test_decompose_preserving_structure_forwards_args(decompose_mode): 256 a, b = cirq.LineQubit.range(2) 257 fc1 = cirq.FrozenCircuit(cirq.SWAP(a, b), cirq.FSimGate(0.1, 0.2).on(a, b)) 258 cop1_1 = cirq.CircuitOperation(fc1).with_tags('test_tag') 259 cop1_2 = cirq.CircuitOperation(fc1).with_qubit_mapping({a: b, b: a}) 260 fc2 = cirq.FrozenCircuit(cirq.X(a), cop1_1, cop1_2) 261 cop2 = cirq.CircuitOperation(fc2) 262 263 circuit = cirq.Circuit(cop2, cirq.measure(a, b, key='m')) 264 265 def keep_func(op: 'cirq.Operation'): 266 # Only decompose SWAP and X. 267 return not isinstance(op.gate, (cirq.SwapPowGate, cirq.XPowGate)) 268 269 def x_to_hzh(op: 'cirq.Operation'): 270 if isinstance(op.gate, cirq.XPowGate) and op.gate.exponent == 1: 271 return [ 272 cirq.H(*op.qubits), 273 cirq.Z(*op.qubits), 274 cirq.H(*op.qubits), 275 ] 276 277 actual = cirq.Circuit( 278 cirq.decompose( 279 circuit, 280 keep=keep_func, 281 intercepting_decomposer=x_to_hzh if decompose_mode == 'intercept' else None, 282 fallback_decomposer=x_to_hzh if decompose_mode == 'fallback' else None, 283 preserve_structure=True, 284 ), 285 ) 286 287 # This should keep the CircuitOperations but decompose their SWAPs. 288 fc1_decomp = cirq.FrozenCircuit( 289 cirq.decompose( 290 fc1, 291 keep=keep_func, 292 fallback_decomposer=x_to_hzh, 293 ) 294 ) 295 expected = cirq.Circuit( 296 cirq.CircuitOperation( 297 cirq.FrozenCircuit( 298 cirq.H(a), 299 cirq.Z(a), 300 cirq.H(a), 301 cirq.CircuitOperation(fc1_decomp).with_tags('test_tag'), 302 cirq.CircuitOperation(fc1_decomp).with_qubit_mapping({a: b, b: a}), 303 ) 304 ), 305 cirq.measure(a, b, key='m'), 306 ) 307 assert actual == expected 308