1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Objects and methods for acting efficiently on a state tensor."""
15import abc
16import copy
17from typing import (
18    Any,
19    Dict,
20    List,
21    TypeVar,
22    TYPE_CHECKING,
23    Sequence,
24    Tuple,
25    cast,
26    Optional,
27    Iterator,
28)
29
30import numpy as np
31
32from cirq import protocols
33from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
34from cirq.sim.operation_target import OperationTarget
35
36TSelf = TypeVar('TSelf', bound='ActOnArgs')
37
38if TYPE_CHECKING:
39    import cirq
40
41
42class ActOnArgs(OperationTarget[TSelf]):
43    """State and context for an operation acting on a state tensor."""
44
45    def __init__(
46        self,
47        prng: np.random.RandomState = None,
48        qubits: Sequence['cirq.Qid'] = None,
49        log_of_measurement_results: Dict[str, Any] = None,
50    ):
51        """Inits ActOnArgs.
52
53        Args:
54            prng: The pseudo random number generator to use for probabilistic
55                effects.
56            qubits: Determines the canonical ordering of the qubits. This
57                is often used in specifying the initial state, i.e. the
58                ordering of the computational basis states.
59            log_of_measurement_results: A mutable object that measurements are
60                being recorded into.
61        """
62        if prng is None:
63            prng = cast(np.random.RandomState, np.random)
64        if qubits is None:
65            qubits = ()
66        if log_of_measurement_results is None:
67            log_of_measurement_results = {}
68        self._set_qubits(qubits)
69        self.prng = prng
70        self._log_of_measurement_results = log_of_measurement_results
71
72    def _set_qubits(self, qubits: Sequence['cirq.Qid']):
73        self._qubits = tuple(qubits)
74        self.qubit_map = {q: i for i, q in enumerate(self.qubits)}
75
76    # TODO(#3388) Add documentation for Raises.
77    # pylint: disable=missing-raises-doc
78    def measure(self, qubits: Sequence['cirq.Qid'], key: str, invert_mask: Sequence[bool]):
79        """Adds a measurement result to the log.
80
81        Args:
82            qubits: The qubits to measure.
83            key: The key the measurement result should be logged under. Note
84                that operations should only store results under keys they have
85                declared in a `_measurement_key_names_` method.
86            invert_mask: The invert mask for the measurement.
87        """
88        bits = self._perform_measurement(qubits)
89        corrected = [bit ^ (bit < 2 and mask) for bit, mask in zip(bits, invert_mask)]
90        if key in self._log_of_measurement_results:
91            raise ValueError(f"Measurement already logged to key {key!r}")
92        self._log_of_measurement_results[key] = corrected
93
94    # pylint: enable=missing-raises-doc
95    def get_axes(self, qubits: Sequence['cirq.Qid']) -> List[int]:
96        return [self.qubit_map[q] for q in qubits]
97
98    @abc.abstractmethod
99    def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
100        """Child classes that perform measurements should implement this with
101        the implementation."""
102
103    def copy(self: TSelf) -> TSelf:
104        """Creates a copy of the object."""
105        args = copy.copy(self)
106        self._on_copy(args)
107        args._log_of_measurement_results = self.log_of_measurement_results.copy()
108        return args
109
110    def _on_copy(self: TSelf, args: TSelf):
111        """Subclasses should implement this with any additional state copy
112        functionality."""
113
114    def create_merged_state(self: TSelf) -> TSelf:
115        """Creates a final merged state."""
116        return self
117
118    def kronecker_product(self: TSelf, other: TSelf, *, inplace=False) -> TSelf:
119        """Joins two state spaces together."""
120        args = self if inplace else copy.copy(self)
121        self._on_kronecker_product(other, args)
122        args._set_qubits(self.qubits + other.qubits)
123        return args
124
125    def _on_kronecker_product(self: TSelf, other: TSelf, target: TSelf):
126        """Subclasses should implement this with any additional state product
127        functionality, if supported."""
128
129    def factor(
130        self: TSelf,
131        qubits: Sequence['cirq.Qid'],
132        *,
133        validate=True,
134        atol=1e-07,
135        inplace=False,
136    ) -> Tuple[TSelf, TSelf]:
137        """Splits two state spaces after a measurement or reset."""
138        extracted = copy.copy(self)
139        remainder = self if inplace else copy.copy(self)
140        self._on_factor(qubits, extracted, remainder, validate, atol)
141        extracted._set_qubits(qubits)
142        remainder._set_qubits([q for q in self.qubits if q not in qubits])
143        return extracted, remainder
144
145    def _on_factor(
146        self: TSelf,
147        qubits: Sequence['cirq.Qid'],
148        extracted: TSelf,
149        remainder: TSelf,
150        validate=True,
151        atol=1e-07,
152    ):
153        """Subclasses should implement this with any additional state factor
154        functionality, if supported."""
155
156    def transpose_to_qubit_order(
157        self: TSelf, qubits: Sequence['cirq.Qid'], *, inplace=False
158    ) -> TSelf:
159        """Physically reindexes the state by the new basis.
160
161        Args:
162            qubits: The desired qubit order.
163            inplace: True to perform this operation inplace.
164
165        Returns:
166            The state with qubit order transposed and underlying representation
167            updated.
168
169        Raises:
170            ValueError: If the provided qubits do not match the existing ones.
171        """
172        if len(self.qubits) != len(qubits) or set(qubits) != set(self.qubits):
173            raise ValueError(f'Qubits do not match. Existing: {self.qubits}, provided: {qubits}')
174        args = self if inplace else copy.copy(self)
175        self._on_transpose_to_qubit_order(qubits, args)
176        args._set_qubits(qubits)
177        return args
178
179    def _on_transpose_to_qubit_order(self: TSelf, qubits: Sequence['cirq.Qid'], target: TSelf):
180        """Subclasses should implement this with any additional state transpose
181        functionality, if supported."""
182
183    @property
184    def log_of_measurement_results(self) -> Dict[str, Any]:
185        return self._log_of_measurement_results
186
187    @property
188    def qubits(self) -> Tuple['cirq.Qid', ...]:
189        return self._qubits
190
191    def swap(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
192        """Swaps two qubits.
193
194        This only affects the index, and does not modify the underlying
195        state.
196
197        Args:
198            q1: The first qubit to swap.
199            q2: The second qubit to swap.
200            inplace: True to swap the qubits in the current object, False to
201                create a copy with the qubits swapped.
202
203        Returns:
204            The original object with the qubits swapped if inplace is
205            requested, or a copy of the original object with the qubits swapped
206            otherwise.
207
208        Raises:
209            ValueError: If the qubits are of different dimensionality.
210        """
211        if q1.dimension != q2.dimension:
212            raise ValueError(f'Cannot swap different dimensions: q1={q1}, q2={q2}')
213
214        args = self if inplace else copy.copy(self)
215        i1 = self.qubits.index(q1)
216        i2 = self.qubits.index(q2)
217        qubits = list(args.qubits)
218        qubits[i1], qubits[i2] = qubits[i2], qubits[i1]
219        args._qubits = tuple(qubits)
220        args.qubit_map = {q: i for i, q in enumerate(qubits)}
221        return args
222
223    def rename(self, q1: 'cirq.Qid', q2: 'cirq.Qid', *, inplace=False):
224        """Renames `q1` to `q2`.
225
226        Args:
227            q1: The qubit to rename.
228            q2: The new name.
229            inplace: True to rename the qubit in the current object, False to
230                create a copy with the qubit renamed.
231
232        Returns:
233            The original object with the qubits renamed if inplace is
234            requested, or a copy of the original object with the qubits renamed
235            otherwise.
236
237        Raises:
238            ValueError: If the qubits are of different dimensionality.
239        """
240        if q1.dimension != q2.dimension:
241            raise ValueError(f'Cannot rename to different dimensions: q1={q1}, q2={q2}')
242
243        args = self if inplace else copy.copy(self)
244        i1 = self.qubits.index(q1)
245        qubits = list(args.qubits)
246        qubits[i1] = q2
247        args._qubits = tuple(qubits)
248        args.qubit_map = {q: i for i, q in enumerate(qubits)}
249        return args
250
251    def __getitem__(self: TSelf, item: Optional['cirq.Qid']) -> TSelf:
252        if item not in self.qubit_map:
253            raise IndexError(f'{item} not in {self.qubits}')
254        return self
255
256    def __len__(self) -> int:
257        return len(self.qubits)
258
259    def __iter__(self) -> Iterator[Optional['cirq.Qid']]:
260        return iter(self.qubits)
261
262
263def strat_act_on_from_apply_decompose(
264    val: Any,
265    args: ActOnArgs,
266    qubits: Sequence['cirq.Qid'],
267) -> bool:
268    operations, qubits1, _ = _try_decompose_into_operations_and_qubits(val)
269    assert len(qubits1) == len(qubits)
270    qubit_map = {q: qubits[i] for i, q in enumerate(qubits1)}
271    if operations is None:
272        return NotImplemented
273    for operation in operations:
274        operation = operation.with_qubits(*[qubit_map[q] for q in operation.qubits])
275        protocols.act_on(operation, args)
276    return True
277