1from typing import Any, Dict, Iterable, Tuple, Union
2import numpy as np
3
4from cirq import linalg, protocols, value
5from cirq._compat import proper_repr
6from cirq.ops import raw_types
7
8
9class MixedUnitaryChannel(raw_types.Gate):
10    """A generic mixture that can record the index of its selected operator.
11
12    This type of object is also referred to as a mixed-unitary channel.
13
14    Args:
15        mixture: a list of (probability, qubit unitary) pairs
16        key: an optional measurement key string for this mixture. Simulations
17            which select a single unitary to apply will store the index
18            of that unitary in the measurement result list with this key.
19        validate: if True, validate that `mixture` describes a valid mixture.
20            This validation can be slow; prefer pre-validating if possible.
21    """
22
23    def __init__(
24        self,
25        mixture: Iterable[Tuple[float, np.ndarray]],
26        key: Union[str, value.MeasurementKey, None] = None,
27        validate: bool = False,
28    ):
29        mixture = list(mixture)
30        if not mixture:
31            raise ValueError('MixedUnitaryChannel must have at least one unitary.')
32        if not protocols.approx_eq(sum(p[0] for p in mixture), 1):
33            raise ValueError('Unitary probabilities must sum to 1.')
34        m0 = mixture[0][1]
35        num_qubits = np.log2(m0.shape[0])
36        if not num_qubits.is_integer() or m0.shape[1] != m0.shape[0]:
37            raise ValueError(
38                f'Input mixture of shape {m0.shape} does not '
39                'represent a square operator over qubits.'
40            )
41        self._num_qubits = int(num_qubits)
42        for i, op in enumerate(p[1] for p in mixture):
43            if not op.shape == m0.shape:
44                raise ValueError(
45                    f'Inconsistent unitary shapes: op[0]: {m0.shape}, op[{i}]: {op.shape}'
46                )
47            if validate and not linalg.is_unitary(op):
48                raise ValueError(f'Element {i} of mixture is non-unitary.')
49        self._mixture = mixture
50        if not isinstance(key, value.MeasurementKey) and key is not None:
51            key = value.MeasurementKey(key)
52        self._key = key
53
54    @staticmethod
55    def from_mixture(
56        mixture: 'protocols.SupportsMixture', key: Union[str, value.MeasurementKey, None] = None
57    ):
58        """Creates a copy of a mixture with the given measurement key."""
59        return MixedUnitaryChannel(mixture=list(protocols.mixture(mixture)), key=key)
60
61    def __eq__(self, other) -> bool:
62        if not isinstance(other, MixedUnitaryChannel):
63            return NotImplemented
64        if self._key != other._key:
65            return False
66        if not np.allclose(
67            [m[0] for m in self._mixture],
68            [m[0] for m in other._mixture],
69        ):
70            return False
71        return np.allclose(
72            [m[1] for m in self._mixture],
73            [m[1] for m in other._mixture],
74        )
75
76    def num_qubits(self) -> int:
77        return self._num_qubits
78
79    def _mixture_(self):
80        return self._mixture
81
82    def _measurement_key_name_(self) -> str:
83        if self._key is None:
84            return NotImplemented
85        return str(self._key)
86
87    def _measurement_key_obj_(self) -> value.MeasurementKey:
88        if self._key is None:
89            return NotImplemented
90        return self._key
91
92    def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
93        if self._key is None:
94            return NotImplemented
95        if self._key not in key_map:
96            return self
97        return MixedUnitaryChannel(mixture=self._mixture, key=key_map[str(self._key)])
98
99    def _with_key_path_(self, path: Tuple[str, ...]):
100        return MixedUnitaryChannel(
101            mixture=self._mixture, key=protocols.with_key_path(self._key, path)
102        )
103
104    def __str__(self):
105        if self._key is not None:
106            return f'MixedUnitaryChannel({self._mixture}, key={self._key})'
107        return f'MixedUnitaryChannel({self._mixture})'
108
109    def __repr__(self):
110        unitary_tuples = [
111            '(' + repr(op[0]) + ', ' + proper_repr(op[1]) + ')' for op in self._mixture
112        ]
113        args = [f'mixture=[{", ".join(unitary_tuples)}]']
114        if self._key is not None:
115            args.append(f'key=\'{self._key}\'')
116        return f'cirq.MixedUnitaryChannel({", ".join(args)})'
117
118    def _json_dict_(self) -> Dict[str, Any]:
119        return protocols.obj_to_dict_helper(self, ['_mixture', '_key'])
120
121    @classmethod
122    def _from_json_dict_(cls, _mixture, _key, **kwargs):
123        mix_pairs = [(m[0], np.asarray(m[1])) for m in _mixture]
124        return cls(mixture=mix_pairs, key=_key)
125