1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""An MPS simulator.
15
16This is based on this paper:
17https://arxiv.org/abs/2002.07730
18"""
19
20import dataclasses
21import math
22from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union
23
24import numpy as np
25import quimb.tensor as qtn
26
27from cirq import devices, study, ops, protocols, value
28from cirq.sim import simulator, simulator_base
29from cirq.sim.act_on_args import ActOnArgs
30
31if TYPE_CHECKING:
32    import cirq
33
34
35@dataclasses.dataclass(frozen=True)
36class MPSOptions:
37    # Some of these parameters are fed directly to Quimb so refer to the documentation for detail:
38    # https://quimb.readthedocs.io/en/latest/_autosummary/ \
39    #       quimb.tensor.tensor_core.html#quimb.tensor.tensor_core.tensor_split
40
41    # How to split the tensor. Refer to the Quimb documentation for the exact meaning.
42    method: str = 'svds'
43    # If integer, the maxmimum number of singular values to keep, regardless of ``cutoff``.
44    max_bond: Optional[int] = None
45    # Method with which to apply the cutoff threshold. Refer to the Quimb documentation.
46    cutoff_mode: str = 'rsum2'
47    # The threshold below which to discard singular values. Refer to the Quimb documentation.
48    cutoff: float = 1e-6
49    # Because the computation is approximate, the sum of the probabilities is not 1.0. This
50    # parameter is the absolute deviation from 1.0 that is allowed.
51    sum_prob_atol: float = 1e-3
52
53
54class MPSSimulator(
55    simulator_base.SimulatorBase[
56        'MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState', 'MPSState'
57    ],
58):
59    """An efficient simulator for MPS circuits."""
60
61    # TODO(#3388) Add documentation for Raises.
62    # pylint: disable=missing-raises-doc
63    def __init__(
64        self,
65        noise: 'cirq.NOISE_MODEL_LIKE' = None,
66        seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
67        simulation_options: MPSOptions = MPSOptions(),
68        grouping: Optional[Dict['cirq.Qid', int]] = None,
69    ):
70        """Creates instance of `MPSSimulator`.
71
72        Args:
73            noise: A noise model to apply while simulating.
74            seed: The random seed to use for this simulator.
75            simulation_options: Numerical options for the simulation.
76            grouping: How to group qubits together, if None all are individual.
77        """
78        self.init = True
79        noise_model = devices.NoiseModel.from_noise_model_like(noise)
80        if not protocols.has_mixture(noise_model):
81            raise ValueError(f'noise must be unitary or mixture but was {noise_model}')
82        self.simulation_options = simulation_options
83        self.grouping = grouping
84        super().__init__(
85            noise=noise,
86            seed=seed,
87        )
88
89    # pylint: enable=missing-raises-doc
90    # TODO(#3388) Add documentation for Args.
91    # pylint: disable=missing-param-doc
92    def _create_partial_act_on_args(
93        self,
94        initial_state: Union[int, 'MPSState'],
95        qubits: Sequence['cirq.Qid'],
96        logs: Dict[str, Any],
97    ) -> 'MPSState':
98        """Creates MPSState args for simulating the Circuit.
99
100        Args:
101            initial_state: The initial state for the simulation in the
102                computational basis. Represented as a big endian int.
103            qubits: Determines the canonical ordering of the qubits. This
104                is often used in specifying the initial state, i.e. the
105                ordering of the computational basis states.
106
107        Returns:
108            MPSState args for simulating the Circuit.
109        """
110        if isinstance(initial_state, MPSState):
111            return initial_state
112
113        return MPSState(
114            qubits=qubits,
115            prng=self._prng,
116            simulation_options=self.simulation_options,
117            grouping=self.grouping,
118            initial_state=initial_state,
119            log_of_measurement_results=logs,
120        )
121
122    # pylint: enable=missing-param-doc
123    def _create_step_result(
124        self,
125        sim_state: 'cirq.OperationTarget[MPSState]',
126    ):
127        return MPSSimulatorStepResult(sim_state)
128
129    def _create_simulator_trial_result(
130        self,
131        params: study.ParamResolver,
132        measurements: Dict[str, np.ndarray],
133        final_step_result: 'MPSSimulatorStepResult',
134    ) -> 'MPSTrialResult':
135        """Creates a single trial results with the measurements.
136
137        Args:
138            params: A ParamResolver for determining values of Symbols.
139            measurements: A dictionary from measurement key (e.g. qubit) to the
140                actual measurement array.
141            final_step_result: The final step result of the simulation.
142
143        Returns:
144            A single result.
145        """
146        return MPSTrialResult(
147            params=params, measurements=measurements, final_step_result=final_step_result
148        )
149
150
151class MPSTrialResult(simulator.SimulationTrialResult):
152    """A single trial reult"""
153
154    def __init__(
155        self,
156        params: study.ParamResolver,
157        measurements: Dict[str, np.ndarray],
158        final_step_result: 'MPSSimulatorStepResult',
159    ) -> None:
160        super().__init__(
161            params=params, measurements=measurements, final_step_result=final_step_result
162        )
163
164    @property
165    def final_state(self):
166        return self._final_simulator_state
167
168    def __str__(self) -> str:
169        samples = super().__str__()
170        final = self._final_simulator_state
171        return f'measurements: {samples}\noutput state: {final}'
172
173
174class MPSSimulatorStepResult(simulator_base.StepResultBase['MPSState', 'MPSState']):
175    """A `StepResult` that can perform measurements."""
176
177    def __init__(
178        self,
179        sim_state: 'cirq.OperationTarget[MPSState]',
180    ):
181        """Results of a step of the simulator.
182        Attributes:
183            sim_state: The qubit:ActOnArgs lookup for this step.
184        """
185        super().__init__(sim_state)
186
187    @property
188    def state(self):
189        return self._merged_sim_state
190
191    def __str__(self) -> str:
192        def bitstring(vals):
193            return ','.join(str(v) for v in vals)
194
195        results = sorted([(key, bitstring(val)) for key, val in self.measurements.items()])
196
197        if len(results) == 0:
198            measurements = ''
199        else:
200            measurements = ' '.join([f'{key}={val}' for key, val in results]) + '\n'
201
202        final = self.state
203
204        return f'{measurements}{final}'
205
206    def _simulator_state(self):
207        return self.state
208
209
210@value.value_equality
211class MPSState(ActOnArgs):
212    """A state of the MPS simulation."""
213
214    # TODO(#3388) Add documentation for Raises.
215    # pylint: disable=missing-raises-doc
216    def __init__(
217        self,
218        qubits: Sequence['cirq.Qid'],
219        prng: np.random.RandomState,
220        simulation_options: MPSOptions = MPSOptions(),
221        grouping: Optional[Dict['cirq.Qid', int]] = None,
222        initial_state: int = 0,
223        log_of_measurement_results: Dict[str, Any] = None,
224    ):
225        """Creates and MPSState
226
227        Args:
228            qubits: Determines the canonical ordering of the qubits. This
229                is often used in specifying the initial state, i.e. the
230                ordering of the computational basis states.
231            prng: A random number generator, used to simulate measurements.
232            simulation_options: Numerical options for the simulation.
233            grouping: How to group qubits together, if None all are individual.
234            initial_state: An integer representing the initial state.
235            log_of_measurement_results: A mutable object that measurements are
236                being recorded into.
237        """
238        super().__init__(prng, qubits, log_of_measurement_results)
239        qubit_map = self.qubit_map
240        self.grouping = qubit_map if grouping is None else grouping
241        if self.grouping.keys() != self.qubit_map.keys():
242            raise ValueError('Grouping must cover exactly the qubits.')
243        self.M = []
244        for _ in range(max(self.grouping.values()) + 1):
245            self.M.append(qtn.Tensor())
246
247        # The order of the qubits matters, because the state |01> is different from |10>. Since
248        # Quimb uses strings to name tensor indices, we want to be able to sort them too. If we are
249        # working with, say, 123 qubits then we want qubit 3 to come before qubit 100, but then
250        # we want write the string '003' which comes before '100' in lexicographic order. The code
251        # below is just simple string formatting.
252        max_num_digits = len(f'{max(qubit_map.values())}')
253        self.format_i = f'i_{{:0{max_num_digits}}}'
254        self.format_mu = 'mu_{}_{}'
255
256        # TODO(tonybruguier): Instead of relying on sortable indices could you keep a parallel
257        # mapping of e.g. qubit to string-index and do all "logic" on the qubits themselves and
258        # only translate to string-indices when calling a quimb API.
259
260        # TODO(tonybruguier): Refactor out so that the code below can also be used by
261        # circuit_to_tensors in cirq.contrib.quimb.state_vector.
262
263        for qubit in reversed(list(qubit_map.keys())):
264            d = qubit.dimension
265            x = np.zeros(d)
266            x[initial_state % d] = 1.0
267
268            i = qubit_map[qubit]
269            n = self.grouping[qubit]
270            self.M[n] @= qtn.Tensor(x, inds=(self.i_str(i),))
271            initial_state = initial_state // d
272        self.simulation_options = simulation_options
273        self.estimated_gate_error_list: List[float] = []
274
275    # pylint: enable=missing-raises-doc
276    def i_str(self, i: int) -> str:
277        # Returns the index name for the i'th qid.
278        return self.format_i.format(i)
279
280    def mu_str(self, i: int, j: int) -> str:
281        # Returns the index name for the pair of the i'th and j'th qids. Note
282        # that by convention, the lower index is always the first in the output
283        # string.
284        smallest = min(i, j)
285        largest = max(i, j)
286        return self.format_mu.format(smallest, largest)
287
288    def __str__(self) -> str:
289        return str(qtn.TensorNetwork(self.M))
290
291    def _value_equality_values_(self) -> Any:
292        return self.qubit_map, self.M, self.simulation_options, self.grouping
293
294    def _on_copy(self, target: 'MPSState'):
295        target.simulation_options = self.simulation_options
296        target.grouping = self.grouping
297        target.M = [x.copy() for x in self.M]
298        target.estimated_gate_error_list = self.estimated_gate_error_list
299
300    def state_vector(self) -> np.ndarray:
301        """Returns the full state vector.
302
303        Returns:
304            A vector that contains the full state.
305        """
306        tensor_network = qtn.TensorNetwork(self.M)
307        state_vector = tensor_network.contract(inplace=False)
308
309        # Here, we rely on the formatting of the indices, and the fact that we have enough
310        # leading zeros so that 003 comes before 100.
311        sorted_ind = tuple(sorted(state_vector.inds))
312        return state_vector.fuse({'i': sorted_ind}).data
313
314    def partial_trace(self, keep_qubits: Set[ops.Qid]) -> np.ndarray:
315        """Traces out all qubits except keep_qubits.
316
317        Args:
318            keep_qubits: The set of qubits that are left after computing the
319                partial trace. For example, if we have a circuit for 3 qubits
320                and this parameter only has one qubit, the entire density matrix
321                would be 8x8, but this function returns a 2x2 matrix.
322
323        Returns:
324            An array that contains the partial trace.
325        """
326
327        contracted_inds = set(
328            [self.i_str(i) for qubit, i in self.qubit_map.items() if qubit not in keep_qubits]
329        )
330
331        conj_pfx = "conj_"
332
333        tensor_network = qtn.TensorNetwork(self.M)
334
335        # Rename the internal indices to avoid collisions. Also rename the qubit
336        # indices that are kept. We do not rename the qubit indices that are
337        # traced out.
338        conj_tensor_network = tensor_network.conj()
339        reindex_mapping = {}
340        for M in conj_tensor_network.tensors:
341            for ind in M.inds:
342                if ind not in contracted_inds:
343                    reindex_mapping[ind] = conj_pfx + ind
344        conj_tensor_network.reindex(reindex_mapping, inplace=True)
345        partial_trace = conj_tensor_network @ tensor_network
346
347        forward_inds = [self.i_str(self.qubit_map[keep_qubit]) for keep_qubit in keep_qubits]
348        backward_inds = [conj_pfx + forward_ind for forward_ind in forward_inds]
349        return partial_trace.to_dense(forward_inds, backward_inds)
350
351    def to_numpy(self) -> np.ndarray:
352        """An alias for the state vector."""
353        return self.state_vector()
354
355    def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState):
356        """Applies a unitary operation, mutating the object to represent the new state.
357
358        op:
359            The operation that mutates the object. Note that currently, only 1-
360            and 2- qubit operations are currently supported.
361        """
362
363        old_inds = tuple([self.i_str(self.qubit_map[qubit]) for qubit in op.qubits])
364        new_inds = tuple(['new_' + old_ind for old_ind in old_inds])
365
366        if protocols.has_unitary(op):
367            U = protocols.unitary(op)
368        else:
369            mixtures = protocols.mixture(op)
370            mixture_idx = int(prng.choice(len(mixtures), p=[mixture[0] for mixture in mixtures]))
371            U = mixtures[mixture_idx][1]
372        U = qtn.Tensor(
373            U.reshape([qubit.dimension for qubit in op.qubits] * 2), inds=(new_inds + old_inds)
374        )
375
376        # TODO(tonybruguier): Explore using the Quimb's tensor network natively.
377
378        if len(op.qubits) == 1:
379            n = self.grouping[op.qubits[0]]
380
381            self.M[n] = (U @ self.M[n]).reindex({new_inds[0]: old_inds[0]})
382        elif len(op.qubits) == 2:
383            n, p = [self.grouping[qubit] for qubit in op.qubits]
384
385            if n == p:
386                self.M[n] = (U @ self.M[n]).reindex(
387                    {new_inds[0]: old_inds[0], new_inds[1]: old_inds[1]}
388                )
389            else:
390                # This is the index on which we do the contraction. We need to add it iff it's
391                # the first time that we do the joining for that specific pair.
392                mu_ind = self.mu_str(n, p)
393                if mu_ind not in self.M[n].inds:
394                    self.M[n].new_ind(mu_ind)
395                if mu_ind not in self.M[p].inds:
396                    self.M[p].new_ind(mu_ind)
397
398                T = U @ self.M[n] @ self.M[p]
399
400                left_inds = tuple(set(T.inds) & set(self.M[n].inds)) + (new_inds[0],)
401                X, Y = T.split(
402                    left_inds,
403                    method=self.simulation_options.method,
404                    max_bond=self.simulation_options.max_bond,
405                    cutoff=self.simulation_options.cutoff,
406                    cutoff_mode=self.simulation_options.cutoff_mode,
407                    get='tensors',
408                    absorb='both',
409                    bond_ind=mu_ind,
410                )
411
412                # Equations (13), (14), and (15):
413                # TODO(tonybruguier): When Quimb 2.0.0 is released, the split()
414                # function should have a 'renorm' that, when set to None, will
415                # allow to compute e_n exactly as:
416                # np.sum(abs((X @ Y).data) ** 2).real / np.sum(abs(T) ** 2).real
417                #
418                # The renormalization would then have to be done manually.
419                #
420                # However, for now, e_n are just the estimated value.
421                e_n = self.simulation_options.cutoff
422                self.estimated_gate_error_list.append(e_n)
423
424                self.M[n] = X.reindex({new_inds[0]: old_inds[0]})
425                self.M[p] = Y.reindex({new_inds[1]: old_inds[1]})
426        else:
427            # NOTE(tonybruguier): There could be a way to handle higher orders. I think this could
428            # involve HOSVDs:
429            # https://en.wikipedia.org/wiki/Higher-order_singular_value_decomposition
430            #
431            # TODO(tonybruguier): Evaluate whether it's even useful to implement and learn more
432            # about HOSVDs.
433            raise ValueError('Can only handle 1 and 2 qubit operations')
434        return True
435
436    def _act_on_fallback_(
437        self,
438        action: Union['cirq.Operation', 'cirq.Gate'],
439        qubits: Sequence['cirq.Qid'],
440        allow_decompose: bool = True,
441    ) -> bool:
442        """Delegates the action to self.apply_op"""
443        if isinstance(action, ops.Gate):
444            action = ops.GateOperation(action, qubits)
445        return self.apply_op(action, self.prng)
446
447    def estimation_stats(self):
448        """Returns some statistics about the memory usage and quality of the approximation."""
449
450        num_coefs_used = sum([Mi.data.size for Mi in self.M])
451        memory_bytes = sum([Mi.data.nbytes for Mi in self.M])
452
453        # The computation below is done for numerical stability, instead of directly using the
454        # formula:
455        # estimated_fidelity = \prod_i (1 - estimated_gate_error_list_i)
456        estimated_fidelity = 1.0 + np.expm1(
457            sum(np.log1p(-x) for x in self.estimated_gate_error_list)
458        )
459        estimated_fidelity = round(estimated_fidelity, ndigits=3)
460
461        return {
462            "num_coefs_used": num_coefs_used,
463            "memory_bytes": memory_bytes,
464            "estimated_fidelity": estimated_fidelity,
465        }
466
467    # TODO(#3388) Add documentation for Raises.
468    # pylint: disable=missing-raises-doc
469    def perform_measurement(
470        self, qubits: Sequence[ops.Qid], prng: np.random.RandomState, collapse_state_vector=True
471    ) -> List[int]:
472        """Performs a measurement over one or more qubits.
473
474        Args:
475            qubits: The sequence of qids to measure, in that order.
476            prng: A random number generator, used to simulate measurements.
477            collapse_state_vector: A Boolean specifying whether we should mutate
478                the state after the measurement.
479        """
480        results: List[int] = []
481
482        if collapse_state_vector:
483            state = self
484        else:
485            state = self.copy()
486
487        for qubit in qubits:
488            n = state.qubit_map[qubit]
489
490            # Trace out other qubits
491            M = state.partial_trace(keep_qubits={qubit})
492            probs = np.diag(M).real
493            sum_probs = sum(probs)
494
495            # Because the computation is approximate, the probabilities do not
496            # necessarily add up to 1.0, and thus we re-normalize them.
497            if abs(sum_probs - 1.0) > self.simulation_options.sum_prob_atol:
498                raise ValueError(f'Sum of probabilities exceeds tolerance: {sum_probs}')
499            norm_probs = [x / sum_probs for x in probs]
500
501            d = qubit.dimension
502            result: int = int(prng.choice(d, p=norm_probs))
503
504            collapser = np.zeros((d, d))
505            collapser[result][result] = 1.0 / math.sqrt(probs[result])
506
507            old_n = state.i_str(n)
508            new_n = 'new_' + old_n
509
510            collapser = qtn.Tensor(collapser, inds=(new_n, old_n))
511
512            state.M[n] = (collapser @ state.M[n]).reindex({new_n: old_n})
513
514            results.append(result)
515
516        return results
517
518    # pylint: enable=missing-raises-doc
519    def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]:
520        """Measures the axes specified by the simulator."""
521        return self.perform_measurement(qubits, self.prng)
522
523    def sample(
524        self,
525        qubits: Sequence[ops.Qid],
526        repetitions: int = 1,
527        seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None,
528    ) -> np.ndarray:
529
530        measurements: List[List[int]] = []
531
532        for _ in range(repetitions):
533            measurements.append(
534                self.perform_measurement(
535                    qubits, value.parse_random_state(seed), collapse_state_vector=False
536                )
537            )
538
539        return np.array(measurements, dtype=int)
540