1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import pytest
15
16import cirq
17
18
19class NoMethod:
20    pass
21
22
23class DecomposeNotImplemented:
24    def _decompose_(self, qubits=None):
25        return NotImplemented
26
27
28class DecomposeNone:
29    def _decompose_(self, qubits=None):
30        return None
31
32
33class DecomposeGiven:
34    def __init__(self, val):
35        self.val = val
36
37    def _decompose_(self):
38        return self.val
39
40
41class DecomposeWithQubitsGiven:
42    def __init__(self, func):
43        self.func = func
44
45    def _decompose_(self, qubits):
46        return self.func(*qubits)
47
48
49class DecomposeGenerated:
50    def _decompose_(self):
51        yield cirq.X(cirq.LineQubit(0))
52        yield cirq.Y(cirq.LineQubit(1))
53
54
55class DecomposeQuditGate:
56    def _decompose_(self, qids):
57        yield cirq.identity_each(*qids)
58
59
60def test_decompose_once():
61    # No default value results in descriptive error.
62    with pytest.raises(TypeError, match='no _decompose_ method'):
63        _ = cirq.decompose_once(NoMethod())
64    with pytest.raises(TypeError, match='returned NotImplemented or None'):
65        _ = cirq.decompose_once(DecomposeNotImplemented())
66    with pytest.raises(TypeError, match='returned NotImplemented or None'):
67        _ = cirq.decompose_once(DecomposeNone())
68
69    # Default value works.
70    assert cirq.decompose_once(NoMethod(), 5) == 5
71    assert cirq.decompose_once(DecomposeNotImplemented(), None) is None
72    assert cirq.decompose_once(NoMethod(), NotImplemented) is NotImplemented
73    assert cirq.decompose_once(DecomposeNone(), 0) == 0
74
75    # Flattens into a list.
76    op = cirq.X(cirq.NamedQubit('q'))
77    assert cirq.decompose_once(DecomposeGiven(op)) == [op]
78    assert cirq.decompose_once(DecomposeGiven([[[op]], []])) == [op]
79    assert cirq.decompose_once(DecomposeGiven(op for _ in range(2))) == [op, op]
80    assert type(cirq.decompose_once(DecomposeGiven(op for _ in range(2)))) == list
81    assert cirq.decompose_once(DecomposeGenerated()) == [
82        cirq.X(cirq.LineQubit(0)),
83        cirq.Y(cirq.LineQubit(1)),
84    ]
85
86
87def test_decompose_once_with_qubits():
88    qs = cirq.LineQubit.range(3)
89
90    # No default value results in descriptive error.
91    with pytest.raises(TypeError, match='no _decompose_ method'):
92        _ = cirq.decompose_once_with_qubits(NoMethod(), qs)
93    with pytest.raises(TypeError, match='returned NotImplemented or None'):
94        _ = cirq.decompose_once_with_qubits(DecomposeNotImplemented(), qs)
95    with pytest.raises(TypeError, match='returned NotImplemented or None'):
96        _ = cirq.decompose_once_with_qubits(DecomposeNone(), qs)
97
98    # Default value works.
99    assert cirq.decompose_once_with_qubits(NoMethod(), qs, 5) == 5
100    assert cirq.decompose_once_with_qubits(DecomposeNotImplemented(), qs, None) is None
101    assert cirq.decompose_once_with_qubits(NoMethod(), qs, NotImplemented) is NotImplemented
102
103    # Flattens into a list.
104    assert cirq.decompose_once_with_qubits(DecomposeWithQubitsGiven(cirq.X.on_each), qs) == [
105        cirq.X(cirq.LineQubit(0)),
106        cirq.X(cirq.LineQubit(1)),
107        cirq.X(cirq.LineQubit(2)),
108    ]
109    assert cirq.decompose_once_with_qubits(
110        DecomposeWithQubitsGiven(lambda *qubits: cirq.Y(qubits[0])), qs
111    ) == [cirq.Y(cirq.LineQubit(0))]
112    assert cirq.decompose_once_with_qubits(
113        DecomposeWithQubitsGiven(lambda *qubits: (cirq.Y(q) for q in qubits)), qs
114    ) == [cirq.Y(cirq.LineQubit(0)), cirq.Y(cirq.LineQubit(1)), cirq.Y(cirq.LineQubit(2))]
115
116    # Qudits, _decompose_ argument name is not 'qubits'.
117    assert cirq.decompose_once_with_qubits(
118        DecomposeQuditGate(), cirq.LineQid.for_qid_shape((1, 2, 3))
119    ) == [cirq.identity_each(*cirq.LineQid.for_qid_shape((1, 2, 3)))]
120
121    # Works when qubits are generated.
122    def use_qubits_twice(*qubits):
123        a = list(qubits)
124        b = list(qubits)
125        yield cirq.X.on_each(*a)
126        yield cirq.Y.on_each(*b)
127
128    assert cirq.decompose_once_with_qubits(
129        DecomposeWithQubitsGiven(use_qubits_twice), (q for q in qs)
130    ) == list(cirq.X.on_each(*qs)) + list(cirq.Y.on_each(*qs))
131
132
133def test_decompose_general():
134    a, b, c = cirq.LineQubit.range(3)
135    no_method = NoMethod()
136    assert cirq.decompose(no_method) == [no_method]
137
138    # Flattens iterables.
139    assert cirq.decompose([cirq.SWAP(a, b), cirq.SWAP(a, b)]) == 2 * cirq.decompose(cirq.SWAP(a, b))
140
141    # Decomposed circuit should be equivalent. The ordering should be correct.
142    ops = cirq.TOFFOLI(a, b, c), cirq.H(a), cirq.CZ(a, c)
143    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
144        cirq.Circuit(ops), cirq.Circuit(cirq.decompose(ops)), atol=1e-8
145    )
146
147
148def test_decompose_keep():
149    a, b = cirq.LineQubit.range(2)
150
151    # Recursion can be stopped.
152    assert cirq.decompose(cirq.SWAP(a, b), keep=lambda e: isinstance(e.gate, cirq.CNotPowGate)) == [
153        cirq.CNOT(a, b),
154        cirq.CNOT(b, a),
155        cirq.CNOT(a, b),
156    ]
157
158    # Recursion continues down to CZ + single-qubit gates.
159    cirq.testing.assert_has_diagram(
160        cirq.Circuit(cirq.decompose(cirq.SWAP(a, b))),
161        """
1620: ────────────@───Y^-0.5───@───Y^0.5────@───────────
163               │            │            │
1641: ───Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5───
165""",
166    )
167
168    # If you're happy with everything, no decomposition happens.
169    assert cirq.decompose(cirq.SWAP(a, b), keep=lambda _: True) == [cirq.SWAP(a, b)]
170    # Unless it's not an operation.
171    assert cirq.decompose(DecomposeGiven(cirq.SWAP(b, a)), keep=lambda _: True) == [cirq.SWAP(b, a)]
172    # E.g. lists still get flattened.
173    assert cirq.decompose([[[cirq.SWAP(a, b)]]], keep=lambda _: True) == [cirq.SWAP(a, b)]
174
175
176def test_decompose_on_stuck_raise():
177    a, b = cirq.LineQubit.range(2)
178    no_method = NoMethod()
179
180    # If you're not happy with anything, you're going to get an error.
181    with pytest.raises(ValueError, match="but can't be decomposed"):
182        _ = cirq.decompose(NoMethod(), keep=lambda _: False)
183    # Unless there's no operations to be unhappy about.
184    assert cirq.decompose([], keep=lambda _: False) == []
185    # Or you say you're fine.
186    assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=None) == [no_method]
187    assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=lambda _: None) == [
188        no_method
189    ]
190    # You can customize the error.
191    with pytest.raises(TypeError, match='test'):
192        _ = cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=TypeError('test'))
193    with pytest.raises(NotImplementedError, match='op cirq.CZ'):
194        _ = cirq.decompose(
195            cirq.CZ(a, b),
196            keep=lambda _: False,
197            on_stuck_raise=lambda op: NotImplementedError(f'op {op!r}'),
198        )
199
200    # There's a nice warning if you specify `on_stuck_raise` but not `keep`.
201    with pytest.raises(ValueError, match='on_stuck_raise'):
202        assert cirq.decompose([], on_stuck_raise=None)
203    with pytest.raises(ValueError, match='on_stuck_raise'):
204        assert cirq.decompose([], on_stuck_raise=TypeError('x'))
205
206
207def test_decompose_intercept():
208    a = cirq.NamedQubit('a')
209    b = cirq.NamedQubit('b')
210
211    # Runs instead of normal decomposition.
212    actual = cirq.decompose(
213        cirq.SWAP(a, b),
214        intercepting_decomposer=lambda op: (cirq.X(a) if op == cirq.SWAP(a, b) else NotImplemented),
215    )
216    assert actual == [cirq.X(a)]
217
218    # Falls back to normal decomposition when NotImplemented.
219    actual = cirq.decompose(
220        cirq.SWAP(a, b),
221        keep=lambda op: isinstance(op.gate, cirq.CNotPowGate),
222        intercepting_decomposer=lambda _: NotImplemented,
223    )
224    assert actual == [cirq.CNOT(a, b), cirq.CNOT(b, a), cirq.CNOT(a, b)]
225
226
227def test_decompose_preserving_structure():
228    a, b = cirq.LineQubit.range(2)
229    fc1 = cirq.FrozenCircuit(cirq.SWAP(a, b), cirq.FSimGate(0.1, 0.2).on(a, b))
230    cop1_1 = cirq.CircuitOperation(fc1).with_tags('test_tag')
231    cop1_2 = cirq.CircuitOperation(fc1).with_qubit_mapping({a: b, b: a})
232    fc2 = cirq.FrozenCircuit(cirq.X(a), cop1_1, cop1_2)
233    cop2 = cirq.CircuitOperation(fc2)
234
235    circuit = cirq.Circuit(cop2, cirq.measure(a, b, key='m'))
236    actual = cirq.Circuit(cirq.decompose(circuit, preserve_structure=True))
237
238    # This should keep the CircuitOperations but decompose their SWAPs.
239    fc1_decomp = cirq.FrozenCircuit(cirq.decompose(fc1))
240    expected = cirq.Circuit(
241        cirq.CircuitOperation(
242            cirq.FrozenCircuit(
243                cirq.X(a),
244                cirq.CircuitOperation(fc1_decomp).with_tags('test_tag'),
245                cirq.CircuitOperation(fc1_decomp).with_qubit_mapping({a: b, b: a}),
246            )
247        ),
248        cirq.measure(a, b, key='m'),
249    )
250    assert actual == expected
251
252
253# Test both intercepting and fallback decomposers.
254@pytest.mark.parametrize('decompose_mode', ['intercept', 'fallback'])
255def test_decompose_preserving_structure_forwards_args(decompose_mode):
256    a, b = cirq.LineQubit.range(2)
257    fc1 = cirq.FrozenCircuit(cirq.SWAP(a, b), cirq.FSimGate(0.1, 0.2).on(a, b))
258    cop1_1 = cirq.CircuitOperation(fc1).with_tags('test_tag')
259    cop1_2 = cirq.CircuitOperation(fc1).with_qubit_mapping({a: b, b: a})
260    fc2 = cirq.FrozenCircuit(cirq.X(a), cop1_1, cop1_2)
261    cop2 = cirq.CircuitOperation(fc2)
262
263    circuit = cirq.Circuit(cop2, cirq.measure(a, b, key='m'))
264
265    def keep_func(op: 'cirq.Operation'):
266        # Only decompose SWAP and X.
267        return not isinstance(op.gate, (cirq.SwapPowGate, cirq.XPowGate))
268
269    def x_to_hzh(op: 'cirq.Operation'):
270        if isinstance(op.gate, cirq.XPowGate) and op.gate.exponent == 1:
271            return [
272                cirq.H(*op.qubits),
273                cirq.Z(*op.qubits),
274                cirq.H(*op.qubits),
275            ]
276
277    actual = cirq.Circuit(
278        cirq.decompose(
279            circuit,
280            keep=keep_func,
281            intercepting_decomposer=x_to_hzh if decompose_mode == 'intercept' else None,
282            fallback_decomposer=x_to_hzh if decompose_mode == 'fallback' else None,
283            preserve_structure=True,
284        ),
285    )
286
287    # This should keep the CircuitOperations but decompose their SWAPs.
288    fc1_decomp = cirq.FrozenCircuit(
289        cirq.decompose(
290            fc1,
291            keep=keep_func,
292            fallback_decomposer=x_to_hzh,
293        )
294    )
295    expected = cirq.Circuit(
296        cirq.CircuitOperation(
297            cirq.FrozenCircuit(
298                cirq.H(a),
299                cirq.Z(a),
300                cirq.H(a),
301                cirq.CircuitOperation(fc1_decomp).with_tags('test_tag'),
302                cirq.CircuitOperation(fc1_decomp).with_qubit_mapping({a: b, b: a}),
303            )
304        ),
305        cirq.measure(a, b, key='m'),
306    )
307    assert actual == expected
308