1# Copyright 2021 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Objects and methods for acting efficiently on a state tensor.""" 15import abc 16import copy 17from typing import ( 18 Any, 19 Dict, 20 List, 21 TypeVar, 22 TYPE_CHECKING, 23 Sequence, 24 Tuple, 25 cast, 26 Optional, 27 Iterator, 28) 29 30import numpy as np 31 32from cirq import protocols 33from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits 34from cirq.sim.operation_target import OperationTarget 35 36TSelf = TypeVar('TSelf', bound='ActOnArgs') 37 38if TYPE_CHECKING: 39 import cirq 40 41 42class ActOnArgs(OperationTarget[TSelf]): 43 """State and context for an operation acting on a state tensor.""" 44 45 def __init__( 46 self, 47 prng: np.random.RandomState = None, 48 qubits: Sequence['cirq.Qid'] = None, 49 log_of_measurement_results: Dict[str, Any] = None, 50 ): 51 """Inits ActOnArgs. 52 53 Args: 54 prng: The pseudo random number generator to use for probabilistic 55 effects. 56 qubits: Determines the canonical ordering of the qubits. This 57 is often used in specifying the initial state, i.e. the 58 ordering of the computational basis states. 59 log_of_measurement_results: A mutable object that measurements are 60 being recorded into. 61 """ 62 if prng is None: 63 prng = cast(np.random.RandomState, np.random) 64 if qubits is None: 65 qubits = () 66 if log_of_measurement_results is None: 67 log_of_measurement_results = {} 68 self._set_qubits(qubits) 69 self.prng = prng 70 self._log_of_measurement_results = log_of_measurement_results 71 72 def _set_qubits(self, qubits: Sequence['cirq.Qid']): 73 self._qubits = tuple(qubits) 74 self.qubit_map = {q: i for i, q in enumerate(self.qubits)} 75 76 # TODO(#3388) Add documentation for Raises. 77 # pylint: disable=missing-raises-doc 78 def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]): 79 """Adds a measurement result to the log. 80 81 Args: 82 qubits: The qubits to measure. 83 key: The key the measurement result should be logged under. Note 84 that operations should only store results under keys they have 85 declared in a `_measurement_key_names_` method. 86 invert_mask: The invert mask for the measurement. 87 """ 88 bits = self._perform_measurement(qubits) 89 corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)] 90 if key in self._log_of_measurement_results: 91 raise ValueError(f"Measurement already logged to key {key!r}") 92 self._log_of_measurement_results[key] = corrected 93 94 # pylint: enable=missing-raises-doc 95 def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]: 96 return [self.qubit_map[q] for q in qubits] 97 98 @abc.abstractmethod 99 def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: 100 """Child classes that perform measurements should implement this with 101 the implementation.""" 102 103 def copy(self: TSelf) -> TSelf: 104 """Creates a copy of the object.""" 105 args = copy.copy(self) 106 self._on_copy(args) 107 args._log_of_measurement_results = self.log_of_measurement_results.copy() 108 return args 109 110 def _on_copy(self: TSelf, args: TSelf): 111 """Subclasses should implement this with any additional state copy 112 functionality.""" 113 114 def create_merged_state(self: TSelf) -> TSelf: 115 """Creates a final merged state.""" 116 return self 117 118 def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf: 119 """Joins two state spaces together.""" 120 args = self if inplace else copy.copy(self) 121 self._on_kronecker_product(other, args) 122 args._set_qubits(self.qubits + other.qubits) 123 return args 124 125 def _on_kronecker_product(self: TSelf, other: TSelf, target: TSelf): 126 """Subclasses should implement this with any additional state product 127 functionality, if supported.""" 128 129 def factor( 130 self: TSelf, 131 qubits: Sequence['cirq.Qid'], 132 *, 133 validate=True, 134 atol=1e-07, 135 inplace=False, 136 ) -> Tuple[TSelf, TSelf]: 137 """Splits two state spaces after a measurement or reset.""" 138 extracted = copy.copy(self) 139 remainder = self if inplace else copy.copy(self) 140 self._on_factor(qubits, extracted, remainder, validate, atol) 141 extracted._set_qubits(qubits) 142 remainder._set_qubits([q for q in self.qubits if q not in qubits]) 143 return extracted, remainder 144 145 def _on_factor( 146 self: TSelf, 147 qubits: Sequence['cirq.Qid'], 148 extracted: TSelf, 149 remainder: TSelf, 150 validate=True, 151 atol=1e-07, 152 ): 153 """Subclasses should implement this with any additional state factor 154 functionality, if supported.""" 155 156 def transpose_to_qubit_order( 157 self: TSelf, qubits: Sequence['cirq.Qid'], *, inplace=False 158 ) -> TSelf: 159 """Physically reindexes the state by the new basis. 160 161 Args: 162 qubits: The desired qubit order. 163 inplace: True to perform this operation inplace. 164 165 Returns: 166 The state with qubit order transposed and underlying representation 167 updated. 168 169 Raises: 170 ValueError: If the provided qubits do not match the existing ones. 171 """ 172 if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits): 173 raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}') 174 args = self if inplace else copy.copy(self) 175 self._on_transpose_to_qubit_order(qubits, args) 176 args._set_qubits(qubits) 177 return args 178 179 def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], target: TSelf): 180 """Subclasses should implement this with any additional state transpose 181 functionality, if supported.""" 182 183 @property 184 def log_of_measurement_results(self) -> Dict[str, Any]: 185 return self._log_of_measurement_results 186 187 @property 188 def qubits(self) -> Tuple['cirq.Qid', ...]: 189 return self._qubits 190 191 def swap(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): 192 """Swaps two qubits. 193 194 This only affects the index, and does not modify the underlying 195 state. 196 197 Args: 198 q1: The first qubit to swap. 199 q2: The second qubit to swap. 200 inplace: True to swap the qubits in the current object, False to 201 create a copy with the qubits swapped. 202 203 Returns: 204 The original object with the qubits swapped if inplace is 205 requested, or a copy of the original object with the qubits swapped 206 otherwise. 207 208 Raises: 209 ValueError: If the qubits are of different dimensionality. 210 """ 211 if q1.dimension != q2.dimension: 212 raise ValueError(f'Cannot swap different dimensions: q1={q1}, q2={q2}') 213 214 args = self if inplace else copy.copy(self) 215 i1 = self.qubits.index(q1) 216 i2 = self.qubits.index(q2) 217 qubits = list(args.qubits) 218 qubits[i1], qubits[i2] = qubits[i2], qubits[i1] 219 args._qubits = tuple(qubits) 220 args.qubit_map = {q: i for i, q in enumerate(qubits)} 221 return args 222 223 def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False): 224 """Renames `q1` to `q2`. 225 226 Args: 227 q1: The qubit to rename. 228 q2: The new name. 229 inplace: True to rename the qubit in the current object, False to 230 create a copy with the qubit renamed. 231 232 Returns: 233 The original object with the qubits renamed if inplace is 234 requested, or a copy of the original object with the qubits renamed 235 otherwise. 236 237 Raises: 238 ValueError: If the qubits are of different dimensionality. 239 """ 240 if q1.dimension != q2.dimension: 241 raise ValueError(f'Cannot rename to different dimensions: q1={q1}, q2={q2}') 242 243 args = self if inplace else copy.copy(self) 244 i1 = self.qubits.index(q1) 245 qubits = list(args.qubits) 246 qubits[i1] = q2 247 args._qubits = tuple(qubits) 248 args.qubit_map = {q: i for i, q in enumerate(qubits)} 249 return args 250 251 def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf: 252 if item not in self.qubit_map: 253 raise IndexError(f'{item} not in {self.qubits}') 254 return self 255 256 def __len__(self) -> int: 257 return len(self.qubits) 258 259 def __iter__(self) -> Iterator[Optional['cirq.Qid']]: 260 return iter(self.qubits) 261 262 263def strat_act_on_from_apply_decompose( 264 val: Any, 265 args: ActOnArgs, 266 qubits: Sequence['cirq.Qid'], 267) -> bool: 268 operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val) 269 assert len(qubits1) == len(qubits) 270 qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)} 271 if operations is None: 272 return NotImplemented 273 for operation in operations: 274 operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits]) 275 protocols.act_on(operation, args) 276 return True 277