1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union 16 17import numpy as np 18 19import cirq 20from cirq import protocols, value 21from cirq.ops import raw_types, controlled_operation as cop 22from cirq.type_workarounds import NotImplementedType 23 24 25@value.value_equality 26class ControlledGate(raw_types.Gate): 27 """Augments existing gates to have one or more control qubits. 28 29 This object is typically created via `gate.controlled()`. 30 """ 31 32 def __init__( 33 self, 34 sub_gate: 'cirq.Gate', 35 num_controls: int = None, 36 control_values: Optional[Sequence[Union[int, Collection[int]]]] = None, 37 control_qid_shape: Optional[Sequence[int]] = None, 38 ) -> None: 39 """Initializes the controlled gate. If no arguments are specified for 40 the controls, defaults to a single qubit control. 41 42 Args: 43 sub_gate: The gate to add a control qubit to. 44 num_controls: Total number of control qubits. 45 control_values: For which control qubit values to apply the sub 46 gate. A sequence of length `num_controls` where each 47 entry is an integer (or set of integers) corresponding to the 48 qubit value (or set of possible values) where that control is 49 enabled. When all controls are enabled, the sub gate is 50 applied. If unspecified, control values default to 1. 51 control_qid_shape: The qid shape of the controls. A tuple of the 52 expected dimension of each control qid. Defaults to 53 `(2,) * num_controls`. Specify this argument when using qudits. 54 55 """ 56 if num_controls is None: 57 if control_values is not None: 58 num_controls = len(control_values) 59 elif control_qid_shape is not None: 60 num_controls = len(control_qid_shape) 61 else: 62 num_controls = 1 63 if control_values is None: 64 control_values = ((1,),) * num_controls 65 if num_controls != len(control_values): 66 raise ValueError('len(control_values) != num_controls') 67 68 if control_qid_shape is None: 69 control_qid_shape = (2,) * num_controls 70 if num_controls != len(control_qid_shape): 71 raise ValueError('len(control_qid_shape) != num_controls') 72 self.control_qid_shape = tuple(control_qid_shape) 73 74 # Convert to sorted tuples 75 self.control_values = cast( 76 Tuple[Tuple[int, ...], ...], 77 tuple((val,) if isinstance(val, int) else tuple(sorted(val)) for val in control_values), 78 ) 79 # Verify control values not out of bounds 80 for i, (val, dimension) in enumerate(zip(self.control_values, self.control_qid_shape)): 81 if not all(0 <= v < dimension for v in val): 82 raise ValueError( 83 'Control values <{!r}> outside of range for control qubit ' 84 'number <{!r}>.'.format(val, i) 85 ) 86 87 # Flatten nested ControlledGates. 88 if isinstance(sub_gate, ControlledGate): 89 self.sub_gate = sub_gate.sub_gate # type: ignore 90 self.control_values += sub_gate.control_values 91 self.control_qid_shape += sub_gate.control_qid_shape 92 else: 93 self.sub_gate = sub_gate 94 95 def num_controls(self) -> int: 96 return len(self.control_qid_shape) 97 98 def _qid_shape_(self) -> Tuple[int, ...]: 99 return self.control_qid_shape + cirq.qid_shape(self.sub_gate) 100 101 def _decompose_(self, qubits): 102 result = protocols.decompose_once_with_qubits( 103 self.sub_gate, qubits[self.num_controls() :], NotImplemented 104 ) 105 106 if result is NotImplemented: 107 return NotImplemented 108 109 decomposed = [] 110 for op in result: 111 decomposed.append( 112 cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) 113 ) 114 return decomposed 115 116 def on(self, *qubits: 'cirq.Qid') -> cop.ControlledOperation: 117 if len(qubits) == 0: 118 raise ValueError(f"Applied a gate to an empty set of qubits. Gate: {self!r}") 119 self.validate_args(qubits) 120 return cop.ControlledOperation( 121 qubits[: self.num_controls()], 122 self.sub_gate.on(*qubits[self.num_controls() :]), 123 self.control_values, 124 ) 125 126 def _value_equality_values_(self): 127 return ( 128 self.sub_gate, 129 self.num_controls(), 130 frozenset(zip(self.control_values, self.control_qid_shape)), 131 ) 132 133 def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray: 134 qubits = cirq.LineQid.for_gate(self) 135 op = self.sub_gate.on(*qubits[self.num_controls() :]) 136 c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) 137 return protocols.apply_unitary(c_op, args, default=NotImplemented) 138 139 def _has_unitary_(self) -> bool: 140 return protocols.has_unitary(self.sub_gate) 141 142 def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: 143 qubits = cirq.LineQid.for_gate(self) 144 op = self.sub_gate.on(*qubits[self.num_controls() :]) 145 c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) 146 147 return protocols.unitary(c_op, default=NotImplemented) 148 149 def _has_mixture_(self) -> bool: 150 return protocols.has_mixture(self.sub_gate) 151 152 def _mixture_(self) -> Union[np.ndarray, NotImplementedType]: 153 qubits = cirq.LineQid.for_gate(self) 154 op = self.sub_gate.on(*qubits[self.num_controls() :]) 155 c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) 156 return protocols.mixture(c_op, default=NotImplemented) 157 158 def __pow__(self, exponent: Any) -> 'ControlledGate': 159 new_sub_gate = protocols.pow(self.sub_gate, exponent, NotImplemented) 160 if new_sub_gate is NotImplemented: 161 return NotImplemented 162 return ControlledGate( 163 new_sub_gate, 164 self.num_controls(), 165 control_values=self.control_values, 166 control_qid_shape=self.control_qid_shape, 167 ) 168 169 def _is_parameterized_(self) -> bool: 170 return protocols.is_parameterized(self.sub_gate) 171 172 def _parameter_names_(self) -> AbstractSet[str]: 173 return protocols.parameter_names(self.sub_gate) 174 175 def _resolve_parameters_( 176 self, resolver: 'cirq.ParamResolver', recursive: bool 177 ) -> 'ControlledGate': 178 new_sub_gate = protocols.resolve_parameters(self.sub_gate, resolver, recursive) 179 return ControlledGate( 180 new_sub_gate, 181 self.num_controls(), 182 control_values=self.control_values, 183 control_qid_shape=self.control_qid_shape, 184 ) 185 186 def _trace_distance_bound_(self) -> Optional[float]: 187 if self._is_parameterized_(): 188 return None 189 u = protocols.unitary(self.sub_gate, default=None) 190 if u is None: 191 return NotImplemented 192 angle_list = np.append(np.angle(np.linalg.eigvals(u)), 0) 193 return protocols.trace_distance_from_angle_list(angle_list) 194 195 def _circuit_diagram_info_( 196 self, args: 'cirq.CircuitDiagramInfoArgs' 197 ) -> 'cirq.CircuitDiagramInfo': 198 sub_args = protocols.CircuitDiagramInfoArgs( 199 known_qubit_count=( 200 args.known_qubit_count - self.num_controls() 201 if args.known_qubit_count is not None 202 else None 203 ), 204 known_qubits=( 205 args.known_qubits[self.num_controls() :] if args.known_qubits is not None else None 206 ), 207 use_unicode_characters=args.use_unicode_characters, 208 precision=args.precision, 209 qubit_map=args.qubit_map, 210 ) 211 sub_info = protocols.circuit_diagram_info(self.sub_gate, sub_args, None) 212 if sub_info is None: 213 return NotImplemented 214 215 def get_symbol(vals): 216 if tuple(vals) == (1,): 217 return '@' 218 return f"({','.join(map(str, vals))})" 219 220 return protocols.CircuitDiagramInfo( 221 wire_symbols=( 222 *(get_symbol(vals) for vals in self.control_values), 223 *sub_info.wire_symbols, 224 ), 225 exponent=sub_info.exponent, 226 ) 227 228 def __str__(self) -> str: 229 if set(self.control_values) == {(1,)}: 230 231 def get_prefix(control_vals): 232 return 'C' 233 234 else: 235 236 def get_prefix(control_vals): 237 control_vals_str = ''.join(map(str, sorted(control_vals))) 238 return f'C{control_vals_str}' 239 240 return ''.join(map(get_prefix, self.control_values)) + str(self.sub_gate) 241 242 def __repr__(self) -> str: 243 if self.num_controls() == 1 and self.control_values == ((1,),): 244 return f'cirq.ControlledGate(sub_gate={self.sub_gate!r})' 245 246 if all(vals == (1,) for vals in self.control_values) and set(self.control_qid_shape) == {2}: 247 return ( 248 f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, ' 249 f'num_controls={self.num_controls()!r})' 250 ) 251 return ( 252 f'cirq.ControlledGate(sub_gate={self.sub_gate!r}, ' 253 f'control_values={self.control_values!r},' 254 f'control_qid_shape={self.control_qid_shape!r})' 255 ) 256 257 def _json_dict_(self) -> Dict[str, Any]: 258 return { 259 'cirq_type': self.__class__.__name__, 260 'control_values': self.control_values, 261 'control_qid_shape': self.control_qid_shape, 262 'sub_gate': self.sub_gate, 263 } 264