1# Copyright 2021 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Functionality for grouping and validating Cirq Gates""" 16 17from typing import Any, Callable, cast, Dict, FrozenSet, List, Optional, Type, TYPE_CHECKING, Union 18from cirq.ops import global_phase_op, op_tree, raw_types 19from cirq import protocols, value 20 21if TYPE_CHECKING: 22 import cirq 23 24 25def _gate_str( 26 gate: Union[raw_types.Gate, Type[raw_types.Gate], 'cirq.GateFamily'], 27 gettr: Callable[[Any], str] = str, 28) -> str: 29 return gettr(gate) if not isinstance(gate, type) else f'{gate.__module__}.{gate.__name__}' 30 31 32@value.value_equality(distinct_child_types=True) 33class GateFamily: 34 """Wrapper around gate instances/types describing a set of accepted gates. 35 36 GateFamily supports initialization via 37 a) Non-parameterized instances of `cirq.Gate` (Instance Family). 38 b) Python types inheriting from `cirq.Gate` (Type Family). 39 40 By default, the containment checks depend on the initialization type: 41 a) Instance Family: Containment check is done via `cirq.equal_up_to_global_phase`. 42 b) Type Family: Containment check is done by type comparison. 43 44 For example: 45 a) Instance Family: 46 >>> gate_family = cirq.GateFamily(cirq.X) 47 >>> assert cirq.X in gate_family 48 >>> assert cirq.Rx(rads=np.pi) in gate_family 49 >>> assert cirq.X ** sympy.Symbol("theta") not in gate_family 50 51 b) Type Family: 52 >>> gate_family = cirq.GateFamily(cirq.XPowGate) 53 >>> assert cirq.X in gate_family 54 >>> assert cirq.Rx(rads=np.pi) in gate_family 55 >>> assert cirq.X ** sympy.Symbol("theta") in gate_family 56 57 In order to create gate families with constraints on parameters of a gate 58 type, users should derive from the `cirq.GateFamily` class and override the 59 `_predicate` method used to check for gate containment. 60 """ 61 62 def __init__( 63 self, 64 gate: Union[Type[raw_types.Gate], raw_types.Gate], 65 *, 66 name: Optional[str] = None, 67 description: Optional[str] = None, 68 ignore_global_phase: bool = True, 69 ) -> None: 70 """Init GateFamily. 71 72 Args: 73 gate: A python `type` inheriting from `cirq.Gate` for type based membership checks, or 74 a non-parameterized instance of a `cirq.Gate` for equality based membership checks. 75 name: The name of the gate family. 76 description: Human readable description of the gate family. 77 ignore_global_phase: If True, value equality is checked via 78 `cirq.equal_up_to_global_phase`. 79 80 Raises: 81 ValueError: if `gate` is not a `cirq.Gate` instance or subclass. 82 ValueError: if `gate` is a parameterized instance of `cirq.Gate`. 83 """ 84 if not ( 85 isinstance(gate, raw_types.Gate) 86 or (isinstance(gate, type) and issubclass(gate, raw_types.Gate)) 87 ): 88 raise ValueError(f'Gate {gate} must be an instance or subclass of `cirq.Gate`.') 89 if isinstance(gate, raw_types.Gate) and protocols.is_parameterized(gate): 90 raise ValueError(f'Gate {gate} must be a non-parameterized instance of `cirq.Gate`.') 91 92 self._gate = gate 93 self._name = name if name else self._default_name() 94 self._description = description if description else self._default_description() 95 self._ignore_global_phase = ignore_global_phase 96 97 def _gate_str(self, gettr: Callable[[Any], str] = str) -> str: 98 return _gate_str(self.gate, gettr) 99 100 def _default_name(self) -> str: 101 family_type = 'Instance' if isinstance(self.gate, raw_types.Gate) else 'Type' 102 return f'{family_type} GateFamily: {self._gate_str()}' 103 104 def _default_description(self) -> str: 105 check_type = r'g == {}' if isinstance(self.gate, raw_types.Gate) else r'isinstance(g, {})' 106 return f'Accepts `cirq.Gate` instances `g` s.t. `{check_type.format(self._gate_str())}`' 107 108 @property 109 def gate(self) -> Union[Type[raw_types.Gate], raw_types.Gate]: 110 return self._gate 111 112 @property 113 def name(self) -> str: 114 return self._name 115 116 @property 117 def description(self) -> str: 118 return self._description 119 120 def _predicate(self, gate: raw_types.Gate) -> bool: 121 """Checks whether `cirq.Gate` instance `gate` belongs to this GateFamily. 122 123 The default predicate depends on the gate family initialization type: 124 a) Instance Family: `cirq.equal_up_to_global_phase(gate, self.gate)` 125 if self._ignore_global_phase else `gate == self.gate`. 126 b) Type Family: `isinstance(gate, self.gate)`. 127 128 Args: 129 gate: `cirq.Gate` instance which should be checked for containment. 130 """ 131 if isinstance(self.gate, raw_types.Gate): 132 return ( 133 protocols.equal_up_to_global_phase(gate, self.gate) 134 if self._ignore_global_phase 135 else gate == self._gate 136 ) 137 return isinstance(gate, self.gate) 138 139 def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool: 140 if isinstance(item, raw_types.Operation): 141 if item.gate is None: 142 return False 143 item = item.gate 144 return self._predicate(item) 145 146 def __str__(self) -> str: 147 return f'{self.name}\n{self.description}' 148 149 def __repr__(self) -> str: 150 name_and_description = '' 151 if self.name != self._default_name() or self.description != self._default_description(): 152 name_and_description = f'name="{self.name}", description="{self.description}", ' 153 return ( 154 f'cirq.GateFamily(' 155 f'gate={self._gate_str(repr)}, ' 156 f'{name_and_description}' 157 f'ignore_global_phase={self._ignore_global_phase})' 158 ) 159 160 def _value_equality_values_(self) -> Any: 161 # `isinstance` is used to ensure the a gate type and gate instance is not compared. 162 return ( 163 isinstance(self.gate, raw_types.Gate), 164 self.gate, 165 self.name, 166 self.description, 167 self._ignore_global_phase, 168 ) 169 170 171@value.value_equality() 172class Gateset: 173 """Gatesets represent a collection of `cirq.GateFamily` objects. 174 175 Gatesets are useful for 176 a) Describing the set of allowed gates in a human readable format 177 b) Validating a given gate / optree against the set of allowed gates 178 179 Gatesets rely on the underlying `cirq.GateFamily` for both description and 180 validation purposes. 181 """ 182 183 def __init__( 184 self, 185 *gates: Union[Type[raw_types.Gate], raw_types.Gate, GateFamily], 186 name: Optional[str] = None, 187 unroll_circuit_op: bool = True, 188 accept_global_phase_op: bool = True, 189 ) -> None: 190 """Init Gateset. 191 192 Accepts a list of gates, each of which should be either 193 a) `cirq.Gate` subclass 194 b) `cirq.Gate` instance 195 c) `cirq.GateFamily` instance 196 197 `cirq.Gate` subclasses and instances are converted to the default 198 `cirq.GateFamily(gate=g)` instance and thus a default name and 199 description is populated. 200 201 Args: 202 *gates: A list of `cirq.Gate` subclasses / `cirq.Gate` instances / 203 `cirq.GateFamily` instances to initialize the Gateset. 204 name: (Optional) Name for the Gateset. Useful for description. 205 unroll_circuit_op: If True, `cirq.CircuitOperation` is recursively 206 validated by validating the underlying `cirq.Circuit`. 207 accept_global_phase_op: If True, `cirq.GlobalPhaseOperation` is accepted. 208 """ 209 self._name = name 210 self._unroll_circuit_op = unroll_circuit_op 211 self._accept_global_phase_op = accept_global_phase_op 212 self._instance_gate_families: Dict[raw_types.Gate, GateFamily] = {} 213 self._type_gate_families: Dict[Type[raw_types.Gate], GateFamily] = {} 214 self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates]) 215 unique_gate_list: List[GateFamily] = list( 216 dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates) 217 ) 218 for g in unique_gate_list: 219 if type(g) == GateFamily: 220 if isinstance(g.gate, raw_types.Gate): 221 self._instance_gate_families[g.gate] = g 222 else: 223 self._type_gate_families[g.gate] = g 224 self._gates_str_str = "\n\n".join([str(g) for g in unique_gate_list]) 225 self._gates = frozenset(unique_gate_list) 226 227 @property 228 def name(self) -> Optional[str]: 229 return self._name 230 231 @property 232 def gates(self) -> FrozenSet[GateFamily]: 233 return self._gates 234 235 def with_params( 236 self, 237 *, 238 name: Optional[str] = None, 239 unroll_circuit_op: Optional[bool] = None, 240 accept_global_phase_op: Optional[bool] = None, 241 ) -> 'Gateset': 242 """Returns a copy of this Gateset with identical gates and new values for named arguments. 243 244 If a named argument is None then corresponding value of this Gateset is used instead. 245 246 Args: 247 name: New name for the Gateset. 248 unroll_circuit_op: If True, new Gateset will recursively validate 249 `cirq.CircuitOperation` by validating the underlying `cirq.Circuit`. 250 accept_global_phase_op: If True, new Gateset will accept `cirq.GlobalPhaseOperation`. 251 252 Returns: 253 `self` if all new values are None or identical to the values of current Gateset. 254 else a new Gateset with identical gates and new values for named arguments. 255 """ 256 257 def val_if_none(var: Any, val: Any) -> Any: 258 return var if var is not None else val 259 260 name = val_if_none(name, self._name) 261 unroll_circuit_op = val_if_none(unroll_circuit_op, self._unroll_circuit_op) 262 accept_global_phase_op = val_if_none(accept_global_phase_op, self._accept_global_phase_op) 263 if ( 264 name == self._name 265 and unroll_circuit_op == self._unroll_circuit_op 266 and accept_global_phase_op == self._accept_global_phase_op 267 ): 268 return self 269 return Gateset( 270 *self.gates, 271 name=name, 272 unroll_circuit_op=cast(bool, unroll_circuit_op), 273 accept_global_phase_op=cast(bool, accept_global_phase_op), 274 ) 275 276 def __contains__(self, item: Union[raw_types.Gate, raw_types.Operation]) -> bool: 277 """Check for containment of a given Gate/Operation in this Gateset. 278 279 Containment checks are handled as follows: 280 a) For Gates or Operations that have an underlying gate (i.e. op.gate is not None): 281 - Forwards the containment check to the underlying `cirq.GateFamily` objects. 282 - Examples of such operations include `cirq.GateOperations` and their controlled 283 and tagged variants (i.e. instances of `cirq.TaggedOperation`, 284 `cirq.ControlledOperation` where `op.gate` is not None) etc. 285 286 b) For Operations that do not have an underlying gate: 287 - Forwards the containment check to `self._validate_operation(item)`. 288 - Examples of such operations include `cirq.CircuitOperations` and their controlled 289 and tagged variants (i.e. instances of `cirq.TaggedOperation`, 290 `cirq.ControlledOperation` where `op.gate` is None) etc. 291 292 The complexity of the method is: 293 a) O(1) when any default `cirq.GateFamily` instance accepts the given item, except 294 for an Instance GateFamily trying to match an item with a different global phase. 295 b) O(n) for all other cases: matching against custom gate families, matching across 296 global phase for the default Instance GateFamily, no match against any underlying 297 gate family. 298 299 Args: 300 item: The `cirq.Gate` or `cirq.Operation` instance to check containment for. 301 """ 302 if isinstance(item, raw_types.Operation) and item.gate is None: 303 return self._validate_operation(item) 304 305 g = item if isinstance(item, raw_types.Gate) else item.gate 306 assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`' 307 308 if g in self._instance_gate_families: 309 assert item in self._instance_gate_families[g], ( 310 f"{item} instance matches {self._instance_gate_families[g]} but " 311 f"is not accepted by it." 312 ) 313 return True 314 315 for gate_mro_type in type(g).mro(): 316 if gate_mro_type in self._type_gate_families: 317 assert item in self._type_gate_families[gate_mro_type], ( 318 f"{g} type {gate_mro_type} matches Type GateFamily:" 319 f"{self._type_gate_families[gate_mro_type]} but is not accepted by it." 320 ) 321 return True 322 323 return any(item in gate_family for gate_family in self._gates) 324 325 def validate( 326 self, 327 circuit_or_optree: Union['cirq.AbstractCircuit', op_tree.OP_TREE], 328 ) -> bool: 329 """Validates gates forming `circuit_or_optree` should be contained in Gateset. 330 331 Args: 332 circuit_or_optree: The `cirq.Circuit` or `cirq.OP_TREE` to validate. 333 """ 334 # To avoid circular import. 335 from cirq.circuits import circuit 336 337 optree = circuit_or_optree 338 if isinstance(circuit_or_optree, circuit.AbstractCircuit): 339 optree = circuit_or_optree.all_operations() 340 return all(self._validate_operation(op) for op in op_tree.flatten_to_ops(optree)) 341 342 def _validate_operation(self, op: raw_types.Operation) -> bool: 343 """Validates whether the given `cirq.Operation` is contained in this Gateset. 344 345 The containment checks are handled as follows: 346 347 a) For any operation which has an underlying gate (i.e. `op.gate` is not None): 348 - Containment is checked via `self.__contains__` which further checks for containment 349 in any of the underlying gate families. 350 351 b) For all other types of operations (eg: `cirq.CircuitOperation`, 352 `cirq.GlobalPhaseOperation` etc): 353 - The behavior is controlled via flags passed to the constructor. 354 355 Users should override this method to define custom behavior for operations that do not 356 have an underlying `cirq.Gate`. 357 358 Args: 359 op: The `cirq.Operation` instance to check containment for. 360 """ 361 362 # To avoid circular import. 363 from cirq.circuits import circuit_operation 364 365 if op.gate is not None: 366 return op in self 367 368 if isinstance(op, raw_types.TaggedOperation): 369 return self._validate_operation(op.sub_operation) 370 elif isinstance(op, circuit_operation.CircuitOperation) and self._unroll_circuit_op: 371 op_circuit = protocols.resolve_parameters( 372 op.circuit.unfreeze(), op.param_resolver, recursive=False 373 ) 374 op_circuit = op_circuit.transform_qubits( 375 lambda q: cast(circuit_operation.CircuitOperation, op).qubit_map.get(q, q) 376 ) 377 return self.validate(op_circuit) 378 elif isinstance(op, global_phase_op.GlobalPhaseOperation): 379 return self._accept_global_phase_op 380 else: 381 return False 382 383 def _value_equality_values_(self) -> Any: 384 return ( 385 self.gates, 386 self.name, 387 self._unroll_circuit_op, 388 self._accept_global_phase_op, 389 ) 390 391 def __repr__(self) -> str: 392 name_str = f'name = "{self.name}", ' if self.name is not None else '' 393 return ( 394 f'cirq.Gateset(' 395 f'{self._gates_repr_str}, ' 396 f'{name_str}' 397 f'unroll_circuit_op = {self._unroll_circuit_op},' 398 f'accept_global_phase_op = {self._accept_global_phase_op})' 399 ) 400 401 def __str__(self) -> str: 402 header = 'Gateset: ' 403 if self.name: 404 header += self.name 405 return f'{header}\n' + self._gates_str_str 406