1# Copyright 2020 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numbers
16from typing import (
17    AbstractSet,
18    Tuple,
19    TYPE_CHECKING,
20    Dict,
21    Any,
22    cast,
23    SupportsFloat,
24    Optional,
25    Sequence,
26)
27
28import numpy as np
29
30from cirq import protocols, value
31from cirq.ops import raw_types
32from cirq._compat import proper_repr
33
34if TYPE_CHECKING:
35    import cirq
36
37
38@value.value_equality(approximate=True)
39class RandomGateChannel(raw_types.Gate):
40    """Applies a sub gate with some probability."""
41
42    def __init__(self, *, sub_gate: 'cirq.Gate', probability: value.TParamVal):
43        if (
44            isinstance(probability, numbers.Number)
45            and not 0 <= float(cast(SupportsFloat, probability)) <= 1
46        ):
47            raise ValueError("not 0 <= probability <= 1")
48
49        self.sub_gate = sub_gate
50        self.probability = probability
51
52        # Auto flatten.
53        if isinstance(self.sub_gate, RandomGateChannel):
54            self.probability *= self.sub_gate.probability
55            self.sub_gate = self.sub_gate.sub_gate
56
57    def _qid_shape_(self) -> Tuple[int, ...]:
58        return protocols.qid_shape(self.sub_gate)
59
60    def _value_equality_values_(self):
61        return self.sub_gate, self.probability
62
63    def _has_unitary_(self):
64        return False
65
66    def _has_mixture_(self):
67        return not self._is_parameterized_() and protocols.has_mixture(self.sub_gate)
68
69    def _has_kraus_(self):
70        return not self._is_parameterized_() and protocols.has_kraus(self.sub_gate)
71
72    def _is_parameterized_(self) -> bool:
73        return protocols.is_parameterized(self.probability) or protocols.is_parameterized(
74            self.sub_gate
75        )
76
77    def _parameter_names_(self) -> AbstractSet[str]:
78        return protocols.parameter_names(self.probability) | protocols.parameter_names(
79            self.sub_gate
80        )
81
82    def _resolve_parameters_(
83        self, resolver: 'cirq.ParamResolver', recursive: bool
84    ) -> 'RandomGateChannel':
85        return RandomGateChannel(
86            sub_gate=protocols.resolve_parameters(self.sub_gate, resolver, recursive),
87            probability=protocols.resolve_parameters(self.probability, resolver, recursive),
88        )
89
90    def _mixture_(self):
91        if self._is_parameterized_():
92            return NotImplemented
93
94        mixture = protocols.mixture(self.sub_gate, None)
95        if mixture is None:
96            return None
97
98        do_nothing = np.eye(
99            np.prod(protocols.qid_shape(self.sub_gate), dtype=np.int64), dtype=np.float64
100        )
101        result = [(p * float(self.probability), m) for p, m in mixture]
102        result.append((1 - float(self.probability), do_nothing))
103        return result
104
105    def _kraus_(self):
106        if self._is_parameterized_():
107            return NotImplemented
108
109        channel = protocols.kraus(self.sub_gate, None)
110        if channel is None:
111            return NotImplemented
112
113        do_nothing = np.eye(
114            np.prod(protocols.qid_shape(self.sub_gate), dtype=np.int64), dtype=np.float64
115        )
116        result = [e * np.sqrt(self.probability) for e in channel]
117        result.append(np.sqrt(1 - float(self.probability)) * do_nothing)
118        return result
119
120    def _trace_distance_bound_(self) -> float:
121        result = protocols.trace_distance_bound(self.sub_gate)
122        if not self._is_parameterized_():
123            result *= float(self.probability)
124        return result
125
126    def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
127        from cirq.sim import clifford
128
129        if self._is_parameterized_():
130            return NotImplemented
131        if isinstance(args, clifford.ActOnCliffordTableauArgs):
132            if args.prng.random() < self.probability:
133                # Note: because we're doing this probabilistically, it's not
134                # safe to fallback to other strategies if act_on fails. Those
135                # strategies could double-count the probability.
136                protocols.act_on(self.sub_gate, args, qubits)
137            return True
138        return NotImplemented
139
140    def _json_dict_(self) -> Dict[str, Any]:
141        return protocols.obj_to_dict_helper(self, ['sub_gate', 'probability'])
142
143    @classmethod
144    def _from_json_dict_(cls, sub_gate, probability, **kwargs):
145        return cls(sub_gate=sub_gate, probability=probability)
146
147    def _circuit_diagram_info_(
148        self, args: 'cirq.CircuitDiagramInfoArgs'
149    ) -> Optional['cirq.CircuitDiagramInfo']:
150        result = protocols.circuit_diagram_info(self.sub_gate, args, None)
151        if result is None:
152            return None
153        wires = list(result.wire_symbols)
154        if wires:
155            wires[0] += f'[prob={args.format_real(self.probability)}]'
156        return result.with_wire_symbols(wires)
157
158    def __str__(self):
159        return f'{self.sub_gate}[prob={self.probability}]'
160
161    def __repr__(self):
162        if self.probability == 1:
163            return f'cirq.RandomGateChannel(sub_gate={self.sub_gate!r}, probability=1)'
164        return f'{self.sub_gate!r}.with_probability({proper_repr(self.probability)})'
165