1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import (
16    TYPE_CHECKING,
17    Any,
18    Callable,
19    Dict,
20    Iterable,
21    List,
22    Optional,
23    overload,
24    Sequence,
25    Tuple,
26    TypeVar,
27    Union,
28)
29from collections import defaultdict
30
31from typing_extensions import Protocol
32
33from cirq import devices, ops
34from cirq._doc import doc_private
35from cirq.protocols import qid_shape_protocol
36from cirq.type_workarounds import NotImplementedType
37
38if TYPE_CHECKING:
39    import cirq
40
41TDefault = TypeVar('TDefault')
42
43TError = TypeVar('TError', bound=Exception)
44
45RaiseTypeErrorIfNotProvided: Any = ([],)
46
47DecomposeResult = Union[None, NotImplementedType, 'cirq.OP_TREE']
48OpDecomposer = Callable[['cirq.Operation'], DecomposeResult]
49
50
51def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError:
52    return ValueError(f"Operation doesn't satisfy the given `keep` but can't be decomposed: {op!r}")
53
54
55class SupportsDecompose(Protocol):
56    """An object that can be decomposed into simpler operations.
57
58    All decomposition methods should ultimately terminate on basic 1-qubit and
59    2-qubit gates included by default in Cirq. Cirq does not make any guarantees
60    about what the final gate set is. Currently, decompositions within Cirq
61    happen to converge towards the X, Y, Z, CZ, PhasedX, specified-matrix gates,
62    and others. This set will vary from release to release. Because of this
63    variability, it is important for consumers of decomposition to look for
64    generic properties of gates, such as "two qubit gate with a unitary matrix",
65    instead of specific gate types such as CZ gates (though a consumer is
66    of course free to handle CZ gates in a special way, and consumers can
67    give an `intercepting_decomposer` to `cirq.decompose` that attempts to
68    target a specific gate set).
69
70    For example, `cirq.TOFFOLI` has a `_decompose_` method that returns a pair
71    of Hadamard gates surrounding a `cirq.CCZ`. Although `cirq.CCZ` is not a
72    1-qubit or 2-qubit operation, it specifies its own `_decompose_` method
73    that only returns 1-qubit or 2-qubit operations. This means that iteratively
74    decomposing `cirq.TOFFOLI` terminates in 1-qubit and 2-qubit operations, and
75    so almost all decomposition-aware code will be able to handle `cirq.TOFFOLI`
76    instances.
77
78    Callers are responsible for iteratively decomposing until they are given
79    operations that they understand. The `cirq.decompose` method is a simple way
80    to do this, because it has logic to recursively decompose until a given
81    `keep` predicate is satisfied.
82
83    Code implementing `_decompose_` MUST NOT create cycles, such as a gate A
84    decomposes into a gate B which decomposes back into gate A. This will result
85    in infinite loops when calling `cirq.decompose`.
86
87    It is permitted (though not recommended) for the chain of decompositions
88    resulting from an operation to hit a dead end before reaching 1-qubit or
89    2-qubit operations. When this happens, `cirq.decompose` will raise
90    a `TypeError` by default, but can be configured to ignore the issue or
91    raise a caller-provided error.
92    """
93
94    @doc_private
95    def _decompose_(self) -> DecomposeResult:
96        pass
97
98
99class SupportsDecomposeWithQubits(Protocol):
100    """An object that can be decomposed into operations on given qubits.
101
102    Returning `NotImplemented` or `None` means "not decomposable". Otherwise an
103    operation, list of operations, or generally anything meeting the `OP_TREE`
104    contract can be returned.
105
106    For example, a SWAP gate can be turned into three CNOTs. But in order to
107    describe those CNOTs one must be able to talk about "the target qubit" and
108    "the control qubit". This can only be done once the qubits-to-be-swapped are
109    known.
110
111    The main user of this protocol is `GateOperation`, which decomposes itself
112    by delegating to its gate. The qubits argument is needed because gates are
113    specified independently of target qubits and so must be told the relevant
114    qubits. A `GateOperation` implements `SupportsDecompose` as long as its gate
115    implements `SupportsDecomposeWithQubits`.
116    """
117
118    def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult:
119        pass
120
121
122def decompose(
123    val: Any,
124    *,
125    intercepting_decomposer: Optional[OpDecomposer] = None,
126    fallback_decomposer: Optional[OpDecomposer] = None,
127    keep: Optional[Callable[['cirq.Operation'], bool]] = None,
128    on_stuck_raise: Union[
129        None, Exception, Callable[['cirq.Operation'], Union[None, Exception]]
130    ] = _value_error_describing_bad_operation,
131    preserve_structure: bool = False,
132) -> List['cirq.Operation']:
133    """Recursively decomposes a value into `cirq.Operation`s meeting a criteria.
134
135    Args:
136        val: The value to decompose into operations.
137        intercepting_decomposer: An optional method that is called before the
138            default decomposer (the value's `_decompose_` method). If
139            `intercepting_decomposer` is specified and returns a result that
140            isn't `NotImplemented` or `None`, that result is used. Otherwise the
141            decomposition falls back to the default decomposer.
142
143            Note that `val` will be passed into `intercepting_decomposer`, even
144            if `val` isn't a `cirq.Operation`.
145        fallback_decomposer: An optional decomposition that used after the
146            `intercepting_decomposer` and the default decomposer (the value's
147            `_decompose_` method) both fail.
148        keep: A predicate that determines if the initial operation or
149            intermediate decomposed operations should be kept or else need to be
150            decomposed further. If `keep` isn't specified, it defaults to "value
151            can't be decomposed anymore".
152        on_stuck_raise: If there is an operation that can't be decomposed and
153            also can't be kept, `on_stuck_raise` is used to determine what error
154            to raise. `on_stuck_raise` can either directly be an `Exception`, or
155            a method that takes the problematic operation and returns an
156            `Exception`. If `on_stuck_raise` is set to `None` or a method that
157            returns `None`, non-decomposable operations are simply silently
158            kept. `on_stuck_raise` defaults to a `ValueError` describing the
159            unwanted non-decomposable operation.
160        preserve_structure: Prevents subcircuits (i.e. `CircuitOperation`s)
161            from being decomposed, but decomposes their contents. If this is
162            True, 'intercepting_decomposer' cannot be specified.
163
164    Returns:
165        A list of operations that the given value was decomposed into. If
166        `on_stuck_raise` isn't set to None, all operations in the list will
167        satisfy the predicate specified by `keep`.
168
169    Raises:
170        TypeError:
171            `val` isn't a `cirq.Operation` and can't be decomposed even once.
172            (So it's not possible to return a list of operations.)
173
174        ValueError:
175            Default type of error raised if there's an non-decomposable
176            operation that doesn't satisfy the given `keep` predicate.
177
178        TError:
179            Custom type of error raised if there's an non-decomposable operation
180            that doesn't satisfy the given `keep` predicate.
181    """
182
183    if on_stuck_raise is not _value_error_describing_bad_operation and keep is None:
184        raise ValueError(
185            "Must specify 'keep' if specifying 'on_stuck_raise', because it's "
186            "not possible to get stuck if you don't have a criteria on what's "
187            "acceptable to keep."
188        )
189
190    if preserve_structure:
191        return _decompose_preserving_structure(
192            val,
193            intercepting_decomposer=intercepting_decomposer,
194            fallback_decomposer=fallback_decomposer,
195            keep=keep,
196            on_stuck_raise=on_stuck_raise,
197        )
198
199    def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult:
200        if decomposer is None or not isinstance(val, ops.Operation):
201            return None
202        return decomposer(val)
203
204    output = []
205    queue: List[Any] = [val]
206    while queue:
207        item = queue.pop(0)
208        if isinstance(item, ops.Operation) and keep is not None and keep(item):
209            output.append(item)
210            continue
211
212        decomposed = try_op_decomposer(item, intercepting_decomposer)
213
214        if decomposed is NotImplemented or decomposed is None:
215            decomposed = decompose_once(item, default=None)
216
217        if decomposed is NotImplemented or decomposed is None:
218            decomposed = try_op_decomposer(item, fallback_decomposer)
219
220        if decomposed is not NotImplemented and decomposed is not None:
221            queue[:0] = ops.flatten_to_ops(decomposed)
222            continue
223
224        if not isinstance(item, ops.Operation) and isinstance(item, Iterable):
225            queue[:0] = ops.flatten_to_ops(item)
226            continue
227
228        if keep is not None and on_stuck_raise is not None:
229            if isinstance(on_stuck_raise, Exception):
230                raise on_stuck_raise
231            elif callable(on_stuck_raise):
232                error = on_stuck_raise(item)
233                if error is not None:
234                    raise error
235
236        output.append(item)
237
238    return output
239
240
241# pylint: disable=function-redefined
242
243
244@overload
245def decompose_once(val: Any, **kwargs) -> List['cirq.Operation']:
246    pass
247
248
249@overload
250def decompose_once(
251    val: Any, default: TDefault, *args, **kwargs
252) -> Union[TDefault, List['cirq.Operation']]:
253    pass
254
255
256def decompose_once(val: Any, default=RaiseTypeErrorIfNotProvided, *args, **kwargs):
257    """Decomposes a value into operations, if possible.
258
259    This method decomposes the value exactly once, instead of decomposing it
260    and then continuing to decomposing the decomposed operations recursively
261    until some criteria is met (which is what `cirq.decompose` does).
262
263    Args:
264        val: The value to call `_decompose_` on, if possible.
265        default: A default result to use if the value doesn't have a
266            `_decompose_` method or that method returns `NotImplemented` or
267            `None`. If not specified, non-decomposable values cause a
268            `TypeError`.
269        args: Positional arguments to forward into the `_decompose_` method of
270            `val`.  For example, this is used to tell gates what qubits they are
271            being applied to.
272        kwargs: Keyword arguments to forward into the `_decompose_` method of
273            `val`.
274
275    Returns:
276        The result of `val._decompose_(*args, **kwargs)`, if `val` has a
277        `_decompose_` method and it didn't return `NotImplemented` or `None`.
278        Otherwise `default` is returned, if it was specified. Otherwise an error
279        is raised.
280
281    Raises:
282        TypeError: `val` didn't have a `_decompose_` method (or that method returned
283            `NotImplemented` or `None`) and `default` wasn't set.
284    """
285    method = getattr(val, '_decompose_', None)
286    decomposed = NotImplemented if method is None else method(*args, **kwargs)
287
288    if decomposed is not NotImplemented and decomposed is not None:
289        return list(ops.flatten_op_tree(decomposed))
290
291    if default is not RaiseTypeErrorIfNotProvided:
292        return default
293    if method is None:
294        raise TypeError(f"object of type '{type(val)}' has no _decompose_ method.")
295    raise TypeError(
296        "object of type '{}' does have a _decompose_ method, "
297        "but it returned NotImplemented or None.".format(type(val))
298    )
299
300
301@overload
302def decompose_once_with_qubits(val: Any, qubits: Iterable['cirq.Qid']) -> List['cirq.Operation']:
303    pass
304
305
306@overload
307def decompose_once_with_qubits(
308    val: Any,
309    qubits: Iterable['cirq.Qid'],
310    default: Optional[TDefault],
311) -> Union[TDefault, List['cirq.Operation']]:
312    pass
313
314
315def decompose_once_with_qubits(
316    val: Any, qubits: Iterable['cirq.Qid'], default=RaiseTypeErrorIfNotProvided
317):
318    """Decomposes a value into operations on the given qubits.
319
320    This method is used when decomposing gates, which don't know which qubits
321    they are being applied to unless told. It decomposes the gate exactly once,
322    instead of decomposing it and then continuing to decomposing the decomposed
323    operations recursively until some criteria is met.
324
325    Args:
326        val: The value to call `._decompose_(qubits)` on, if possible.
327        qubits: The value to pass into the named `qubits` parameter of
328            `val._decompose_`.
329        default: A default result to use if the value doesn't have a
330            `_decompose_` method or that method returns `NotImplemented` or
331            `None`. If not specified, non-decomposable values cause a
332            `TypeError`.
333
334    Returns:
335        The result of `val._decompose_(qubits)`, if `val` has a
336        `_decompose_` method and it didn't return `NotImplemented` or `None`.
337        Otherwise `default` is returned, if it was specified. Otherwise an error
338        is raised.
339
340    TypeError:
341        `val` didn't have a `_decompose_` method (or that method returned
342        `NotImplemented` or `None`) and `default` wasn't set.
343    """
344    return decompose_once(val, default, tuple(qubits))
345
346
347# pylint: enable=function-redefined
348
349
350def _try_decompose_into_operations_and_qubits(
351    val: Any,
352) -> Tuple[Optional[List['cirq.Operation']], Sequence['cirq.Qid'], Tuple[int, ...]]:
353    """Returns the value's decomposition (if any) and the qubits it applies to."""
354
355    if isinstance(val, ops.Gate):
356        # Gates don't specify qubits, and so must be handled specially.
357        qid_shape = qid_shape_protocol.qid_shape(val)
358        qubits = devices.LineQid.for_qid_shape(qid_shape)  # type: Sequence[cirq.Qid]
359        return decompose_once_with_qubits(val, qubits, None), qubits, qid_shape
360
361    if isinstance(val, ops.Operation):
362        qid_shape = qid_shape_protocol.qid_shape(val)
363        return decompose_once(val, None), val.qubits, qid_shape
364
365    result = decompose_once(val, None)
366    if result is not None:
367        qubit_set = set()
368        qid_shape_dict: Dict[cirq.Qid, int] = defaultdict(lambda: 1)
369        for op in result:
370            for level, q in zip(qid_shape_protocol.qid_shape(op), op.qubits):
371                qubit_set.add(q)
372                qid_shape_dict[q] = max(qid_shape_dict[q], level)
373        qubits = sorted(qubit_set)
374        return result, qubits, tuple(qid_shape_dict[q] for q in qubits)
375
376    return None, (), ()
377
378
379def _decompose_preserving_structure(
380    val: Any,
381    *,
382    intercepting_decomposer: Optional[OpDecomposer] = None,
383    fallback_decomposer: Optional[OpDecomposer] = None,
384    keep: Optional[Callable[['cirq.Operation'], bool]] = None,
385    on_stuck_raise: Union[
386        None, Exception, Callable[['cirq.Operation'], Union[None, Exception]]
387    ] = _value_error_describing_bad_operation,
388) -> List['cirq.Operation']:
389    """Preserves structure (e.g. subcircuits) while decomposing ops.
390
391    This can be used to reduce a circuit to a particular gateset without
392    increasing its serialization size. See tests for examples.
393    """
394
395    # This method provides a generated 'keep' to its decompose() calls.
396    # If the user-provided keep is not set, on_stuck_raise must be unset to
397    # ensure that failure to decompose does not generate errors.
398    on_stuck_raise = on_stuck_raise if keep is not None else None
399
400    from cirq.circuits import CircuitOperation, FrozenCircuit
401
402    visited_fcs = set()
403
404    def keep_structure(op: 'cirq.Operation'):
405        circuit = getattr(op.untagged, 'circuit', None)
406        if circuit is not None:
407            return circuit in visited_fcs
408        if keep is not None and keep(op):
409            return True
410
411    def dps_interceptor(op: 'cirq.Operation'):
412        if not isinstance(op.untagged, CircuitOperation):
413            if intercepting_decomposer is None:
414                return NotImplemented
415            return intercepting_decomposer(op)
416
417        new_fc = FrozenCircuit(
418            decompose(
419                op.untagged.circuit,
420                intercepting_decomposer=dps_interceptor,
421                fallback_decomposer=fallback_decomposer,
422                keep=keep_structure,
423                on_stuck_raise=on_stuck_raise,
424            )
425        )
426        visited_fcs.add(new_fc)
427        new_co = op.untagged.replace(circuit=new_fc)
428        return new_co if not op.tags else new_co.with_tags(*op.tags)
429
430    return decompose(
431        val,
432        intercepting_decomposer=dps_interceptor,
433        fallback_decomposer=fallback_decomposer,
434        keep=keep_structure,
435        on_stuck_raise=on_stuck_raise,
436    )
437