1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Optional, List, Iterator, Iterable, TYPE_CHECKING
16
17from cirq.interop.quirk.cells.cell import Cell, CELL_SIZES, CellMaker
18
19if TYPE_CHECKING:
20    import cirq
21
22
23class InputCell(Cell):
24    """A modifier that provides a quantum input to gates in the same column."""
25
26    def __init__(self, qubits: Iterable['cirq.Qid'], letter: str):
27        self.qubits = tuple(qubits)
28        self.letter = letter
29
30    def gate_count(self) -> int:
31        return 0
32
33    def with_line_qubits_mapped_to(self, qubits: List['cirq.Qid']) -> 'Cell':
34        return InputCell(qubits=Cell._replace_qubits(self.qubits, qubits), letter=self.letter)
35
36    def modify_column(self, column: List[Optional['Cell']]):
37        for i in range(len(column)):
38            cell = column[i]
39            if cell is not None:
40                column[i] = cell.with_input(self.letter, self.qubits)
41
42
43class SetDefaultInputCell(Cell):
44    """A persistent modifier that provides a fallback classical input."""
45
46    def __init__(self, letter: str, value: int):
47        self.letter = letter
48        self.value = value
49
50    def gate_count(self) -> int:
51        return 0
52
53    def with_line_qubits_mapped_to(self, qubits: List['cirq.Qid']) -> 'Cell':
54        return self
55
56    def persistent_modifiers(self):
57        return {f'set_default_{self.letter}': lambda cell: cell.with_input(self.letter, self.value)}
58
59
60def generate_all_input_cell_makers() -> Iterator[CellMaker]:
61    # Quantum inputs.
62    yield from _input_family("inputA", "a")
63    yield from _input_family("inputB", "b")
64    yield from _input_family("inputR", "r")
65    yield from _input_family("revinputA", "a", rev=True)
66    yield from _input_family("revinputB", "b", rev=True)
67
68    # Classical inputs.
69    yield CellMaker("setA", 2, lambda args: SetDefaultInputCell('a', args.value))
70    yield CellMaker("setB", 2, lambda args: SetDefaultInputCell('b', args.value))
71    yield CellMaker("setR", 2, lambda args: SetDefaultInputCell('r', args.value))
72
73
74def _input_family(identifier_prefix: str, letter: str, rev: bool = False) -> Iterator[CellMaker]:
75    for n in CELL_SIZES:
76        yield CellMaker(
77            identifier=identifier_prefix + str(n),
78            size=n,
79            maker=lambda args: InputCell(
80                qubits=args.qubits[::-1] if rev else args.qubits, letter=letter
81            ),
82        )
83