1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Functionality for grouping and validating Cirq Gates"""
16
17from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union
18from cirq.ops import global_phase_op, op_tree, raw_types
19from cirq import protocols, value
20
21if TYPE_CHECKING:
22    import cirq
23
24
25def _gate_str(
26    gate: Union[raw_types.Gate, Type[raw_types.Gate], 'cirq.GateFamily'],
27    gettr: Callable[[Any], str] = str,
28) -> str:
29    return gettr(gate) if not isinstance(gate, type) else f'{gate.__module__}.{gate.__name__}'
30
31
32@value.value_equality(distinct_child_types=True)
33class GateFamily:
34    """Wrapper around gate instances/types describing a set of accepted gates.
35
36    GateFamily supports initialization via
37        a) Non-parameterized instances of `cirq.Gate` (Instance Family).
38        b) Python types inheriting from `cirq.Gate` (Type Family).
39
40    By default, the containment checks depend on the initialization type:
41        a) Instance Family: Containment check is done via `cirq.equal_up_to_global_phase`.
42        b) Type Family: Containment check is done by type comparison.
43
44    For example:
45        a) Instance Family:
46            >>> gate_family = cirq.GateFamily(cirq.X)
47            >>> assert cirq.X in gate_family
48            >>> assert cirq.Rx(rads=np.pi) in gate_family
49            >>> assert cirq.X ** sympy.Symbol("theta") not in gate_family
50
51        b) Type Family:
52            >>> gate_family = cirq.GateFamily(cirq.XPowGate)
53            >>> assert cirq.X in gate_family
54            >>> assert cirq.Rx(rads=np.pi) in gate_family
55            >>> assert cirq.X ** sympy.Symbol("theta") in gate_family
56
57    In order to create gate families with constraints on parameters of a gate
58    type, users should derive from the `cirq.GateFamily` class and override the
59    `_predicate` method used to check for gate containment.
60    """
61
62    def __init__(
63        self,
64        gate: Union[Type[raw_types.Gate], raw_types.Gate],
65        *,
66        name: Optional[str] = None,
67        description: Optional[str] = None,
68        ignore_global_phase: bool = True,
69    ) -> None:
70        """Init GateFamily.
71
72        Args:
73            gate: A python `type` inheriting from `cirq.Gate` for type based membership checks, or
74                a non-parameterized instance of a `cirq.Gate` for equality based membership checks.
75            name: The name of the gate family.
76            description: Human readable description of the gate family.
77            ignore_global_phase: If True, value equality is checked via
78                `cirq.equal_up_to_global_phase`.
79
80        Raises:
81            ValueError: if `gate` is not a `cirq.Gate` instance or subclass.
82            ValueError: if `gate` is a parameterized instance of `cirq.Gate`.
83        """
84        if not (
85            isinstance(gate, raw_types.Gate)
86            or (isinstance(gate, type) and issubclass(gate, raw_types.Gate))
87        ):
88            raise ValueError(f'Gate {gate} must be an instance or subclass of `cirq.Gate`.')
89        if isinstance(gate, raw_types.Gate) and protocols.is_parameterized(gate):
90            raise ValueError(f'Gate {gate} must be a non-parameterized instance of `cirq.Gate`.')
91
92        self._gate = gate
93        self._name = name if name else self._default_name()
94        self._description = description if description else self._default_description()
95        self._ignore_global_phase = ignore_global_phase
96
97    def _gate_str(self, gettr: Callable[[Any], str] = str) -> str:
98        return _gate_str(self.gate, gettr)
99
100    def _default_name(self) -> str:
101        family_type = 'Instance' if isinstance(self.gate, raw_types.Gate) else 'Type'
102        return f'{family_type} GateFamily: {self._gate_str()}'
103
104    def _default_description(self) -> str:
105        check_type = r'g == {}' if isinstance(self.gate, raw_types.Gate) else r'isinstance(g, {})'
106        return f'Accepts `cirq.Gate` instances `g` s.t. `{check_type.format(self._gate_str())}`'
107
108    @property
109    def gate(self) -> Union[Type[raw_types.Gate], raw_types.Gate]:
110        return self._gate
111
112    @property
113    def name(self) -> str:
114        return self._name
115
116    @property
117    def description(self) -> str:
118        return self._description
119
120    def _predicate(self, gate: raw_types.Gate) -> bool:
121        """Checks whether `cirq.Gate` instance `gate` belongs to this GateFamily.
122
123        The default predicate depends on the gate family initialization type:
124            a) Instance Family: `cirq.equal_up_to_global_phase(gate, self.gate)`
125                                 if self._ignore_global_phase else `gate == self.gate`.
126            b) Type Family: `isinstance(gate, self.gate)`.
127
128        Args:
129            gate: `cirq.Gate` instance which should be checked for containment.
130        """
131        if isinstance(self.gate, raw_types.Gate):
132            return (
133                protocols.equal_up_to_global_phase(gate, self.gate)
134                if self._ignore_global_phase
135                else gate == self._gate
136            )
137        return isinstance(gate, self.gate)
138
139    def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
140        if isinstance(item, raw_types.Operation):
141            if item.gate is None:
142                return False
143            item = item.gate
144        return self._predicate(item)
145
146    def __str__(self) -> str:
147        return f'{self.name}\n{self.description}'
148
149    def __repr__(self) -> str:
150        name_and_description = ''
151        if self.name != self._default_name() or self.description != self._default_description():
152            name_and_description = f'name="{self.name}", description="{self.description}", '
153        return (
154            f'cirq.GateFamily('
155            f'gate={self._gate_str(repr)}, '
156            f'{name_and_description}'
157            f'ignore_global_phase={self._ignore_global_phase})'
158        )
159
160    def _value_equality_values_(self) -> Any:
161        # `isinstance` is used to ensure the a gate type and gate instance is not compared.
162        return (
163            isinstance(self.gate, raw_types.Gate),
164            self.gate,
165            self.name,
166            self.description,
167            self._ignore_global_phase,
168        )
169
170
171@value.value_equality()
172class Gateset:
173    """Gatesets represent a collection of `cirq.GateFamily` objects.
174
175    Gatesets are useful for
176        a) Describing the set of allowed gates in a human readable format
177        b) Validating a given gate / optree against the set of allowed gates
178
179    Gatesets rely on the underlying `cirq.GateFamily` for both description and
180    validation purposes.
181    """
182
183    def __init__(
184        self,
185        *gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily],
186        name: Optional[str] = None,
187        unroll_circuit_op: bool = True,
188        accept_global_phase_op: bool = True,
189    ) -> None:
190        """Init Gateset.
191
192        Accepts a list of gates, each of which should be either
193            a) `cirq.Gate` subclass
194            b) `cirq.Gate` instance
195            c) `cirq.GateFamily` instance
196
197        `cirq.Gate` subclasses and instances are converted to the default
198        `cirq.GateFamily(gate=g)` instance and thus a default name and
199        description is populated.
200
201        Args:
202            *gates: A list of `cirq.Gate` subclasses / `cirq.Gate` instances /
203            `cirq.GateFamily` instances to initialize the Gateset.
204            name: (Optional) Name for the Gateset. Useful for description.
205            unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively
206                validated by validating the underlying `cirq.Circuit`.
207            accept_global_phase_op: If True, `cirq.GlobalPhaseOperation` is accepted.
208        """
209        self._name = name
210        self._unroll_circuit_op = unroll_circuit_op
211        self._accept_global_phase_op = accept_global_phase_op
212        self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {}
213        self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {}
214        self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates])
215        unique_gate_list: List[GateFamily] = list(
216            dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates)
217        )
218        for g in unique_gate_list:
219            if type(g) == GateFamily:
220                if isinstance(g.gate, raw_types.Gate):
221                    self._instance_gate_families[g.gate] = g
222                else:
223                    self._type_gate_families[g.gate] = g
224        self._gates_str_str = "\n\n".join([str(g) for g in unique_gate_list])
225        self._gates = frozenset(unique_gate_list)
226
227    @property
228    def name(self) -> Optional[str]:
229        return self._name
230
231    @property
232    def gates(self) -> FrozenSet[GateFamily]:
233        return self._gates
234
235    def with_params(
236        self,
237        *,
238        name: Optional[str] = None,
239        unroll_circuit_op: Optional[bool] = None,
240        accept_global_phase_op: Optional[bool] = None,
241    ) -> 'Gateset':
242        """Returns a copy of this Gateset with identical gates and new values for named arguments.
243
244        If a named argument is None then corresponding value of this Gateset is used instead.
245
246        Args:
247            name: New name for the Gateset.
248            unroll_circuit_op: If True, new Gateset will recursively validate
249                `cirq.CircuitOperation` by validating the underlying `cirq.Circuit`.
250            accept_global_phase_op: If True, new Gateset will accept `cirq.GlobalPhaseOperation`.
251
252        Returns:
253            `self` if all new values are None or identical to the values of current Gateset.
254            else a new Gateset with identical gates and new values for named arguments.
255        """
256
257        def val_if_none(var: Any, val: Any) -> Any:
258            return var if var is not None else val
259
260        name = val_if_none(name, self._name)
261        unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op)
262        accept_global_phase_op = val_if_none(accept_global_phase_op, self._accept_global_phase_op)
263        if (
264            name == self._name
265            and unroll_circuit_op == self._unroll_circuit_op
266            and accept_global_phase_op == self._accept_global_phase_op
267        ):
268            return self
269        return Gateset(
270            *self.gates,
271            name=name,
272            unroll_circuit_op=cast(bool, unroll_circuit_op),
273            accept_global_phase_op=cast(bool, accept_global_phase_op),
274        )
275
276    def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool:
277        """Check for containment of a given Gate/Operation in this Gateset.
278
279        Containment checks are handled as follows:
280            a) For Gates or Operations that have an underlying gate (i.e. op.gate is not None):
281                - Forwards the containment check to the underlying `cirq.GateFamily` objects.
282                - Examples of such operations include `cirq.GateOperations` and their controlled
283                    and tagged variants (i.e. instances of `cirq.TaggedOperation`,
284                    `cirq.ControlledOperation` where `op.gate` is not None) etc.
285
286            b) For Operations that do not have an underlying gate:
287                - Forwards the containment check to `self._validate_operation(item)`.
288                - Examples of such operations include `cirq.CircuitOperations` and their controlled
289                    and tagged variants (i.e. instances of `cirq.TaggedOperation`,
290                    `cirq.ControlledOperation` where `op.gate` is None) etc.
291
292        The complexity of the method is:
293            a) O(1) when any default `cirq.GateFamily` instance accepts the given item, except
294                for an Instance GateFamily trying to match an item with a different global phase.
295            b) O(n) for all other cases: matching against custom gate families, matching across
296                global phase for the default Instance GateFamily, no match against any underlying
297                gate family.
298
299        Args:
300            item: The `cirq.Gate` or `cirq.Operation` instance to check containment for.
301        """
302        if isinstance(item, raw_types.Operation) and item.gate is None:
303            return self._validate_operation(item)
304
305        g = item if isinstance(item, raw_types.Gate) else item.gate
306        assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'
307
308        if g in self._instance_gate_families:
309            assert item in self._instance_gate_families[g], (
310                f"{item} instance matches {self._instance_gate_families[g]} but "
311                f"is not accepted by it."
312            )
313            return True
314
315        for gate_mro_type in type(g).mro():
316            if gate_mro_type in self._type_gate_families:
317                assert item in self._type_gate_families[gate_mro_type], (
318                    f"{g} type {gate_mro_type} matches Type GateFamily:"
319                    f"{self._type_gate_families[gate_mro_type]} but is not accepted by it."
320                )
321                return True
322
323        return any(item in gate_family for gate_family in self._gates)
324
325    def validate(
326        self,
327        circuit_or_optree: Union['cirq.AbstractCircuit', op_tree.OP_TREE],
328    ) -> bool:
329        """Validates gates forming `circuit_or_optree` should be contained in Gateset.
330
331        Args:
332            circuit_or_optree: The `cirq.Circuit` or `cirq.OP_TREE` to validate.
333        """
334        # To avoid circular import.
335        from cirq.circuits import circuit
336
337        optree = circuit_or_optree
338        if isinstance(circuit_or_optree, circuit.AbstractCircuit):
339            optree = circuit_or_optree.all_operations()
340        return all(self._validate_operation(op) for op in op_tree.flatten_to_ops(optree))
341
342    def _validate_operation(self, op: raw_types.Operation) -> bool:
343        """Validates whether the given `cirq.Operation` is contained in this Gateset.
344
345        The containment checks are handled as follows:
346
347        a) For any operation which has an underlying gate (i.e. `op.gate` is not None):
348            - Containment is checked via `self.__contains__` which further checks for containment
349                in any of the underlying gate families.
350
351        b) For all other types of operations (eg: `cirq.CircuitOperation`,
352        `cirq.GlobalPhaseOperation` etc):
353            - The behavior is controlled via flags passed to the constructor.
354
355        Users should override this method to define custom behavior for operations that do not
356        have an underlying `cirq.Gate`.
357
358        Args:
359            op: The `cirq.Operation` instance to check containment for.
360        """
361
362        # To avoid circular import.
363        from cirq.circuits import circuit_operation
364
365        if op.gate is not None:
366            return op in self
367
368        if isinstance(op, raw_types.TaggedOperation):
369            return self._validate_operation(op.sub_operation)
370        elif isinstance(op, circuit_operation.CircuitOperation) and self._unroll_circuit_op:
371            op_circuit = protocols.resolve_parameters(
372                op.circuit.unfreeze(), op.param_resolver, recursive=False
373            )
374            op_circuit = op_circuit.transform_qubits(
375                lambda q: cast(circuit_operation.CircuitOperation, op).qubit_map.get(q, q)
376            )
377            return self.validate(op_circuit)
378        elif isinstance(op, global_phase_op.GlobalPhaseOperation):
379            return self._accept_global_phase_op
380        else:
381            return False
382
383    def _value_equality_values_(self) -> Any:
384        return (
385            self.gates,
386            self.name,
387            self._unroll_circuit_op,
388            self._accept_global_phase_op,
389        )
390
391    def __repr__(self) -> str:
392        name_str = f'name = "{self.name}", ' if self.name is not None else ''
393        return (
394            f'cirq.Gateset('
395            f'{self._gates_repr_str}, '
396            f'{name_str}'
397            f'unroll_circuit_op = {self._unroll_circuit_op},'
398            f'accept_global_phase_op = {self._accept_global_phase_op})'
399        )
400
401    def __str__(self) -> str:
402        header = 'Gateset: '
403        if self.name:
404            header += self.name
405        return f'{header}\n' + self._gates_str_str
406