1# Copyright 2020 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Callable, Dict, Set, Tuple, Union
16import numpy as np
17import cirq
18from cirq import protocols, value, ops
19
20
21def to_quil_complex_format(num) -> str:
22    """A function for outputting a number to a complex string in QUIL format."""
23    cnum = complex(str(num))
24    return f"{cnum.real}+{cnum.imag}i"
25
26
27@value.value_equality(approximate=True)
28class QuilOneQubitGate(ops.SingleQubitGate):
29    """A QUIL gate representing any single qubit unitary with a DEFGATE and
30    2x2 matrix in QUIL.
31    """
32
33    def __init__(self, matrix: np.ndarray) -> None:
34        """Inits QuilOneQubitGate.
35
36        Args:
37            matrix: The 2x2 unitary matrix for this gate.
38        """
39        self.matrix = matrix
40
41    def _quil_(self, qubits: Tuple['cirq.Qid', ...], formatter: 'cirq.QuilFormatter') -> str:
42        return (
43            f'DEFGATE USERGATE:\n    '
44            f'{to_quil_complex_format(self.matrix[0, 0])}, '
45            f'{to_quil_complex_format(self.matrix[0, 1])}\n    '
46            f'{to_quil_complex_format(self.matrix[1, 0])}, '
47            f'{to_quil_complex_format(self.matrix[1, 1])}\n'
48            f'{formatter.format("USERGATE {0}", qubits[0])}\n'
49        )
50
51    def __repr__(self) -> str:
52        return f'cirq.circuits.quil_output.QuilOneQubitGate(matrix=\n{self.matrix}\n)'
53
54    def _value_equality_values_(self):
55        return self.matrix
56
57
58@value.value_equality(approximate=True)
59class QuilTwoQubitGate(ops.Gate):
60    """A two qubit gate represented in QUIL with a DEFGATE and it's 4x4
61    unitary matrix.
62    """
63
64    def __init__(self, matrix: np.ndarray) -> None:
65        """Inits QuilTwoQubitGate.
66
67        Args:
68            matrix: The 4x4 unitary matrix for this gate.
69        """
70        self.matrix = matrix
71
72    def _num_qubits_(self) -> int:
73        return 2
74
75    def _value_equality_values_(self):
76        return self.matrix
77
78    def _quil_(self, qubits: Tuple['cirq.Qid', ...], formatter: 'cirq.QuilFormatter') -> str:
79        return (
80            f'DEFGATE USERGATE:\n    '
81            f'{to_quil_complex_format(self.matrix[0, 0])}, '
82            f'{to_quil_complex_format(self.matrix[0, 1])}, '
83            f'{to_quil_complex_format(self.matrix[0, 2])}, '
84            f'{to_quil_complex_format(self.matrix[0, 3])}\n    '
85            f'{to_quil_complex_format(self.matrix[1, 0])}, '
86            f'{to_quil_complex_format(self.matrix[1, 1])}, '
87            f'{to_quil_complex_format(self.matrix[1, 2])}, '
88            f'{to_quil_complex_format(self.matrix[1, 3])}\n    '
89            f'{to_quil_complex_format(self.matrix[2, 0])}, '
90            f'{to_quil_complex_format(self.matrix[2, 1])}, '
91            f'{to_quil_complex_format(self.matrix[2, 2])}, '
92            f'{to_quil_complex_format(self.matrix[2, 3])}\n    '
93            f'{to_quil_complex_format(self.matrix[3, 0])}, '
94            f'{to_quil_complex_format(self.matrix[3, 1])}, '
95            f'{to_quil_complex_format(self.matrix[3, 2])}, '
96            f'{to_quil_complex_format(self.matrix[3, 3])}\n'
97            f'{formatter.format("USERGATE {0} {1}", qubits[0], qubits[1])}\n'
98        )
99
100    def __repr__(self) -> str:
101        return f'cirq.circuits.quil_output.QuilTwoQubitGate(matrix=\n{self.matrix}\n)'
102
103
104class QuilOutput:
105    """An object for passing operations and qubits then outputting them to
106    QUIL format. The string representation returns the QUIL output for the
107    circuit.
108    """
109
110    def __init__(self, operations: 'cirq.OP_TREE', qubits: Tuple['cirq.Qid', ...]) -> None:
111        """Inits QuilOutput.
112
113        Args:
114            operations: A list or tuple of `cirq.OP_TREE` arguments.
115            qubits: The qubits used in the operations.
116        """
117        self.qubits = qubits
118        self.operations = tuple(cirq.ops.flatten_to_ops(operations))
119        self.measurements = tuple(
120            op for op in self.operations if isinstance(op.gate, ops.MeasurementGate)
121        )
122        self.qubit_id_map = self._generate_qubit_ids()
123        self.measurement_id_map = self._generate_measurement_ids()
124        self.formatter = protocols.QuilFormatter(
125            qubit_id_map=self.qubit_id_map, measurement_id_map=self.measurement_id_map
126        )
127
128    def _generate_qubit_ids(self) -> Dict['cirq.Qid', str]:
129        return {qubit: str(i) for i, qubit in enumerate(self.qubits)}
130
131    def _generate_measurement_ids(self) -> Dict[str, str]:
132        index = 0
133        measurement_id_map: Dict[str, str] = {}
134        for op in self.operations:
135            if isinstance(op.gate, ops.MeasurementGate):
136                key = protocols.measurement_key_name(op)
137                if key in measurement_id_map:
138                    continue
139                measurement_id_map[key] = f'm{index}'
140                index += 1
141        return measurement_id_map
142
143    def save_to_file(self, path: Union[str, bytes, int]) -> None:
144        """Write QUIL output to a file specified by path."""
145        with open(path, 'w') as f:
146            f.write(str(self))
147
148    def __str__(self) -> str:
149        output = []
150        self._write_quil(lambda s: output.append(s))
151        return self.rename_defgates(''.join(output))
152
153    def _write_quil(self, output_func: Callable[[str], None]) -> None:
154        output_func('# Created using Cirq.\n\n')
155        if len(self.measurements) > 0:
156            measurements_declared: Set[str] = set()
157            for m in self.measurements:
158                key = protocols.measurement_key_name(m)
159                if key in measurements_declared:
160                    continue
161                measurements_declared.add(key)
162                output_func(f'DECLARE {self.measurement_id_map[key]} BIT[{len(m.qubits)}]\n')
163            output_func('\n')
164
165        def keep(op: 'cirq.Operation') -> bool:
166            return protocols.quil(op, formatter=self.formatter) is not None
167
168        def fallback(op):
169            if len(op.qubits) not in [1, 2]:
170                return NotImplemented
171
172            mat = protocols.unitary(op, None)
173            if mat is None:
174                return NotImplemented
175
176            # Following code is a safety measure
177            # Could not find a gate that doesn't decompose into a gate
178            # with a _quil_ implementation
179            # coverage: ignore
180            if len(op.qubits) == 1:
181                return QuilOneQubitGate(mat).on(*op.qubits)
182            return QuilTwoQubitGate(mat).on(*op.qubits)
183
184        def on_stuck(bad_op):
185            return ValueError(f'Cannot output operation as QUIL: {bad_op!r}')
186
187        for main_op in self.operations:
188            decomposed = protocols.decompose(
189                main_op, keep=keep, fallback_decomposer=fallback, on_stuck_raise=on_stuck
190            )
191
192            for decomposed_op in decomposed:
193                output_func(protocols.quil(decomposed_op, formatter=self.formatter))
194
195    def rename_defgates(self, output: str) -> str:
196        """A function for renaming the DEFGATEs within the QUIL output. This
197        utilizes a second pass to find each DEFGATE and rename it based on
198        a counter.
199        """
200        result = output
201        defString = "DEFGATE"
202        nameString = "USERGATE"
203        defIdx = 0
204        nameIdx = 0
205        gateNum = 0
206        i = 0
207        while i < len(output):
208            if result[i] == defString[defIdx]:
209                defIdx += 1
210            else:
211                defIdx = 0
212            if result[i] == nameString[nameIdx]:
213                nameIdx += 1
214            else:
215                nameIdx = 0
216            if defIdx == len(defString):
217                gateNum += 1
218                defIdx = 0
219            if nameIdx == len(nameString):
220                result = result[: i + 1] + str(gateNum) + result[i + 1 :]
221                nameIdx = 0
222                i += 1
223            i += 1
224        return result
225