1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import AbstractSet, Dict, Iterable, Union, TYPE_CHECKING
16
17import sympy
18
19from cirq import value, protocols
20from cirq._compat import proper_repr
21from cirq.ops import (
22    raw_types,
23    common_gates,
24    pauli_string as ps,
25    pauli_gates,
26    op_tree,
27    pauli_string_raw_types,
28)
29
30if TYPE_CHECKING:
31    import cirq
32
33
34@value.value_equality(approximate=True)
35class PauliStringPhasor(pauli_string_raw_types.PauliStringGateOperation):
36    """An operation that phases the eigenstates of a Pauli string.
37
38    The -1 eigenstates of the Pauli string will have their amplitude multiplied
39    by e^(i pi exponent_neg) while +1 eigenstates of the Pauli string will have
40    their amplitude multiplied by e^(i pi exponent_pos).
41    """
42
43    def __init__(
44        self,
45        pauli_string: ps.PauliString,
46        *,
47        exponent_neg: Union[int, float, sympy.Basic] = 1,
48        exponent_pos: Union[int, float, sympy.Basic] = 0,
49    ) -> None:
50        """Initializes the operation.
51
52        Args:
53            pauli_string: The PauliString defining the positive and negative
54                eigenspaces that will be independently phased.
55            exponent_neg: How much to phase vectors in the negative eigenspace,
56                in the form of the t in (-1)**t = exp(i pi t).
57            exponent_pos: How much to phase vectors in the positive eigenspace,
58                in the form of the t in (-1)**t = exp(i pi t).
59        """
60        if pauli_string.coefficient == -1:
61            pauli_string = -pauli_string
62            exponent_pos, exponent_neg = exponent_neg, exponent_pos
63
64        if pauli_string.coefficient != 1:
65            raise ValueError(
66                "Given PauliString doesn't have +1 and -1 eigenvalues. "
67                "pauli_string.coefficient must be 1 or -1."
68            )
69
70        super().__init__(pauli_string)
71        self.exponent_neg = value.canonicalize_half_turns(exponent_neg)
72        self.exponent_pos = value.canonicalize_half_turns(exponent_pos)
73
74    @property
75    def exponent_relative(self) -> Union[int, float, sympy.Basic]:
76        return value.canonicalize_half_turns(self.exponent_neg - self.exponent_pos)
77
78    def _value_equality_values_(self):
79        return (
80            self.pauli_string,
81            self.exponent_neg,
82            self.exponent_pos,
83        )
84
85    def equal_up_to_global_phase(self, other):
86        if isinstance(other, PauliStringPhasor):
87            rel1 = self.exponent_relative
88            rel2 = other.exponent_relative
89            return rel1 == rel2 and self.pauli_string == other.pauli_string
90        return False
91
92    def map_qubits(self, qubit_map: Dict[raw_types.Qid, raw_types.Qid]):
93        return PauliStringPhasor(
94            self.pauli_string.map_qubits(qubit_map),
95            exponent_neg=self.exponent_neg,
96            exponent_pos=self.exponent_pos,
97        )
98
99    def __pow__(self, exponent: Union[float, sympy.Symbol]) -> 'PauliStringPhasor':
100        pn = protocols.mul(self.exponent_neg, exponent, None)
101        pp = protocols.mul(self.exponent_pos, exponent, None)
102        if pn is None or pp is None:
103            return NotImplemented
104        return PauliStringPhasor(self.pauli_string, exponent_neg=pn, exponent_pos=pp)
105
106    def can_merge_with(self, op: 'PauliStringPhasor') -> bool:
107        return self.pauli_string.equal_up_to_coefficient(op.pauli_string)
108
109    def merged_with(self, op: 'PauliStringPhasor') -> 'PauliStringPhasor':
110        if not self.can_merge_with(op):
111            raise ValueError(f'Cannot merge operations: {self}, {op}')
112        pp = self.exponent_pos + op.exponent_pos
113        pn = self.exponent_neg + op.exponent_neg
114        return PauliStringPhasor(self.pauli_string, exponent_pos=pp, exponent_neg=pn)
115
116    def _has_unitary_(self):
117        return not self._is_parameterized_()
118
119    def _decompose_(self) -> 'cirq.OP_TREE':
120        if len(self.pauli_string) <= 0:
121            return
122        qubits = self.qubits
123        any_qubit = qubits[0]
124        to_z_ops = op_tree.freeze_op_tree(self.pauli_string.to_z_basis_ops())
125        xor_decomp = tuple(xor_nonlocal_decompose(qubits, any_qubit))
126        yield to_z_ops
127        yield xor_decomp
128
129        if self.exponent_neg:
130            yield pauli_gates.Z(any_qubit) ** self.exponent_neg
131        if self.exponent_pos:
132            yield pauli_gates.X(any_qubit)
133            yield pauli_gates.Z(any_qubit) ** self.exponent_pos
134            yield pauli_gates.X(any_qubit)
135
136        yield protocols.inverse(xor_decomp)
137        yield protocols.inverse(to_z_ops)
138
139    def _circuit_diagram_info_(
140        self, args: 'cirq.CircuitDiagramInfoArgs'
141    ) -> 'cirq.CircuitDiagramInfo':
142        return self._pauli_string_diagram_info(args, exponent=self.exponent_relative)
143
144    def _trace_distance_bound_(self) -> float:
145        if len(self.qubits) == 0:
146            return 0.0
147        return protocols.trace_distance_bound(pauli_gates.Z ** self.exponent_relative)
148
149    def _is_parameterized_(self) -> bool:
150        return protocols.is_parameterized(self.exponent_neg) or protocols.is_parameterized(
151            self.exponent_pos
152        )
153
154    def _parameter_names_(self) -> AbstractSet[str]:
155        return protocols.parameter_names(self.exponent_neg) | protocols.parameter_names(
156            self.exponent_pos
157        )
158
159    def _resolve_parameters_(
160        self, resolver: 'cirq.ParamResolver', recursive: bool
161    ) -> 'PauliStringPhasor':
162        return PauliStringPhasor(
163            self.pauli_string,
164            exponent_neg=resolver.value_of(self.exponent_neg, recursive),
165            exponent_pos=resolver.value_of(self.exponent_pos, recursive),
166        )
167
168    def pass_operations_over(
169        self, ops: Iterable[raw_types.Operation], after_to_before: bool = False
170    ) -> 'PauliStringPhasor':
171        new_pauli_string = self.pauli_string.pass_operations_over(ops, after_to_before)
172        pp = self.exponent_pos
173        pn = self.exponent_neg
174        return PauliStringPhasor(new_pauli_string, exponent_pos=pp, exponent_neg=pn)
175
176    def __repr__(self) -> str:
177        return (
178            f'cirq.PauliStringPhasor({self.pauli_string!r}, '
179            f'exponent_neg={proper_repr(self.exponent_neg)}, '
180            f'exponent_pos={proper_repr(self.exponent_pos)})'
181        )
182
183    def __str__(self) -> str:
184        if self.exponent_pos == -self.exponent_neg:
185            sign = '-' if self.exponent_pos < 0 else ''
186            exponent = str(abs(self.exponent_pos))
187            return f'exp({sign}iπ{exponent}*{self.pauli_string})'
188        return f'({self.pauli_string})**{self.exponent_relative}'
189
190
191def xor_nonlocal_decompose(
192    qubits: Iterable[raw_types.Qid], onto_qubit: 'cirq.Qid'
193) -> Iterable[raw_types.Operation]:
194    """Decomposition ignores connectivity."""
195    for qubit in qubits:
196        if qubit != onto_qubit:
197            yield common_gates.CNOT(qubit, onto_qubit)
198