1import numbers
2from typing import AbstractSet, Any, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union
3
4import numpy as np
5import sympy
6
7from cirq import value, ops, protocols, linalg
8from cirq.ops import gate_features
9from cirq._compat import proper_repr
10
11if TYPE_CHECKING:
12    import cirq
13
14
15@value.value_equality(approximate=True)
16class PhasedXZGate(gate_features.SingleQubitGate):
17    """A single qubit operation expressed as $Z^z Z^a X^x Z^{-a}$.
18
19    The above expression is a matrix multiplication with time going to the left.
20    In quantum circuit notation, this operation decomposes into this circuit:
21
22    ───Z^(-a)──X^x──Z^a────Z^z───
23
24    The axis phase exponent (a) decides which axis in the XY plane to rotate
25    around. The amount of rotation around that axis is decided by the x
26    exponent (x). Then the z exponent (z) decides how much to phase the qubit.
27    """
28
29    def __init__(
30        self,
31        *,
32        x_exponent: Union[numbers.Real, sympy.Basic],
33        z_exponent: Union[numbers.Real, sympy.Basic],
34        axis_phase_exponent: Union[numbers.Real, sympy.Basic],
35    ) -> None:
36        """Inits PhasedXZGate.
37
38        Args:
39            x_exponent: Determines how much to rotate during the
40                axis-in-XY-plane rotation. The $x$ in $Z^z Z^a X^x Z^{-a}$.
41            z_exponent: The amount of phasing to apply after the
42                axis-in-XY-plane rotation. The $z$ in $Z^z Z^a X^x Z^{-a}$.
43            axis_phase_exponent: Determines which axis to rotate around during
44                the axis-in-XY-plane rotation. The $a$ in $Z^z Z^a X^x Z^{-a}$.
45        """
46        self._x_exponent = x_exponent
47        self._z_exponent = z_exponent
48        self._axis_phase_exponent = axis_phase_exponent
49
50    def _canonical(self) -> 'cirq.PhasedXZGate':
51        x = self.x_exponent
52        z = self.z_exponent
53        a = self.axis_phase_exponent
54
55        # Canonicalize X exponent into (-1, +1].
56        if isinstance(x, numbers.Real):
57            x %= 2
58            if x > 1:
59                x -= 2
60
61        # Axis phase exponent is irrelevant if there is no X exponent.
62        if x == 0:
63            a = 0
64        # For 180 degree X rotations, the axis phase and z exponent overlap.
65        if x == 1 and z != 0:
66            a += z / 2
67            z = 0
68
69        # Canonicalize Z exponent into (-1, +1].
70        if isinstance(z, numbers.Real):
71            z %= 2
72            if z > 1:
73                z -= 2
74
75        # Canonicalize axis phase exponent into (-0.5, +0.5].
76        if isinstance(a, numbers.Real):
77            a %= 2
78            if a > 1:
79                a -= 2
80            if a <= -0.5:
81                a += 1
82                if x != 1:
83                    x = -x
84            elif a > 0.5:
85                a -= 1
86                if x != 1:
87                    x = -x
88
89        return PhasedXZGate(x_exponent=x, z_exponent=z, axis_phase_exponent=a)
90
91    @property
92    def x_exponent(self) -> Union[numbers.Real, sympy.Basic]:
93        return self._x_exponent
94
95    @property
96    def z_exponent(self) -> Union[numbers.Real, sympy.Basic]:
97        return self._z_exponent
98
99    @property
100    def axis_phase_exponent(self) -> Union[numbers.Real, sympy.Basic]:
101        return self._axis_phase_exponent
102
103    def _value_equality_values_(self):
104        c = self._canonical()
105        return (
106            value.PeriodicValue(c._x_exponent, 2),
107            value.PeriodicValue(c._z_exponent, 2),
108            value.PeriodicValue(c._axis_phase_exponent, 2),
109        )
110
111    @staticmethod
112    def from_matrix(mat: np.ndarray) -> 'cirq.PhasedXZGate':
113        pre_phase, rotation, post_phase = linalg.deconstruct_single_qubit_matrix_into_angles(mat)
114        pre_phase /= np.pi
115        post_phase /= np.pi
116        rotation /= np.pi
117        pre_phase -= 0.5
118        post_phase += 0.5
119        return PhasedXZGate(
120            x_exponent=rotation, axis_phase_exponent=-pre_phase, z_exponent=post_phase + pre_phase
121        )._canonical()
122
123    def with_z_exponent(self, z_exponent: Union[numbers.Real, sympy.Basic]) -> 'cirq.PhasedXZGate':
124        return PhasedXZGate(
125            axis_phase_exponent=self._axis_phase_exponent,
126            x_exponent=self._x_exponent,
127            z_exponent=z_exponent,
128        )
129
130    def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optional[str]:
131        from cirq.circuits import qasm_output
132
133        qasm_gate = qasm_output.QasmUGate(
134            lmda=0.5 - self._axis_phase_exponent,
135            theta=self._x_exponent,
136            phi=self._z_exponent + self._axis_phase_exponent - 0.5,
137        )
138        return protocols.qasm(qasm_gate, args=args, qubits=qubits)
139
140    def _has_unitary_(self) -> bool:
141        return not self._is_parameterized_()
142
143    def _decompose_(self, qubits: Sequence['cirq.Qid']) -> 'cirq.OP_TREE':
144        q = qubits[0]
145        yield ops.Z(q) ** -self._axis_phase_exponent
146        yield ops.X(q) ** self._x_exponent
147        yield ops.Z(q) ** (self._axis_phase_exponent + self._z_exponent)
148
149    def __pow__(self, exponent: Union[float, int]) -> 'PhasedXZGate':
150        if exponent == 1:
151            return self
152        if exponent == -1:
153            return PhasedXZGate(
154                x_exponent=-self._x_exponent,
155                z_exponent=-self._z_exponent,
156                axis_phase_exponent=self._z_exponent + self.axis_phase_exponent,
157            )
158        return NotImplemented
159
160    def _is_parameterized_(self) -> bool:
161        """See `cirq.SupportsParameterization`."""
162        return (
163            protocols.is_parameterized(self._x_exponent)
164            or protocols.is_parameterized(self._z_exponent)
165            or protocols.is_parameterized(self._axis_phase_exponent)
166        )
167
168    def _parameter_names_(self) -> AbstractSet[str]:
169        """See `cirq.SupportsParameterization`."""
170        return (
171            protocols.parameter_names(self._x_exponent)
172            | protocols.parameter_names(self._z_exponent)
173            | protocols.parameter_names(self._axis_phase_exponent)
174        )
175
176    def _resolve_parameters_(
177        self, resolver: 'cirq.ParamResolver', recursive: bool
178    ) -> 'cirq.PhasedXZGate':
179        """See `cirq.SupportsParameterization`."""
180        return PhasedXZGate(
181            z_exponent=resolver.value_of(self._z_exponent, recursive),
182            x_exponent=resolver.value_of(self._x_exponent, recursive),
183            axis_phase_exponent=resolver.value_of(self._axis_phase_exponent, recursive),
184        )
185
186    def _phase_by_(self, phase_turns, qubit_index) -> 'cirq.PhasedXZGate':
187        """See `cirq.SupportsPhase`."""
188        assert qubit_index == 0
189        return PhasedXZGate(
190            x_exponent=self._x_exponent,
191            z_exponent=self._z_exponent,
192            axis_phase_exponent=self._axis_phase_exponent + phase_turns * 2,
193        )
194
195    def _pauli_expansion_(self) -> 'cirq.LinearDict[str]':
196        if protocols.is_parameterized(self):
197            return NotImplemented
198        x_angle = np.pi * self._x_exponent / 2
199        z_angle = np.pi * self._z_exponent / 2
200        axis_angle = np.pi * self._axis_phase_exponent
201        phase = np.exp(1j * (x_angle + z_angle))
202
203        cx = np.cos(x_angle)
204        sx = np.sin(x_angle)
205        return value.LinearDict(
206            {
207                'I': phase * cx * np.cos(z_angle),
208                'X': -1j * phase * sx * np.cos(z_angle + axis_angle),
209                'Y': -1j * phase * sx * np.sin(z_angle + axis_angle),
210                'Z': -1j * phase * cx * np.sin(z_angle),
211            }
212        )  # yapf: disable
213
214    def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs') -> str:
215        """See `cirq.SupportsCircuitDiagramInfo`."""
216        return (
217            f'PhXZ('
218            f'a={args.format_real(self._axis_phase_exponent)},'
219            f'x={args.format_real(self._x_exponent)},'
220            f'z={args.format_real(self._z_exponent)})'
221        )
222
223    def __str__(self) -> str:
224        return protocols.circuit_diagram_info(self).wire_symbols[0]
225
226    def __repr__(self) -> str:
227        return (
228            f'cirq.PhasedXZGate('
229            f'axis_phase_exponent={proper_repr(self._axis_phase_exponent)},'
230            f' x_exponent={proper_repr(self._x_exponent)}, '
231            f'z_exponent={proper_repr(self._z_exponent)})'
232        )
233
234    def _json_dict_(self) -> Dict[str, Any]:
235        return protocols.obj_to_dict_helper(
236            self, ['axis_phase_exponent', 'x_exponent', 'z_exponent']
237        )
238