1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import Any, Iterable, Iterator, List, Optional, Sequence, Tuple, Union 16 17import numpy as np 18 19import cirq 20from cirq import ops, linalg, value 21from cirq.interop.quirk.cells.cell import Cell, CellMaker 22 23 24@value.value_equality 25class InputRotationCell(Cell): 26 """Applies an operation that depends on an input gate.""" 27 28 def __init__( 29 self, 30 identifier: str, 31 register: Optional[Sequence['cirq.Qid']], 32 base_operation: 'cirq.Operation', 33 exponent_sign: int, 34 ): 35 self.identifier = identifier 36 self.register = None if register is None else tuple(register) 37 self.base_operation = base_operation 38 self.exponent_sign = exponent_sign 39 40 def _value_equality_values_(self) -> Any: 41 return ( 42 self.identifier, 43 self.register, 44 self.base_operation, 45 self.exponent_sign, 46 ) 47 48 def __repr__(self) -> str: 49 return ( 50 f'cirq.interop.quirk.cells.input_rotation_cells.InputRotationCell(' 51 f'\n {self.identifier!r},' 52 f'\n {self.register!r},' 53 f'\n {self.base_operation!r},' 54 f'\n {self.exponent_sign!r})' 55 ) 56 57 def gate_count(self) -> int: 58 return 1 59 60 def with_line_qubits_mapped_to(self, qubits: List['cirq.Qid']) -> 'Cell': 61 return InputRotationCell( 62 self.identifier, 63 None if self.register is None else Cell._replace_qubits(self.register, qubits), 64 self.base_operation.with_qubits( 65 *Cell._replace_qubits(self.base_operation.qubits, qubits) 66 ), 67 exponent_sign=self.exponent_sign, 68 ) 69 70 def with_input(self, letter: str, register: Union[Sequence['cirq.Qid'], int]) -> 'Cell': 71 # Parameterized rotations use input A as their parameter. 72 if self.register is None and letter == 'a': 73 if isinstance(register, int): 74 raise ValueError( 75 'Dependent operation requires known length ' 76 'input; classical constant not allowed.' 77 ) 78 return InputRotationCell( 79 self.identifier, register, self.base_operation, self.exponent_sign 80 ) 81 return self 82 83 def controlled_by(self, qubit: 'cirq.Qid'): 84 return InputRotationCell( 85 self.identifier, 86 self.register, 87 self.base_operation.controlled_by(qubit), 88 self.exponent_sign, 89 ) 90 91 def operations(self) -> 'cirq.OP_TREE': 92 if self.register is None: 93 raise ValueError(f"Missing input 'a'") 94 return QuirkInputRotationOperation( 95 self.identifier, self.register, self.base_operation, self.exponent_sign 96 ) 97 98 99@value.value_equality 100class QuirkInputRotationOperation(ops.Operation): 101 """Operates on target qubits in a way that varies based on an input qureg.""" 102 103 def __init__( 104 self, 105 identifier: str, 106 register: Iterable['cirq.Qid'], 107 base_operation: 'cirq.Operation', 108 exponent_sign: int, 109 ): 110 if exponent_sign not in [-1, +1]: 111 raise ValueError('exponent_sign not in [-1, +1]') 112 self.identifier = identifier 113 self.register = tuple(register) 114 self.base_operation = base_operation 115 self.exponent_sign = exponent_sign 116 117 def _value_equality_values_(self) -> Any: 118 return ( 119 self.identifier, 120 self.register, 121 self.base_operation, 122 self.exponent_sign, 123 ) 124 125 @property 126 def qubits(self) -> Tuple['cirq.Qid', ...]: 127 return tuple(self.base_operation.qubits) + self.register 128 129 def with_qubits(self, *new_qubits): 130 k = len(self.base_operation.qubits) 131 new_op_qubits = new_qubits[:k] 132 new_register = new_qubits[k:] 133 return QuirkInputRotationOperation( 134 self.identifier, 135 new_register, 136 self.base_operation.with_qubits(*new_op_qubits), 137 self.exponent_sign, 138 ) 139 140 def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs'): 141 sub_result = cirq.circuit_diagram_info(self.base_operation) 142 sign_char = '-' if self.exponent_sign == -1 else '' 143 symbols = list(sub_result.wire_symbols) 144 symbols.extend(f'A{i}' for i in range(len(self.register))) 145 return cirq.CircuitDiagramInfo( 146 tuple(symbols), 147 exponent=f'({sign_char}A/2^{len(self.register)})', 148 exponent_qubit_index=sub_result.exponent_qubit_index or 0, 149 auto_exponent_parens=False, 150 ) 151 152 def _has_unitary_(self) -> bool: 153 return True 154 155 def _apply_unitary_(self, args: 'cirq.ApplyUnitaryArgs'): 156 transposed_args = args.with_axes_transposed_to_start() 157 158 target_axes = transposed_args.axes[: len(self.base_operation.qubits)] 159 control_axes = transposed_args.axes[len(self.base_operation.qubits) :] 160 control_max = np.prod([q.dimension for q in self.register], dtype=np.int64).item() 161 162 for i in range(control_max): 163 operation = self.base_operation ** (self.exponent_sign * i / control_max) 164 control_index = linalg.slice_for_qubits_equal_to(control_axes, big_endian_qureg_value=i) 165 sub_args = cirq.ApplyUnitaryArgs( 166 transposed_args.target_tensor[control_index], 167 transposed_args.available_buffer[control_index], 168 target_axes, 169 ) 170 sub_result = cirq.apply_unitary(operation, sub_args) 171 172 if sub_result is not sub_args.target_tensor: 173 sub_args.target_tensor[...] = sub_result 174 175 return args.target_tensor 176 177 def __repr__(self) -> str: 178 return ( 179 f'cirq.interop.quirk.QuirkInputRotationOperation(' 180 f'identifier={self.identifier!r}, ' 181 f'register={self.register!r}, ' 182 f'base_operation={self.base_operation!r}, ' 183 f'exponent_sign={self.exponent_sign!r})' 184 ) 185 186 187def generate_all_input_rotation_cell_makers() -> Iterator[CellMaker]: 188 yield _input_rotation_gate("X^(A/2^n)", ops.X, +1) 189 yield _input_rotation_gate("Y^(A/2^n)", ops.Y, +1) 190 yield _input_rotation_gate("Z^(A/2^n)", ops.Z, +1) 191 yield _input_rotation_gate("X^(-A/2^n)", ops.X, -1) 192 yield _input_rotation_gate("Y^(-A/2^n)", ops.Y, -1) 193 yield _input_rotation_gate("Z^(-A/2^n)", ops.Z, -1) 194 195 196def _input_rotation_gate(identifier: str, gate: 'cirq.Gate', exponent_sign: int) -> CellMaker: 197 return CellMaker( 198 identifier, 199 gate.num_qubits(), 200 lambda args: InputRotationCell( 201 identifier=identifier, 202 register=None, 203 base_operation=gate.on(args.qubits[0]), 204 exponent_sign=exponent_sign, 205 ), 206 ) 207