1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""An MPS simulator. 15 16This is based on this paper: 17https://arxiv.org/abs/2002.07730 18""" 19 20import dataclasses 21import math 22from typing import Any, Dict, List, Optional, Sequence, Set, TYPE_CHECKING, Union 23 24import numpy as np 25import quimb.tensor as qtn 26 27from cirq import devices, study, ops, protocols, value 28from cirq.sim import simulator, simulator_base 29from cirq.sim.act_on_args import ActOnArgs 30 31if TYPE_CHECKING: 32 import cirq 33 34 35@dataclasses.dataclass(frozen=True) 36class MPSOptions: 37 # Some of these parameters are fed directly to Quimb so refer to the documentation for detail: 38 # https://quimb.readthedocs.io/en/latest/_autosummary/ \ 39 # quimb.tensor.tensor_core.html#quimb.tensor.tensor_core.tensor_split 40 41 # How to split the tensor. Refer to the Quimb documentation for the exact meaning. 42 method: str = 'svds' 43 # If integer, the maxmimum number of singular values to keep, regardless of ``cutoff``. 44 max_bond: Optional[int] = None 45 # Method with which to apply the cutoff threshold. Refer to the Quimb documentation. 46 cutoff_mode: str = 'rsum2' 47 # The threshold below which to discard singular values. Refer to the Quimb documentation. 48 cutoff: float = 1e-6 49 # Because the computation is approximate, the sum of the probabilities is not 1.0. This 50 # parameter is the absolute deviation from 1.0 that is allowed. 51 sum_prob_atol: float = 1e-3 52 53 54class MPSSimulator( 55 simulator_base.SimulatorBase[ 56 'MPSSimulatorStepResult', 'MPSTrialResult', 'MPSState', 'MPSState' 57 ], 58): 59 """An efficient simulator for MPS circuits.""" 60 61 # TODO(#3388) Add documentation for Raises. 62 # pylint: disable=missing-raises-doc 63 def __init__( 64 self, 65 noise: 'cirq.NOISE_MODEL_LIKE' = None, 66 seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, 67 simulation_options: MPSOptions = MPSOptions(), 68 grouping: Optional[Dict['cirq.Qid', int]] = None, 69 ): 70 """Creates instance of `MPSSimulator`. 71 72 Args: 73 noise: A noise model to apply while simulating. 74 seed: The random seed to use for this simulator. 75 simulation_options: Numerical options for the simulation. 76 grouping: How to group qubits together, if None all are individual. 77 """ 78 self.init = True 79 noise_model = devices.NoiseModel.from_noise_model_like(noise) 80 if not protocols.has_mixture(noise_model): 81 raise ValueError(f'noise must be unitary or mixture but was {noise_model}') 82 self.simulation_options = simulation_options 83 self.grouping = grouping 84 super().__init__( 85 noise=noise, 86 seed=seed, 87 ) 88 89 # pylint: enable=missing-raises-doc 90 # TODO(#3388) Add documentation for Args. 91 # pylint: disable=missing-param-doc 92 def _create_partial_act_on_args( 93 self, 94 initial_state: Union[int, 'MPSState'], 95 qubits: Sequence['cirq.Qid'], 96 logs: Dict[str, Any], 97 ) -> 'MPSState': 98 """Creates MPSState args for simulating the Circuit. 99 100 Args: 101 initial_state: The initial state for the simulation in the 102 computational basis. Represented as a big endian int. 103 qubits: Determines the canonical ordering of the qubits. This 104 is often used in specifying the initial state, i.e. the 105 ordering of the computational basis states. 106 107 Returns: 108 MPSState args for simulating the Circuit. 109 """ 110 if isinstance(initial_state, MPSState): 111 return initial_state 112 113 return MPSState( 114 qubits=qubits, 115 prng=self._prng, 116 simulation_options=self.simulation_options, 117 grouping=self.grouping, 118 initial_state=initial_state, 119 log_of_measurement_results=logs, 120 ) 121 122 # pylint: enable=missing-param-doc 123 def _create_step_result( 124 self, 125 sim_state: 'cirq.OperationTarget[MPSState]', 126 ): 127 return MPSSimulatorStepResult(sim_state) 128 129 def _create_simulator_trial_result( 130 self, 131 params: study.ParamResolver, 132 measurements: Dict[str, np.ndarray], 133 final_step_result: 'MPSSimulatorStepResult', 134 ) -> 'MPSTrialResult': 135 """Creates a single trial results with the measurements. 136 137 Args: 138 params: A ParamResolver for determining values of Symbols. 139 measurements: A dictionary from measurement key (e.g. qubit) to the 140 actual measurement array. 141 final_step_result: The final step result of the simulation. 142 143 Returns: 144 A single result. 145 """ 146 return MPSTrialResult( 147 params=params, measurements=measurements, final_step_result=final_step_result 148 ) 149 150 151class MPSTrialResult(simulator.SimulationTrialResult): 152 """A single trial reult""" 153 154 def __init__( 155 self, 156 params: study.ParamResolver, 157 measurements: Dict[str, np.ndarray], 158 final_step_result: 'MPSSimulatorStepResult', 159 ) -> None: 160 super().__init__( 161 params=params, measurements=measurements, final_step_result=final_step_result 162 ) 163 164 @property 165 def final_state(self): 166 return self._final_simulator_state 167 168 def __str__(self) -> str: 169 samples = super().__str__() 170 final = self._final_simulator_state 171 return f'measurements: {samples}\noutput state: {final}' 172 173 174class MPSSimulatorStepResult(simulator_base.StepResultBase['MPSState', 'MPSState']): 175 """A `StepResult` that can perform measurements.""" 176 177 def __init__( 178 self, 179 sim_state: 'cirq.OperationTarget[MPSState]', 180 ): 181 """Results of a step of the simulator. 182 Attributes: 183 sim_state: The qubit:ActOnArgs lookup for this step. 184 """ 185 super().__init__(sim_state) 186 187 @property 188 def state(self): 189 return self._merged_sim_state 190 191 def __str__(self) -> str: 192 def bitstring(vals): 193 return ','.join(str(v) for v in vals) 194 195 results = sorted([(key, bitstring(val)) for key, val in self.measurements.items()]) 196 197 if len(results) == 0: 198 measurements = '' 199 else: 200 measurements = ' '.join([f'{key}={val}' for key, val in results]) + '\n' 201 202 final = self.state 203 204 return f'{measurements}{final}' 205 206 def _simulator_state(self): 207 return self.state 208 209 210@value.value_equality 211class MPSState(ActOnArgs): 212 """A state of the MPS simulation.""" 213 214 # TODO(#3388) Add documentation for Raises. 215 # pylint: disable=missing-raises-doc 216 def __init__( 217 self, 218 qubits: Sequence['cirq.Qid'], 219 prng: np.random.RandomState, 220 simulation_options: MPSOptions = MPSOptions(), 221 grouping: Optional[Dict['cirq.Qid', int]] = None, 222 initial_state: int = 0, 223 log_of_measurement_results: Dict[str, Any] = None, 224 ): 225 """Creates and MPSState 226 227 Args: 228 qubits: Determines the canonical ordering of the qubits. This 229 is often used in specifying the initial state, i.e. the 230 ordering of the computational basis states. 231 prng: A random number generator, used to simulate measurements. 232 simulation_options: Numerical options for the simulation. 233 grouping: How to group qubits together, if None all are individual. 234 initial_state: An integer representing the initial state. 235 log_of_measurement_results: A mutable object that measurements are 236 being recorded into. 237 """ 238 super().__init__(prng, qubits, log_of_measurement_results) 239 qubit_map = self.qubit_map 240 self.grouping = qubit_map if grouping is None else grouping 241 if self.grouping.keys() != self.qubit_map.keys(): 242 raise ValueError('Grouping must cover exactly the qubits.') 243 self.M = [] 244 for _ in range(max(self.grouping.values()) + 1): 245 self.M.append(qtn.Tensor()) 246 247 # The order of the qubits matters, because the state |01> is different from |10>. Since 248 # Quimb uses strings to name tensor indices, we want to be able to sort them too. If we are 249 # working with, say, 123 qubits then we want qubit 3 to come before qubit 100, but then 250 # we want write the string '003' which comes before '100' in lexicographic order. The code 251 # below is just simple string formatting. 252 max_num_digits = len(f'{max(qubit_map.values())}') 253 self.format_i = f'i_{{:0{max_num_digits}}}' 254 self.format_mu = 'mu_{}_{}' 255 256 # TODO(tonybruguier): Instead of relying on sortable indices could you keep a parallel 257 # mapping of e.g. qubit to string-index and do all "logic" on the qubits themselves and 258 # only translate to string-indices when calling a quimb API. 259 260 # TODO(tonybruguier): Refactor out so that the code below can also be used by 261 # circuit_to_tensors in cirq.contrib.quimb.state_vector. 262 263 for qubit in reversed(list(qubit_map.keys())): 264 d = qubit.dimension 265 x = np.zeros(d) 266 x[initial_state % d] = 1.0 267 268 i = qubit_map[qubit] 269 n = self.grouping[qubit] 270 self.M[n] @= qtn.Tensor(x, inds=(self.i_str(i),)) 271 initial_state = initial_state // d 272 self.simulation_options = simulation_options 273 self.estimated_gate_error_list: List[float] = [] 274 275 # pylint: enable=missing-raises-doc 276 def i_str(self, i: int) -> str: 277 # Returns the index name for the i'th qid. 278 return self.format_i.format(i) 279 280 def mu_str(self, i: int, j: int) -> str: 281 # Returns the index name for the pair of the i'th and j'th qids. Note 282 # that by convention, the lower index is always the first in the output 283 # string. 284 smallest = min(i, j) 285 largest = max(i, j) 286 return self.format_mu.format(smallest, largest) 287 288 def __str__(self) -> str: 289 return str(qtn.TensorNetwork(self.M)) 290 291 def _value_equality_values_(self) -> Any: 292 return self.qubit_map, self.M, self.simulation_options, self.grouping 293 294 def _on_copy(self, target: 'MPSState'): 295 target.simulation_options = self.simulation_options 296 target.grouping = self.grouping 297 target.M = [x.copy() for x in self.M] 298 target.estimated_gate_error_list = self.estimated_gate_error_list 299 300 def state_vector(self) -> np.ndarray: 301 """Returns the full state vector. 302 303 Returns: 304 A vector that contains the full state. 305 """ 306 tensor_network = qtn.TensorNetwork(self.M) 307 state_vector = tensor_network.contract(inplace=False) 308 309 # Here, we rely on the formatting of the indices, and the fact that we have enough 310 # leading zeros so that 003 comes before 100. 311 sorted_ind = tuple(sorted(state_vector.inds)) 312 return state_vector.fuse({'i': sorted_ind}).data 313 314 def partial_trace(self, keep_qubits: Set[ops.Qid]) -> np.ndarray: 315 """Traces out all qubits except keep_qubits. 316 317 Args: 318 keep_qubits: The set of qubits that are left after computing the 319 partial trace. For example, if we have a circuit for 3 qubits 320 and this parameter only has one qubit, the entire density matrix 321 would be 8x8, but this function returns a 2x2 matrix. 322 323 Returns: 324 An array that contains the partial trace. 325 """ 326 327 contracted_inds = set( 328 [self.i_str(i) for qubit, i in self.qubit_map.items() if qubit not in keep_qubits] 329 ) 330 331 conj_pfx = "conj_" 332 333 tensor_network = qtn.TensorNetwork(self.M) 334 335 # Rename the internal indices to avoid collisions. Also rename the qubit 336 # indices that are kept. We do not rename the qubit indices that are 337 # traced out. 338 conj_tensor_network = tensor_network.conj() 339 reindex_mapping = {} 340 for M in conj_tensor_network.tensors: 341 for ind in M.inds: 342 if ind not in contracted_inds: 343 reindex_mapping[ind] = conj_pfx + ind 344 conj_tensor_network.reindex(reindex_mapping, inplace=True) 345 partial_trace = conj_tensor_network @ tensor_network 346 347 forward_inds = [self.i_str(self.qubit_map[keep_qubit]) for keep_qubit in keep_qubits] 348 backward_inds = [conj_pfx + forward_ind for forward_ind in forward_inds] 349 return partial_trace.to_dense(forward_inds, backward_inds) 350 351 def to_numpy(self) -> np.ndarray: 352 """An alias for the state vector.""" 353 return self.state_vector() 354 355 def apply_op(self, op: 'cirq.Operation', prng: np.random.RandomState): 356 """Applies a unitary operation, mutating the object to represent the new state. 357 358 op: 359 The operation that mutates the object. Note that currently, only 1- 360 and 2- qubit operations are currently supported. 361 """ 362 363 old_inds = tuple([self.i_str(self.qubit_map[qubit]) for qubit in op.qubits]) 364 new_inds = tuple(['new_' + old_ind for old_ind in old_inds]) 365 366 if protocols.has_unitary(op): 367 U = protocols.unitary(op) 368 else: 369 mixtures = protocols.mixture(op) 370 mixture_idx = int(prng.choice(len(mixtures), p=[mixture[0] for mixture in mixtures])) 371 U = mixtures[mixture_idx][1] 372 U = qtn.Tensor( 373 U.reshape([qubit.dimension for qubit in op.qubits] * 2), inds=(new_inds + old_inds) 374 ) 375 376 # TODO(tonybruguier): Explore using the Quimb's tensor network natively. 377 378 if len(op.qubits) == 1: 379 n = self.grouping[op.qubits[0]] 380 381 self.M[n] = (U @ self.M[n]).reindex({new_inds[0]: old_inds[0]}) 382 elif len(op.qubits) == 2: 383 n, p = [self.grouping[qubit] for qubit in op.qubits] 384 385 if n == p: 386 self.M[n] = (U @ self.M[n]).reindex( 387 {new_inds[0]: old_inds[0], new_inds[1]: old_inds[1]} 388 ) 389 else: 390 # This is the index on which we do the contraction. We need to add it iff it's 391 # the first time that we do the joining for that specific pair. 392 mu_ind = self.mu_str(n, p) 393 if mu_ind not in self.M[n].inds: 394 self.M[n].new_ind(mu_ind) 395 if mu_ind not in self.M[p].inds: 396 self.M[p].new_ind(mu_ind) 397 398 T = U @ self.M[n] @ self.M[p] 399 400 left_inds = tuple(set(T.inds) & set(self.M[n].inds)) + (new_inds[0],) 401 X, Y = T.split( 402 left_inds, 403 method=self.simulation_options.method, 404 max_bond=self.simulation_options.max_bond, 405 cutoff=self.simulation_options.cutoff, 406 cutoff_mode=self.simulation_options.cutoff_mode, 407 get='tensors', 408 absorb='both', 409 bond_ind=mu_ind, 410 ) 411 412 # Equations (13), (14), and (15): 413 # TODO(tonybruguier): When Quimb 2.0.0 is released, the split() 414 # function should have a 'renorm' that, when set to None, will 415 # allow to compute e_n exactly as: 416 # np.sum(abs((X @ Y).data) ** 2).real / np.sum(abs(T) ** 2).real 417 # 418 # The renormalization would then have to be done manually. 419 # 420 # However, for now, e_n are just the estimated value. 421 e_n = self.simulation_options.cutoff 422 self.estimated_gate_error_list.append(e_n) 423 424 self.M[n] = X.reindex({new_inds[0]: old_inds[0]}) 425 self.M[p] = Y.reindex({new_inds[1]: old_inds[1]}) 426 else: 427 # NOTE(tonybruguier): There could be a way to handle higher orders. I think this could 428 # involve HOSVDs: 429 # https://en.wikipedia.org/wiki/Higher-order_singular_value_decomposition 430 # 431 # TODO(tonybruguier): Evaluate whether it's even useful to implement and learn more 432 # about HOSVDs. 433 raise ValueError('Can only handle 1 and 2 qubit operations') 434 return True 435 436 def _act_on_fallback_( 437 self, 438 action: Union['cirq.Operation', 'cirq.Gate'], 439 qubits: Sequence['cirq.Qid'], 440 allow_decompose: bool = True, 441 ) -> bool: 442 """Delegates the action to self.apply_op""" 443 if isinstance(action, ops.Gate): 444 action = ops.GateOperation(action, qubits) 445 return self.apply_op(action, self.prng) 446 447 def estimation_stats(self): 448 """Returns some statistics about the memory usage and quality of the approximation.""" 449 450 num_coefs_used = sum([Mi.data.size for Mi in self.M]) 451 memory_bytes = sum([Mi.data.nbytes for Mi in self.M]) 452 453 # The computation below is done for numerical stability, instead of directly using the 454 # formula: 455 # estimated_fidelity = \prod_i (1 - estimated_gate_error_list_i) 456 estimated_fidelity = 1.0 + np.expm1( 457 sum(np.log1p(-x) for x in self.estimated_gate_error_list) 458 ) 459 estimated_fidelity = round(estimated_fidelity, ndigits=3) 460 461 return { 462 "num_coefs_used": num_coefs_used, 463 "memory_bytes": memory_bytes, 464 "estimated_fidelity": estimated_fidelity, 465 } 466 467 # TODO(#3388) Add documentation for Raises. 468 # pylint: disable=missing-raises-doc 469 def perform_measurement( 470 self, qubits: Sequence[ops.Qid], prng: np.random.RandomState, collapse_state_vector=True 471 ) -> List[int]: 472 """Performs a measurement over one or more qubits. 473 474 Args: 475 qubits: The sequence of qids to measure, in that order. 476 prng: A random number generator, used to simulate measurements. 477 collapse_state_vector: A Boolean specifying whether we should mutate 478 the state after the measurement. 479 """ 480 results: List[int] = [] 481 482 if collapse_state_vector: 483 state = self 484 else: 485 state = self.copy() 486 487 for qubit in qubits: 488 n = state.qubit_map[qubit] 489 490 # Trace out other qubits 491 M = state.partial_trace(keep_qubits={qubit}) 492 probs = np.diag(M).real 493 sum_probs = sum(probs) 494 495 # Because the computation is approximate, the probabilities do not 496 # necessarily add up to 1.0, and thus we re-normalize them. 497 if abs(sum_probs - 1.0) > self.simulation_options.sum_prob_atol: 498 raise ValueError(f'Sum of probabilities exceeds tolerance: {sum_probs}') 499 norm_probs = [x / sum_probs for x in probs] 500 501 d = qubit.dimension 502 result: int = int(prng.choice(d, p=norm_probs)) 503 504 collapser = np.zeros((d, d)) 505 collapser[result][result] = 1.0 / math.sqrt(probs[result]) 506 507 old_n = state.i_str(n) 508 new_n = 'new_' + old_n 509 510 collapser = qtn.Tensor(collapser, inds=(new_n, old_n)) 511 512 state.M[n] = (collapser @ state.M[n]).reindex({new_n: old_n}) 513 514 results.append(result) 515 516 return results 517 518 # pylint: enable=missing-raises-doc 519 def _perform_measurement(self, qubits: Sequence['cirq.Qid']) -> List[int]: 520 """Measures the axes specified by the simulator.""" 521 return self.perform_measurement(qubits, self.prng) 522 523 def sample( 524 self, 525 qubits: Sequence[ops.Qid], 526 repetitions: int = 1, 527 seed: 'cirq.RANDOM_STATE_OR_SEED_LIKE' = None, 528 ) -> np.ndarray: 529 530 measurements: List[List[int]] = [] 531 532 for _ in range(repetitions): 533 measurements.append( 534 self.perform_measurement( 535 qubits, value.parse_random_state(seed), collapse_state_vector=False 536 ) 537 ) 538 539 return np.array(measurements, dtype=int) 540