1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union
16
17import numpy as np
18
19import cirq
20from cirq import protocols, value
21from cirq.ops import raw_types, controlled_operation as cop
22from cirq.type_workarounds import NotImplementedType
23
24
25@value.value_equality
26class ControlledGate(raw_types.Gate):
27    """Augments existing gates to have one or more control qubits.
28
29    This object is typically created via `gate.controlled()`.
30    """
31
32    def __init__(
33        self,
34        sub_gate: 'cirq.Gate',
35        num_controls: int = None,
36        control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
37        control_qid_shape: Optional[Sequence[int]] = None,
38    ) -> None:
39        """Initializes the controlled gate. If no arguments are specified for
40           the controls, defaults to a single qubit control.
41
42        Args:
43            sub_gate: The gate to add a control qubit to.
44            num_controls: Total number of control qubits.
45            control_values: For which control qubit values to apply the sub
46                gate.  A sequence of length `num_controls` where each
47                entry is an integer (or set of integers) corresponding to the
48                qubit value (or set of possible values) where that control is
49                enabled.  When all controls are enabled, the sub gate is
50                applied.  If unspecified, control values default to 1.
51            control_qid_shape: The qid shape of the controls.  A tuple of the
52                expected dimension of each control qid.  Defaults to
53                `(2,) * num_controls`.  Specify this argument when using qudits.
54
55        """
56        if num_controls is None:
57            if control_values is not None:
58                num_controls = len(control_values)
59            elif control_qid_shape is not None:
60                num_controls = len(control_qid_shape)
61            else:
62                num_controls = 1
63        if control_values is None:
64            control_values = ((1,),) * num_controls
65        if num_controls != len(control_values):
66            raise ValueError('len(control_values) != num_controls')
67
68        if control_qid_shape is None:
69            control_qid_shape = (2,) * num_controls
70        if num_controls != len(control_qid_shape):
71            raise ValueError('len(control_qid_shape) != num_controls')
72        self.control_qid_shape = tuple(control_qid_shape)
73
74        # Convert to sorted tuples
75        self.control_values = cast(
76            Tuple[Tuple[int, ...], ...],
77            tuple((val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values),
78        )
79        # Verify control values not out of bounds
80        for i, (val, dimension) in enumerate(zip(self.control_values, self.control_qid_shape)):
81            if not all(0 <= v < dimension for v in val):
82                raise ValueError(
83                    'Control values <{!r}> outside of range for control qubit '
84                    'number <{!r}>.'.format(val, i)
85                )
86
87        # Flatten nested ControlledGates.
88        if isinstance(sub_gate, ControlledGate):
89            self.sub_gate = sub_gate.sub_gate  # type: ignore
90            self.control_values += sub_gate.control_values
91            self.control_qid_shape += sub_gate.control_qid_shape
92        else:
93            self.sub_gate = sub_gate
94
95    def num_controls(self) -> int:
96        return len(self.control_qid_shape)
97
98    def _qid_shape_(self) -> Tuple[int, ...]:
99        return self.control_qid_shape + cirq.qid_shape(self.sub_gate)
100
101    def _decompose_(self, qubits):
102        result = protocols.decompose_once_with_qubits(
103            self.sub_gate, qubits[self.num_controls() :], NotImplemented
104        )
105
106        if result is NotImplemented:
107            return NotImplemented
108
109        decomposed = []
110        for op in result:
111            decomposed.append(
112                cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
113            )
114        return decomposed
115
116    def on(self, *qubits: 'cirq.Qid') -> cop.ControlledOperation:
117        if len(qubits) == 0:
118            raise ValueError(f"Applied a gate to an empty set of qubits. Gate: {self!r}")
119        self.validate_args(qubits)
120        return cop.ControlledOperation(
121            qubits[: self.num_controls()],
122            self.sub_gate.on(*qubits[self.num_controls() :]),
123            self.control_values,
124        )
125
126    def _value_equality_values_(self):
127        return (
128            self.sub_gate,
129            self.num_controls(),
130            frozenset(zip(self.control_values, self.control_qid_shape)),
131        )
132
133    def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray:
134        qubits = cirq.LineQid.for_gate(self)
135        op = self.sub_gate.on(*qubits[self.num_controls() :])
136        c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
137        return protocols.apply_unitary(c_op, args, default=NotImplemented)
138
139    def _has_unitary_(self) -> bool:
140        return protocols.has_unitary(self.sub_gate)
141
142    def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
143        qubits = cirq.LineQid.for_gate(self)
144        op = self.sub_gate.on(*qubits[self.num_controls() :])
145        c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
146
147        return protocols.unitary(c_op, default=NotImplemented)
148
149    def _has_mixture_(self) -> bool:
150        return protocols.has_mixture(self.sub_gate)
151
152    def _mixture_(self) -> Union[np.ndarray, NotImplementedType]:
153        qubits = cirq.LineQid.for_gate(self)
154        op = self.sub_gate.on(*qubits[self.num_controls() :])
155        c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
156        return protocols.mixture(c_op, default=NotImplemented)
157
158    def __pow__(self, exponent: Any) -> 'ControlledGate':
159        new_sub_gate = protocols.pow(self.sub_gate, exponent, NotImplemented)
160        if new_sub_gate is NotImplemented:
161            return NotImplemented
162        return ControlledGate(
163            new_sub_gate,
164            self.num_controls(),
165            control_values=self.control_values,
166            control_qid_shape=self.control_qid_shape,
167        )
168
169    def _is_parameterized_(self) -> bool:
170        return protocols.is_parameterized(self.sub_gate)
171
172    def _parameter_names_(self) -> AbstractSet[str]:
173        return protocols.parameter_names(self.sub_gate)
174
175    def _resolve_parameters_(
176        self, resolver: 'cirq.ParamResolver', recursive: bool
177    ) -> 'ControlledGate':
178        new_sub_gate = protocols.resolve_parameters(self.sub_gate, resolver, recursive)
179        return ControlledGate(
180            new_sub_gate,
181            self.num_controls(),
182            control_values=self.control_values,
183            control_qid_shape=self.control_qid_shape,
184        )
185
186    def _trace_distance_bound_(self) -> Optional[float]:
187        if self._is_parameterized_():
188            return None
189        u = protocols.unitary(self.sub_gate, default=None)
190        if u is None:
191            return NotImplemented
192        angle_list = np.append(np.angle(np.linalg.eigvals(u)), 0)
193        return protocols.trace_distance_from_angle_list(angle_list)
194
195    def _circuit_diagram_info_(
196        self, args: 'cirq.CircuitDiagramInfoArgs'
197    ) -> 'cirq.CircuitDiagramInfo':
198        sub_args = protocols.CircuitDiagramInfoArgs(
199            known_qubit_count=(
200                args.known_qubit_count - self.num_controls()
201                if args.known_qubit_count is not None
202                else None
203            ),
204            known_qubits=(
205                args.known_qubits[self.num_controls() :] if args.known_qubits is not None else None
206            ),
207            use_unicode_characters=args.use_unicode_characters,
208            precision=args.precision,
209            qubit_map=args.qubit_map,
210        )
211        sub_info = protocols.circuit_diagram_info(self.sub_gate, sub_args, None)
212        if sub_info is None:
213            return NotImplemented
214
215        def get_symbol(vals):
216            if tuple(vals) == (1,):
217                return '@'
218            return f"({','.join(map(str, vals))})"
219
220        return protocols.CircuitDiagramInfo(
221            wire_symbols=(
222                *(get_symbol(vals) for vals in self.control_values),
223                *sub_info.wire_symbols,
224            ),
225            exponent=sub_info.exponent,
226        )
227
228    def __str__(self) -> str:
229        if set(self.control_values) == {(1,)}:
230
231            def get_prefix(control_vals):
232                return 'C'
233
234        else:
235
236            def get_prefix(control_vals):
237                control_vals_str = ''.join(map(str, sorted(control_vals)))
238                return f'C{control_vals_str}'
239
240        return ''.join(map(get_prefix, self.control_values)) + str(self.sub_gate)
241
242    def __repr__(self) -> str:
243        if self.num_controls() == 1 and self.control_values == ((1,),):
244            return f'cirq.ControlledGate(sub_gate={self.sub_gate!r})'
245
246        if all(vals == (1,) for vals in self.control_values) and set(self.control_qid_shape) == {2}:
247            return (
248                f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, '
249                f'num_controls={self.num_controls()!r})'
250            )
251        return (
252            f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, '
253            f'control_values={self.control_values!r},'
254            f'control_qid_shape={self.control_qid_shape!r})'
255        )
256
257    def _json_dict_(self) -> Dict[str, Any]:
258        return {
259            'cirq_type': self.__class__.__name__,
260            'control_values': self.control_values,
261            'control_qid_shape': self.control_qid_shape,
262            'sub_gate': self.sub_gate,
263        }
264