1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Any, Iterable, Iterator, List, Optional, Sequence, Tuple, Union
16
17import numpy as np
18
19import cirq
20from cirq import ops, linalg, value
21from cirq.interop.quirk.cells.cell import Cell, CellMaker
22
23
24@value.value_equality
25class InputRotationCell(Cell):
26    """Applies an operation that depends on an input gate."""
27
28    def __init__(
29        self,
30        identifier: str,
31        register: Optional[Sequence['cirq.Qid']],
32        base_operation: 'cirq.Operation',
33        exponent_sign: int,
34    ):
35        self.identifier = identifier
36        self.register = None if register is None else tuple(register)
37        self.base_operation = base_operation
38        self.exponent_sign = exponent_sign
39
40    def _value_equality_values_(self) -> Any:
41        return (
42            self.identifier,
43            self.register,
44            self.base_operation,
45            self.exponent_sign,
46        )
47
48    def __repr__(self) -> str:
49        return (
50            f'cirq.interop.quirk.cells.input_rotation_cells.InputRotationCell('
51            f'\n    {self.identifier!r},'
52            f'\n    {self.register!r},'
53            f'\n    {self.base_operation!r},'
54            f'\n    {self.exponent_sign!r})'
55        )
56
57    def gate_count(self) -> int:
58        return 1
59
60    def with_line_qubits_mapped_to(self, qubits: List['cirq.Qid']) -> 'Cell':
61        return InputRotationCell(
62            self.identifier,
63            None if self.register is None else Cell._replace_qubits(self.register, qubits),
64            self.base_operation.with_qubits(
65                *Cell._replace_qubits(self.base_operation.qubits, qubits)
66            ),
67            exponent_sign=self.exponent_sign,
68        )
69
70    def with_input(self, letter: str, register: Union[Sequence['cirq.Qid'], int]) -> 'Cell':
71        # Parameterized rotations use input A as their parameter.
72        if self.register is None and letter == 'a':
73            if isinstance(register, int):
74                raise ValueError(
75                    'Dependent operation requires known length '
76                    'input; classical constant not allowed.'
77                )
78            return InputRotationCell(
79                self.identifier, register, self.base_operation, self.exponent_sign
80            )
81        return self
82
83    def controlled_by(self, qubit: 'cirq.Qid'):
84        return InputRotationCell(
85            self.identifier,
86            self.register,
87            self.base_operation.controlled_by(qubit),
88            self.exponent_sign,
89        )
90
91    def operations(self) -> 'cirq.OP_TREE':
92        if self.register is None:
93            raise ValueError(f"Missing input 'a'")
94        return QuirkInputRotationOperation(
95            self.identifier, self.register, self.base_operation, self.exponent_sign
96        )
97
98
99@value.value_equality
100class QuirkInputRotationOperation(ops.Operation):
101    """Operates on target qubits in a way that varies based on an input qureg."""
102
103    def __init__(
104        self,
105        identifier: str,
106        register: Iterable['cirq.Qid'],
107        base_operation: 'cirq.Operation',
108        exponent_sign: int,
109    ):
110        if exponent_sign not in [-1, +1]:
111            raise ValueError('exponent_sign not in [-1, +1]')
112        self.identifier = identifier
113        self.register = tuple(register)
114        self.base_operation = base_operation
115        self.exponent_sign = exponent_sign
116
117    def _value_equality_values_(self) -> Any:
118        return (
119            self.identifier,
120            self.register,
121            self.base_operation,
122            self.exponent_sign,
123        )
124
125    @property
126    def qubits(self) -> Tuple['cirq.Qid', ...]:
127        return tuple(self.base_operation.qubits) + self.register
128
129    def with_qubits(self, *new_qubits):
130        k = len(self.base_operation.qubits)
131        new_op_qubits = new_qubits[:k]
132        new_register = new_qubits[k:]
133        return QuirkInputRotationOperation(
134            self.identifier,
135            new_register,
136            self.base_operation.with_qubits(*new_op_qubits),
137            self.exponent_sign,
138        )
139
140    def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'):
141        sub_result = cirq.circuit_diagram_info(self.base_operation)
142        sign_char = '-' if self.exponent_sign == -1 else ''
143        symbols = list(sub_result.wire_symbols)
144        symbols.extend(f'A{i}' for i in range(len(self.register)))
145        return cirq.CircuitDiagramInfo(
146            tuple(symbols),
147            exponent=f'({sign_char}A/2^{len(self.register)})',
148            exponent_qubit_index=sub_result.exponent_qubit_index or 0,
149            auto_exponent_parens=False,
150        )
151
152    def _has_unitary_(self) -> bool:
153        return True
154
155    def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'):
156        transposed_args = args.with_axes_transposed_to_start()
157
158        target_axes = transposed_args.axes[: len(self.base_operation.qubits)]
159        control_axes = transposed_args.axes[len(self.base_operation.qubits) :]
160        control_max = np.prod([q.dimension for q in self.register], dtype=np.int64).item()
161
162        for i in range(control_max):
163            operation = self.base_operation ** (self.exponent_sign * i / control_max)
164            control_index = linalg.slice_for_qubits_equal_to(control_axes, big_endian_qureg_value=i)
165            sub_args = cirq.ApplyUnitaryArgs(
166                transposed_args.target_tensor[control_index],
167                transposed_args.available_buffer[control_index],
168                target_axes,
169            )
170            sub_result = cirq.apply_unitary(operation, sub_args)
171
172            if sub_result is not sub_args.target_tensor:
173                sub_args.target_tensor[...] = sub_result
174
175        return args.target_tensor
176
177    def __repr__(self) -> str:
178        return (
179            f'cirq.interop.quirk.QuirkInputRotationOperation('
180            f'identifier={self.identifier!r}, '
181            f'register={self.register!r}, '
182            f'base_operation={self.base_operation!r}, '
183            f'exponent_sign={self.exponent_sign!r})'
184        )
185
186
187def generate_all_input_rotation_cell_makers() -> Iterator[CellMaker]:
188    yield _input_rotation_gate("X^(A/2^n)", ops.X, +1)
189    yield _input_rotation_gate("Y^(A/2^n)", ops.Y, +1)
190    yield _input_rotation_gate("Z^(A/2^n)", ops.Z, +1)
191    yield _input_rotation_gate("X^(-A/2^n)", ops.X, -1)
192    yield _input_rotation_gate("Y^(-A/2^n)", ops.Y, -1)
193    yield _input_rotation_gate("Z^(-A/2^n)", ops.Z, -1)
194
195
196def _input_rotation_gate(identifier: str, gate: 'cirq.Gate', exponent_sign: int) -> CellMaker:
197    return CellMaker(
198        identifier,
199        gate.num_qubits(),
200        lambda args: InputRotationCell(
201            identifier=identifier,
202            register=None,
203            base_operation=gate.on(args.qubits[0]),
204            exponent_sign=exponent_sign,
205        ),
206    )
207