1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Batteries-included class for Cirq's built-in simulators."""
16
17import abc
18import collections
19from typing import (
20    Any,
21    Dict,
22    Iterator,
23    List,
24    Tuple,
25    TYPE_CHECKING,
26    cast,
27    Generic,
28    Type,
29    Sequence,
30    Optional,
31    TypeVar,
32)
33
34import numpy as np
35
36from cirq import circuits, ops, protocols, study, value, devices
37from cirq.sim import ActOnArgsContainer
38from cirq.sim.operation_target import OperationTarget
39from cirq.sim.simulator import (
40    TSimulationTrialResult,
41    TSimulatorState,
42    TActOnArgs,
43    SimulatesIntermediateState,
44    SimulatesSamples,
45    StepResult,
46    check_all_resolved,
47    split_into_matching_protocol_then_general,
48)
49
50if TYPE_CHECKING:
51    import cirq
52
53
54TStepResultBase = TypeVar('TStepResultBase', bound='StepResultBase')
55
56
57class SimulatorBase(
58    Generic[TStepResultBase, TSimulationTrialResult, TSimulatorState, TActOnArgs],
59    SimulatesIntermediateState[
60        TStepResultBase, TSimulationTrialResult, TSimulatorState, TActOnArgs
61    ],
62    SimulatesSamples,
63    metaclass=abc.ABCMeta,
64):
65    """A base class for the built-in simulators.
66
67    Most implementors of this interface should implement the
68    `_create_partial_act_on_args` and `_create_step_result` methods. The first
69    one creates the simulator's quantum state representation at the beginning
70    of the simulation. The second creates the step result emitted after each
71    `Moment` in the simulation.
72
73    Iteration in the subclass is handled by the `_core_iterator` implementation
74    here, which handles moment stepping, application of operations, measurement
75    collection, and creation of noise. Simulators with more advanced needs can
76    override the implementation if necessary.
77
78    Sampling is handled by the implementation of `_run`. This implementation
79    iterates the circuit to create a final step result, and samples that
80    result when possible. If not possible, due to noise or classical
81    probabilities on a state vector, the implementation attempts to fully
82    iterate the unitary prefix once, then only repeat the non-unitary
83    suffix from copies of the state obtained by the prefix. If more advanced
84    functionality is required, then the `_run` method can be overridden.
85
86    Note that state here refers to simulator state, which is not necessarily
87    a state vector. The included simulators and corresponding states are state
88    vector, density matrix, Clifford, and MPS. Each of these use the default
89    `_core_iterator` and `_run` methods.
90    """
91
92    def __init__(
93        self,
94        *,
95        dtype: Type[np.number] = np.complex64,
96        noise: 'cirq.NOISE_MODEL_LIKE' = None,
97        seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
98        ignore_measurement_results: bool = False,
99        split_untangled_states: bool = False,
100    ):
101        """Initializes the simulator.
102
103        Args:
104            dtype: The `numpy.dtype` used by the simulation.
105            noise: A noise model to apply while simulating.
106            seed: The random seed to use for this simulator.
107            ignore_measurement_results: If True, then the simulation
108                will treat measurement as dephasing instead of collapsing
109                process. This is only applicable to simulators that can
110                model dephasing.
111            split_untangled_states: If True, optimizes simulation by running
112                unentangled qubit sets independently and merging those states
113                at the end.
114        """
115        self._dtype = dtype
116        self._prng = value.parse_random_state(seed)
117        self.noise = devices.NoiseModel.from_noise_model_like(noise)
118        self._ignore_measurement_results = ignore_measurement_results
119        self._split_untangled_states = split_untangled_states
120
121    @abc.abstractmethod
122    def _create_partial_act_on_args(
123        self,
124        initial_state: Any,
125        qubits: Sequence['cirq.Qid'],
126        logs: Dict[str, Any],
127    ) -> TActOnArgs:
128        """Creates an instance of the TActOnArgs class for the simulator.
129
130        It represents the supplied qubits initialized to the provided state.
131
132        Args:
133            initial_state: The initial state to represent. An integer state is
134                understood to be a pure state. Other state representations are
135                simulator-dependent.
136            qubits: The sequence of qubits to represent.
137            logs: The structure to hold measurement logs. A single instance
138                should be shared among all ActOnArgs within the simulation.
139        """
140
141    @abc.abstractmethod
142    def _create_step_result(
143        self,
144        sim_state: OperationTarget[TActOnArgs],
145    ) -> TStepResultBase:
146        """This method should be implemented to create a step result.
147
148        Args:
149            sim_state: The OperationTarget for this trial.
150
151        Returns:
152            The StepResult.
153        """
154
155    def _can_be_in_run_prefix(self, val: Any):
156        """Determines what should be put in the prefix in `_run`
157
158        The `_run` method has an optimization that reduces repetition by
159        splitting the circuit into a prefix that is pure with respect to the
160        state representation, and only executing that once per sample set. For
161        state vectors, any unitary operation is pure, and we make this the
162        default here. For density matrices, any non-measurement operation can
163        be represented wholely in the matrix, and thus this method is
164        overridden there to enable greater optimization there.
165
166        Custom simulators can override this method appropriately.
167
168        Args:
169            val: An operation or noise model to test for purity within the
170                state representation.
171
172        Returns:
173            A boolean representing whether the value can be added to the
174            `_run` prefix."""
175        return protocols.has_unitary(val)
176
177    # TODO(#3388) Add documentation for Args.
178    # TODO(#3388) Add documentation for Raises.
179    # pylint: disable=missing-param-doc,missing-raises-doc
180    def _core_iterator(
181        self,
182        circuit: circuits.AbstractCircuit,
183        sim_state: OperationTarget[TActOnArgs],
184        all_measurements_are_terminal: bool = False,
185    ) -> Iterator[TStepResultBase]:
186        """Standard iterator over StepResult from Moments of a Circuit.
187
188        Args:
189            circuit: The circuit to simulate.
190            sim_state: The initial args for the simulation. The form of
191                this state depends on the simulation implementation. See
192                documentation of the implementing class for details.
193
194        Yields:
195            StepResults from simulating a Moment of the Circuit.
196        """
197
198        if len(circuit) == 0:
199            yield self._create_step_result(sim_state)
200            return
201
202        noisy_moments = self.noise.noisy_moments(circuit, sorted(circuit.all_qubits()))
203        measured: Dict[Tuple['cirq.Qid', ...], bool] = collections.defaultdict(bool)
204        for moment in noisy_moments:
205            for op in ops.flatten_to_ops(moment):
206                try:
207                    # TODO: support more general measurements.
208                    # Github issue: https://github.com/quantumlib/Cirq/issues/3566
209
210                    # Preprocess measurements
211                    if all_measurements_are_terminal and measured[op.qubits]:
212                        continue
213                    if isinstance(op.gate, ops.MeasurementGate):
214                        measured[op.qubits] = True
215                        if all_measurements_are_terminal:
216                            continue
217                        if self._ignore_measurement_results:
218                            op = ops.phase_damp(1).on(*op.qubits)
219
220                    # Simulate the operation
221                    protocols.act_on(op, sim_state)
222                except TypeError:
223                    raise TypeError(f"{self.__class__.__name__} doesn't support {op!r}")
224
225            step_result = self._create_step_result(sim_state)
226            yield step_result
227            sim_state = step_result._sim_state
228
229    # pylint: enable=missing-param-doc,missing-raises-doc
230    def _run(
231        self,
232        circuit: circuits.AbstractCircuit,
233        param_resolver: study.ParamResolver,
234        repetitions: int,
235    ) -> Dict[str, np.ndarray]:
236        """See definition in `cirq.SimulatesSamples`."""
237        if self._ignore_measurement_results:
238            raise ValueError("run() is not supported when ignore_measurement_results = True")
239
240        param_resolver = param_resolver or study.ParamResolver({})
241        resolved_circuit = protocols.resolve_parameters(circuit, param_resolver)
242        check_all_resolved(resolved_circuit)
243        qubits = tuple(sorted(resolved_circuit.all_qubits()))
244        act_on_args = self._create_act_on_args(0, qubits)
245
246        prefix, general_suffix = (
247            split_into_matching_protocol_then_general(resolved_circuit, self._can_be_in_run_prefix)
248            if self._can_be_in_run_prefix(self.noise)
249            else (resolved_circuit[0:0], resolved_circuit)
250        )
251        step_result = None
252        for step_result in self._core_iterator(
253            circuit=prefix,
254            sim_state=act_on_args,
255        ):
256            pass
257
258        general_ops = list(general_suffix.all_operations())
259        if all(isinstance(op.gate, ops.MeasurementGate) for op in general_ops):
260            for step_result in self._core_iterator(
261                circuit=general_suffix,
262                sim_state=act_on_args,
263                all_measurements_are_terminal=True,
264            ):
265                pass
266            assert step_result is not None
267            measurement_ops = [cast(ops.GateOperation, op) for op in general_ops]
268            return step_result.sample_measurement_ops(measurement_ops, repetitions, seed=self._prng)
269
270        measurements: Dict[str, List[np.ndarray]] = {}
271        for i in range(repetitions):
272            all_step_results = self._core_iterator(
273                general_suffix,
274                sim_state=act_on_args.copy() if i < repetitions - 1 else act_on_args,
275            )
276            for step_result in all_step_results:
277                pass
278            for k, v in step_result.measurements.items():
279                if k not in measurements:
280                    measurements[k] = []
281                measurements[k].append(np.array(v, dtype=np.uint8))
282        return {k: np.array(v) for k, v in measurements.items()}
283
284    def _create_act_on_args(
285        self,
286        initial_state: Any,
287        qubits: Sequence['cirq.Qid'],
288    ) -> OperationTarget[TActOnArgs]:
289        if isinstance(initial_state, OperationTarget):
290            return initial_state
291
292        log: Dict[str, Any] = {}
293        if self._split_untangled_states:
294            args_map: Dict[Optional['cirq.Qid'], TActOnArgs] = {}
295            if isinstance(initial_state, int):
296                for q in reversed(qubits):
297                    args_map[q] = self._create_partial_act_on_args(
298                        initial_state=initial_state % q.dimension,
299                        qubits=[q],
300                        logs=log,
301                    )
302                    initial_state = int(initial_state / q.dimension)
303            else:
304                args = self._create_partial_act_on_args(
305                    initial_state=initial_state,
306                    qubits=qubits,
307                    logs=log,
308                )
309                for q in qubits:
310                    args_map[q] = args
311            args_map[None] = self._create_partial_act_on_args(0, (), log)
312            return ActOnArgsContainer(args_map, qubits, self._split_untangled_states, log)
313        else:
314            return self._create_partial_act_on_args(
315                initial_state=initial_state,
316                qubits=qubits,
317                logs=log,
318            )
319
320
321class StepResultBase(Generic[TSimulatorState, TActOnArgs], StepResult[TSimulatorState], abc.ABC):
322    """A base class for step results."""
323
324    def __init__(
325        self,
326        sim_state: OperationTarget[TActOnArgs],
327    ):
328        """Initializes the step result.
329
330        Args:
331            sim_state: The `OperationTarget` for this step.
332        """
333        self._sim_state = sim_state
334        self._merged_sim_state_cache: Optional[TActOnArgs] = None
335        super().__init__(sim_state.log_of_measurement_results)
336        qubits = sim_state.qubits
337        self._qubits = qubits
338        self._qubit_mapping = {q: i for i, q in enumerate(qubits)}
339        self._qubit_shape = tuple(q.dimension for q in qubits)
340
341    def _qid_shape_(self):
342        return self._qubit_shape
343
344    @property
345    def _merged_sim_state(self):
346        if self._merged_sim_state_cache is None:
347            self._merged_sim_state_cache = self._sim_state.create_merged_state()
348        return self._merged_sim_state_cache
349
350    def sample(
351        self,
352        qubits: List[ops.Qid],
353        repetitions: int = 1,
354        seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
355    ) -> np.ndarray:
356        return self._sim_state.sample(qubits, repetitions, seed)
357