1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import AbstractSet, Dict, Iterable, Union, TYPE_CHECKING
16
17import sympy
18
19from cirq import value, protocols
20from cirq._compat import proper_repr
21from cirq.ops import (
22    raw_types,
23    common_gates,
24    pauli_string as ps,
25    pauli_gates,
26    op_tree,
27    pauli_string_raw_types,
28)
29
30if TYPE_CHECKING:
31    import cirq
32
33
34@value.value_equality(approximate=True)
35class PauliStringPhasor(pauli_string_raw_types.PauliStringGateOperation):
36    """An operation that phases the eigenstates of a Pauli string.
37
38    The -1 eigenstates of the Pauli string will have their amplitude multiplied
39    by e^(i pi exponent_neg) while +1 eigenstates of the Pauli string will have
40    their amplitude multiplied by e^(i pi exponent_pos).
41    """
42
43    # TODO(#3388) Add documentation for Raises.
44    # pylint: disable=missing-raises-doc
45    def __init__(
46        self,
47        pauli_string: ps.PauliString,
48        *,
49        exponent_neg: Union[int, float, sympy.Basic] = 1,
50        exponent_pos: Union[int, float, sympy.Basic] = 0,
51    ) -> None:
52        """Initializes the operation.
53
54        Args:
55            pauli_string: The PauliString defining the positive and negative
56                eigenspaces that will be independently phased.
57            exponent_neg: How much to phase vectors in the negative eigenspace,
58                in the form of the t in (-1)**t = exp(i pi t).
59            exponent_pos: How much to phase vectors in the positive eigenspace,
60                in the form of the t in (-1)**t = exp(i pi t).
61        """
62        if pauli_string.coefficient == -1:
63            pauli_string = -pauli_string
64            exponent_pos, exponent_neg = exponent_neg, exponent_pos
65
66        if pauli_string.coefficient != 1:
67            raise ValueError(
68                "Given PauliString doesn't have +1 and -1 eigenvalues. "
69                "pauli_string.coefficient must be 1 or -1."
70            )
71
72        super().__init__(pauli_string)
73        self.exponent_neg = value.canonicalize_half_turns(exponent_neg)
74        self.exponent_pos = value.canonicalize_half_turns(exponent_pos)
75
76    # pylint: enable=missing-raises-doc
77    @property
78    def exponent_relative(self) -> Union[int, float, sympy.Basic]:
79        return value.canonicalize_half_turns(self.exponent_neg - self.exponent_pos)
80
81    def _value_equality_values_(self):
82        return (
83            self.pauli_string,
84            self.exponent_neg,
85            self.exponent_pos,
86        )
87
88    def equal_up_to_global_phase(self, other):
89        if isinstance(other, PauliStringPhasor):
90            rel1 = self.exponent_relative
91            rel2 = other.exponent_relative
92            return rel1 == rel2 and self.pauli_string == other.pauli_string
93        return False
94
95    def map_qubits(self, qubit_map: Dict[raw_types.Qid, raw_types.Qid]):
96        return PauliStringPhasor(
97            self.pauli_string.map_qubits(qubit_map),
98            exponent_neg=self.exponent_neg,
99            exponent_pos=self.exponent_pos,
100        )
101
102    def __pow__(self, exponent: Union[float, sympy.Symbol]) -> 'PauliStringPhasor':
103        pn = protocols.mul(self.exponent_neg, exponent, None)
104        pp = protocols.mul(self.exponent_pos, exponent, None)
105        if pn is None or pp is None:
106            return NotImplemented
107        return PauliStringPhasor(self.pauli_string, exponent_neg=pn, exponent_pos=pp)
108
109    def can_merge_with(self, op: 'PauliStringPhasor') -> bool:
110        return self.pauli_string.equal_up_to_coefficient(op.pauli_string)
111
112    def merged_with(self, op: 'PauliStringPhasor') -> 'PauliStringPhasor':
113        if not self.can_merge_with(op):
114            raise ValueError(f'Cannot merge operations: {self}, {op}')
115        pp = self.exponent_pos + op.exponent_pos
116        pn = self.exponent_neg + op.exponent_neg
117        return PauliStringPhasor(self.pauli_string, exponent_pos=pp, exponent_neg=pn)
118
119    def _has_unitary_(self):
120        return not self._is_parameterized_()
121
122    def _decompose_(self) -> 'cirq.OP_TREE':
123        if len(self.pauli_string) <= 0:
124            return
125        qubits = self.qubits
126        any_qubit = qubits[0]
127        to_z_ops = op_tree.freeze_op_tree(self.pauli_string.to_z_basis_ops())
128        xor_decomp = tuple(xor_nonlocal_decompose(qubits, any_qubit))
129        yield to_z_ops
130        yield xor_decomp
131
132        if self.exponent_neg:
133            yield pauli_gates.Z(any_qubit) ** self.exponent_neg
134        if self.exponent_pos:
135            yield pauli_gates.X(any_qubit)
136            yield pauli_gates.Z(any_qubit) ** self.exponent_pos
137            yield pauli_gates.X(any_qubit)
138
139        yield protocols.inverse(xor_decomp)
140        yield protocols.inverse(to_z_ops)
141
142    def _circuit_diagram_info_(
143        self, args: 'cirq.CircuitDiagramInfoArgs'
144    ) -> 'cirq.CircuitDiagramInfo':
145        return self._pauli_string_diagram_info(args, exponent=self.exponent_relative)
146
147    def _trace_distance_bound_(self) -> float:
148        if len(self.qubits) == 0:
149            return 0.0
150        return protocols.trace_distance_bound(pauli_gates.Z ** self.exponent_relative)
151
152    def _is_parameterized_(self) -> bool:
153        return protocols.is_parameterized(self.exponent_neg) or protocols.is_parameterized(
154            self.exponent_pos
155        )
156
157    def _parameter_names_(self) -> AbstractSet[str]:
158        return protocols.parameter_names(self.exponent_neg) | protocols.parameter_names(
159            self.exponent_pos
160        )
161
162    def _resolve_parameters_(
163        self, resolver: 'cirq.ParamResolver', recursive: bool
164    ) -> 'PauliStringPhasor':
165        return PauliStringPhasor(
166            self.pauli_string,
167            exponent_neg=resolver.value_of(self.exponent_neg, recursive),
168            exponent_pos=resolver.value_of(self.exponent_pos, recursive),
169        )
170
171    def pass_operations_over(
172        self, ops: Iterable[raw_types.Operation], after_to_before: bool = False
173    ) -> 'PauliStringPhasor':
174        new_pauli_string = self.pauli_string.pass_operations_over(ops, after_to_before)
175        pp = self.exponent_pos
176        pn = self.exponent_neg
177        return PauliStringPhasor(new_pauli_string, exponent_pos=pp, exponent_neg=pn)
178
179    def __repr__(self) -> str:
180        return (
181            f'cirq.PauliStringPhasor({self.pauli_string!r}, '
182            f'exponent_neg={proper_repr(self.exponent_neg)}, '
183            f'exponent_pos={proper_repr(self.exponent_pos)})'
184        )
185
186    def __str__(self) -> str:
187        if self.exponent_pos == -self.exponent_neg:
188            sign = '-' if self.exponent_pos < 0 else ''
189            exponent = str(abs(self.exponent_pos))
190            return f'exp({sign}iπ{exponent}*{self.pauli_string})'
191        return f'({self.pauli_string})**{self.exponent_relative}'
192
193
194def xor_nonlocal_decompose(
195    qubits: Iterable[raw_types.Qid], onto_qubit: 'cirq.Qid'
196) -> Iterable[raw_types.Operation]:
197    """Decomposition ignores connectivity."""
198    for qubit in qubits:
199        if qubit != onto_qubit:
200            yield common_gates.CNOT(qubit, onto_qubit)
201