1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Basic types defining qubits, gates, and operations."""
16
17import re
18from typing import (
19    AbstractSet,
20    Any,
21    cast,
22    Collection,
23    Dict,
24    FrozenSet,
25    List,
26    Optional,
27    Sequence,
28    Tuple,
29    TypeVar,
30    TYPE_CHECKING,
31    Union,
32)
33
34import numpy as np
35
36from cirq import protocols, value
37from cirq.ops import raw_types, gate_features
38from cirq.type_workarounds import NotImplementedType
39
40if TYPE_CHECKING:
41    import cirq
42
43
44TSelf = TypeVar('TSelf', bound='GateOperation')
45
46
47@value.value_equality(approximate=True)
48class GateOperation(raw_types.Operation):
49    """An application of a gate to a sequence of qubits.
50
51    Objects of this type are immutable.
52    """
53
54    def __init__(self, gate: 'cirq.Gate', qubits: Sequence['cirq.Qid']) -> None:
55        """Inits GateOperation.
56
57        Args:
58            gate: The gate to apply.
59            qubits: The qubits to operate on.
60        """
61        gate.validate_args(qubits)
62        self._gate = gate
63        self._qubits = tuple(qubits)
64
65    @property
66    def gate(self) -> 'cirq.Gate':
67        """The gate applied by the operation."""
68        return self._gate
69
70    @property
71    def qubits(self) -> Tuple['cirq.Qid', ...]:
72        """The qubits targeted by the operation."""
73        return self._qubits
74
75    def with_qubits(self: TSelf, *new_qubits: 'cirq.Qid') -> TSelf:
76        return cast(TSelf, self.gate.on(*new_qubits))
77
78    def with_gate(self, new_gate: 'cirq.Gate') -> 'cirq.Operation':
79        if self.gate is new_gate:
80            # As GateOperation is immutable, this can return the original.
81            return self
82        return new_gate.on(*self.qubits)
83
84    def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
85        new_gate = protocols.with_measurement_key_mapping(self.gate, key_map)
86        if new_gate is NotImplemented:
87            return NotImplemented
88        if new_gate is self.gate:
89            # As GateOperation is immutable, this can return the original.
90            return self
91        return new_gate.on(*self.qubits)
92
93    def _with_key_path_(self, path: Tuple[str, ...]):
94        new_gate = protocols.with_key_path(self.gate, path)
95        if new_gate is NotImplemented:
96            return NotImplemented
97        if new_gate is self.gate:
98            # As GateOperation is immutable, this can return the original.
99            return self
100        return new_gate.on(*self.qubits)
101
102    def __repr__(self):
103        if hasattr(self.gate, '_op_repr_'):
104            result = self.gate._op_repr_(self.qubits)
105            if result is not None and result is not NotImplemented:
106                return result
107        gate_repr = repr(self.gate)
108        qubit_args_repr = ', '.join(repr(q) for q in self.qubits)
109        assert type(self.gate).__call__ == raw_types.Gate.__call__
110
111        # Abbreviate when possible.
112        dont_need_on = re.match(r'^[a-zA-Z0-9.()]+$', gate_repr)
113        if dont_need_on and self == self.gate.__call__(*self.qubits):
114            return f'{gate_repr}({qubit_args_repr})'
115        if self == self.gate.on(*self.qubits):
116            return f'{gate_repr}.on({qubit_args_repr})'
117
118        return f'cirq.GateOperation(gate={self.gate!r}, qubits=[{qubit_args_repr}])'
119
120    def __str__(self) -> str:
121        qubits = ', '.join(str(e) for e in self.qubits)
122        return f'{self.gate}({qubits})'
123
124    def _json_dict_(self) -> Dict[str, Any]:
125        return protocols.obj_to_dict_helper(self, ['gate', 'qubits'])
126
127    def _group_interchangeable_qubits(
128        self,
129    ) -> Tuple[Union['cirq.Qid', Tuple[int, FrozenSet['cirq.Qid']]], ...]:
130
131        if not isinstance(self.gate, gate_features.InterchangeableQubitsGate):
132            return self.qubits
133
134        groups: Dict[int, List['cirq.Qid']] = {}
135        for i, q in enumerate(self.qubits):
136            k = self.gate.qubit_index_to_equivalence_group_key(i)
137            if k not in groups:
138                groups[k] = []
139            groups[k].append(q)
140        return tuple(sorted((k, frozenset(v)) for k, v in groups.items()))
141
142    def _value_equality_values_(self):
143        return self.gate, self._group_interchangeable_qubits()
144
145    def _qid_shape_(self):
146        return self.gate._qid_shape_()
147
148    def _num_qubits_(self):
149        return len(self._qubits)
150
151    def _decompose_(self) -> 'cirq.OP_TREE':
152        return protocols.decompose_once_with_qubits(self.gate, self.qubits, NotImplemented)
153
154    def _pauli_expansion_(self) -> value.LinearDict[str]:
155        getter = getattr(self.gate, '_pauli_expansion_', None)
156        if getter is not None:
157            return getter()
158        return NotImplemented
159
160    def _apply_unitary_(
161        self, args: 'protocols.ApplyUnitaryArgs'
162    ) -> Union[np.ndarray, None, NotImplementedType]:
163        getter = getattr(self.gate, '_apply_unitary_', None)
164        if getter is not None:
165            return getter(args)
166        return NotImplemented
167
168    def _has_unitary_(self) -> bool:
169        getter = getattr(self.gate, '_has_unitary_', None)
170        if getter is not None:
171            return getter()
172        return NotImplemented
173
174    def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
175        getter = getattr(self.gate, '_unitary_', None)
176        if getter is not None:
177            return getter()
178        return NotImplemented
179
180    def _commutes_(
181        self, other: Any, atol: Union[int, float] = 1e-8
182    ) -> Union[bool, NotImplementedType, None]:
183        commutes = self.gate._commutes_on_qids_(self.qubits, other, atol=atol)
184        if commutes is not NotImplemented:
185            return commutes
186
187        return super()._commutes_(other, atol=atol)
188
189    def _has_mixture_(self) -> bool:
190        getter = getattr(self.gate, '_has_mixture_', None)
191        if getter is not None:
192            return getter()
193        return NotImplemented
194
195    def _mixture_(self) -> Sequence[Tuple[float, Any]]:
196        getter = getattr(self.gate, '_mixture_', None)
197        if getter is not None:
198            return getter()
199        return NotImplemented
200
201    def _has_kraus_(self) -> bool:
202        getter = getattr(self.gate, '_has_kraus_', None)
203        if getter is not None:
204            return getter()
205        return NotImplemented
206
207    def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
208        getter = getattr(self.gate, '_kraus_', None)
209        if getter is not None:
210            return getter()
211        return NotImplemented
212
213    def _is_measurement_(self) -> Optional[bool]:
214        getter = getattr(self.gate, '_is_measurement_', None)
215        if getter is not None:
216            return getter()
217        # Let the protocol handle the fallback.
218        return NotImplemented
219
220    def _measurement_key_name_(self) -> Optional[str]:
221        getter = getattr(self.gate, '_measurement_key_name_', None)
222        if getter is not None:
223            return getter()
224        return NotImplemented
225
226    def _measurement_key_names_(self) -> Optional[AbstractSet[str]]:
227        getter = getattr(self.gate, '_measurement_key_names_', None)
228        if getter is not None:
229            return getter()
230        return NotImplemented
231
232    def _measurement_key_obj_(self) -> Optional[value.MeasurementKey]:
233        getter = getattr(self.gate, '_measurement_key_obj_', None)
234        if getter is not None:
235            return getter()
236        return NotImplemented
237
238    def _measurement_key_objs_(self) -> Optional[AbstractSet[value.MeasurementKey]]:
239        getter = getattr(self.gate, '_measurement_key_objs_', None)
240        if getter is not None:
241            return getter()
242        return NotImplemented
243
244    def _act_on_(self, args: 'cirq.ActOnArgs'):
245        getter = getattr(self.gate, '_act_on_', None)
246        if getter is not None:
247            return getter(args, self.qubits)
248        return NotImplemented
249
250    def _is_parameterized_(self) -> bool:
251        getter = getattr(self.gate, '_is_parameterized_', None)
252        if getter is not None:
253            return getter()
254        return NotImplemented
255
256    def _parameter_names_(self) -> AbstractSet[str]:
257        getter = getattr(self.gate, '_parameter_names_', None)
258        if getter is not None:
259            return getter()
260        return NotImplemented
261
262    def _resolve_parameters_(
263        self, resolver: 'cirq.ParamResolver', recursive: bool
264    ) -> 'cirq.Operation':
265        resolved_gate = protocols.resolve_parameters(self.gate, resolver, recursive)
266        return self.with_gate(resolved_gate)
267
268    def _circuit_diagram_info_(
269        self, args: 'cirq.CircuitDiagramInfoArgs'
270    ) -> 'cirq.CircuitDiagramInfo':
271        return protocols.circuit_diagram_info(self.gate, args, NotImplemented)
272
273    def _decompose_into_clifford_(self):
274        sub = getattr(self.gate, '_decompose_into_clifford_with_qubits_', None)
275        if sub is None:
276            return NotImplemented
277        return sub(self.qubits)
278
279    def _trace_distance_bound_(self) -> float:
280        getter = getattr(self.gate, '_trace_distance_bound_', None)
281        if getter is not None:
282            return getter()
283        return NotImplemented
284
285    def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'GateOperation':
286        phased_gate = protocols.phase_by(self.gate, phase_turns, qubit_index, default=None)
287        if phased_gate is None:
288            return NotImplemented
289        return GateOperation(phased_gate, self._qubits)
290
291    def __pow__(self, exponent: Any) -> 'cirq.Operation':
292        """Raise gate to a power, then reapply to the same qubits.
293
294        Only works if the gate implements cirq.ExtrapolatableEffect.
295        For extrapolatable gate G this means the following two are equivalent:
296
297            (G ** 1.5)(qubit)  or  G(qubit) ** 1.5
298
299        Args:
300            exponent: The amount to scale the gate's effect by.
301
302        Returns:
303            A new operation on the same qubits with the scaled gate.
304        """
305        new_gate = protocols.pow(self.gate, exponent, NotImplemented)
306        if new_gate is NotImplemented:
307            return NotImplemented
308        return self.with_gate(new_gate)
309
310    def __mul__(self, other: Any) -> Any:
311        result = self.gate._mul_with_qubits(self._qubits, other)
312
313        # python will not auto-attempt the reverse order for same type.
314        if result is NotImplemented and isinstance(other, GateOperation):
315            return other.__rmul__(self)
316
317        return result
318
319    def __rmul__(self, other: Any) -> Any:
320        return self.gate._rmul_with_qubits(self._qubits, other)
321
322    def _qasm_(self, args: 'protocols.QasmArgs') -> Optional[str]:
323        return protocols.qasm(self.gate, args=args, qubits=self.qubits, default=None)
324
325    def _quil_(self, formatter: 'protocols.QuilFormatter') -> Optional[str]:
326        return protocols.quil(self.gate, qubits=self.qubits, formatter=formatter)
327
328    def _equal_up_to_global_phase_(
329        self, other: Any, atol: Union[int, float] = 1e-8
330    ) -> Union[NotImplementedType, bool]:
331        if not isinstance(other, type(self)):
332            return NotImplemented
333        if self.qubits != other.qubits:
334            return False
335        return protocols.equal_up_to_global_phase(self.gate, other.gate, atol=atol)
336
337    def controlled_by(
338        self,
339        *control_qubits: 'cirq.Qid',
340        control_values: Optional[Sequence[Union[int, Collection[int]]]] = None,
341    ) -> 'cirq.Operation':
342        if len(control_qubits) == 0:
343            return self
344        qubits = tuple(control_qubits)
345        return self._gate.controlled(
346            num_controls=len(qubits),
347            control_values=control_values,
348            control_qid_shape=tuple(q.dimension for q in qubits),
349        ).on(*(qubits + self._qubits))
350
351
352TV = TypeVar('TV', bound=raw_types.Gate)
353