1from typing import Any, Dict, Iterable, Tuple, Union 2import numpy as np 3 4from cirq import linalg, protocols, value 5from cirq._compat import proper_repr 6from cirq.ops import raw_types 7 8 9class MixedUnitaryChannel(raw_types.Gate): 10 """A generic mixture that can record the index of its selected operator. 11 12 This type of object is also referred to as a mixed-unitary channel. 13 14 Args: 15 mixture: a list of (probability, qubit unitary) pairs 16 key: an optional measurement key string for this mixture. Simulations 17 which select a single unitary to apply will store the index 18 of that unitary in the measurement result list with this key. 19 validate: if True, validate that `mixture` describes a valid mixture. 20 This validation can be slow; prefer pre-validating if possible. 21 """ 22 23 def __init__( 24 self, 25 mixture: Iterable[Tuple[float, np.ndarray]], 26 key: Union[str, value.MeasurementKey, None] = None, 27 validate: bool = False, 28 ): 29 mixture = list(mixture) 30 if not mixture: 31 raise ValueError('MixedUnitaryChannel must have at least one unitary.') 32 if not protocols.approx_eq(sum(p[0] for p in mixture), 1): 33 raise ValueError('Unitary probabilities must sum to 1.') 34 m0 = mixture[0][1] 35 num_qubits = np.log2(m0.shape[0]) 36 if not num_qubits.is_integer() or m0.shape[1] != m0.shape[0]: 37 raise ValueError( 38 f'Input mixture of shape {m0.shape} does not ' 39 'represent a square operator over qubits.' 40 ) 41 self._num_qubits = int(num_qubits) 42 for i, op in enumerate(p[1] for p in mixture): 43 if not op.shape == m0.shape: 44 raise ValueError( 45 f'Inconsistent unitary shapes: op[0]: {m0.shape}, op[{i}]: {op.shape}' 46 ) 47 if validate and not linalg.is_unitary(op): 48 raise ValueError(f'Element {i} of mixture is non-unitary.') 49 self._mixture = mixture 50 if not isinstance(key, value.MeasurementKey) and key is not None: 51 key = value.MeasurementKey(key) 52 self._key = key 53 54 @staticmethod 55 def from_mixture( 56 mixture: 'protocols.SupportsMixture', key: Union[str, value.MeasurementKey, None] = None 57 ): 58 """Creates a copy of a mixture with the given measurement key.""" 59 return MixedUnitaryChannel(mixture=list(protocols.mixture(mixture)), key=key) 60 61 def __eq__(self, other) -> bool: 62 if not isinstance(other, MixedUnitaryChannel): 63 return NotImplemented 64 if self._key != other._key: 65 return False 66 if not np.allclose( 67 [m[0] for m in self._mixture], 68 [m[0] for m in other._mixture], 69 ): 70 return False 71 return np.allclose( 72 [m[1] for m in self._mixture], 73 [m[1] for m in other._mixture], 74 ) 75 76 def num_qubits(self) -> int: 77 return self._num_qubits 78 79 def _mixture_(self): 80 return self._mixture 81 82 def _measurement_key_name_(self) -> str: 83 if self._key is None: 84 return NotImplemented 85 return str(self._key) 86 87 def _measurement_key_obj_(self) -> value.MeasurementKey: 88 if self._key is None: 89 return NotImplemented 90 return self._key 91 92 def _with_measurement_key_mapping_(self, key_map: Dict[str, str]): 93 if self._key is None: 94 return NotImplemented 95 if self._key not in key_map: 96 return self 97 return MixedUnitaryChannel(mixture=self._mixture, key=key_map[str(self._key)]) 98 99 def _with_key_path_(self, path: Tuple[str, ...]): 100 return MixedUnitaryChannel( 101 mixture=self._mixture, key=protocols.with_key_path(self._key, path) 102 ) 103 104 def __str__(self): 105 if self._key is not None: 106 return f'MixedUnitaryChannel({self._mixture}, key={self._key})' 107 return f'MixedUnitaryChannel({self._mixture})' 108 109 def __repr__(self): 110 unitary_tuples = [ 111 '(' + repr(op[0]) + ', ' + proper_repr(op[1]) + ')' for op in self._mixture 112 ] 113 args = [f'mixture=[{", ".join(unitary_tuples)}]'] 114 if self._key is not None: 115 args.append(f'key=\'{self._key}\'') 116 return f'cirq.MixedUnitaryChannel({", ".join(args)})' 117 118 def _json_dict_(self) -> Dict[str, Any]: 119 return protocols.obj_to_dict_helper(self, ['_mixture', '_key']) 120 121 @classmethod 122 def _from_json_dict_(cls, _mixture, _key, **kwargs): 123 mix_pairs = [(m[0], np.asarray(m[1])) for m in _mixture] 124 return cls(mixture=mix_pairs, key=_key) 125