1# Copyright 2020 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import FrozenSet, Callable, List, Sequence, Any, Union, Dict
15
16import numpy as np
17
18import cirq
19from cirq import GridQubit, LineQubit
20from cirq.ops import NamedQubit
21from cirq_pasqal import ThreeDQubit, TwoDQubit
22
23
24@cirq.value.value_equality
25class PasqalDevice(cirq.devices.Device):
26    """A generic Pasqal device.
27
28    The most general of Pasqal devices, enforcing only restrictions expected to
29    be shared by all future devices. Serves as the parent class of all Pasqal
30    devices, but can also be used on its own for hosting a nearly unconstrained
31    device. When used as a circuit's device, the qubits have to be of the type
32    cirq.NamedQubit and assumed to be all connected, the idea behind it being
33    that after submission, all optimization and transpilation necessary for its
34    execution on the specified device are handled internally by Pasqal.
35    """
36
37    # TODO(#3388) Add documentation for Raises.
38    # pylint: disable=missing-raises-doc
39    def __init__(self, qubits: Sequence[cirq.ops.Qid]) -> None:
40        """Initializes a device with some qubits.
41
42        Args:
43            qubits (NamedQubit): Qubits on the device, exclusively unrelated to
44                a physical position.
45        Raises:
46            TypeError: if the wrong qubit type is provided.
47        """
48        if len(qubits) > 0:
49            q_type = type(qubits[0])
50
51        for q in qubits:
52            if not isinstance(q, self.supported_qubit_type):
53                raise TypeError(
54                    'Unsupported qubit type: {!r}. This device '
55                    'supports qubit types: {}'.format(q, self.supported_qubit_type)
56                )
57            if not type(q) is q_type:
58                raise TypeError("All qubits must be of same type.")
59
60        if len(qubits) > self.maximum_qubit_number:
61            raise ValueError(
62                'Too many qubits. {} accepts at most {} '
63                'qubits.'.format(type(self), self.maximum_qubit_number)
64            )
65
66        self.gateset = cirq.Gateset(
67            cirq.ParallelGateFamily(cirq.H),
68            cirq.ParallelGateFamily(cirq.PhasedXPowGate),
69            cirq.ParallelGateFamily(cirq.XPowGate),
70            cirq.ParallelGateFamily(cirq.YPowGate),
71            cirq.ParallelGateFamily(cirq.ZPowGate),
72            cirq.AnyIntegerPowerGateFamily(cirq.CNotPowGate),
73            cirq.AnyIntegerPowerGateFamily(cirq.CCNotPowGate),
74            cirq.AnyIntegerPowerGateFamily(cirq.CZPowGate),
75            cirq.AnyIntegerPowerGateFamily(cirq.CCZPowGate),
76            cirq.IdentityGate,
77            cirq.MeasurementGate,
78            unroll_circuit_op=False,
79            accept_global_phase_op=False,
80        )
81        self.qubits = qubits
82
83    # pylint: enable=missing-raises-doc
84    @property
85    def supported_qubit_type(self):
86        return (NamedQubit,)
87
88    @property
89    def maximum_qubit_number(self):
90        return 100
91
92    def qubit_set(self) -> FrozenSet[cirq.Qid]:
93        return frozenset(self.qubits)
94
95    def qubit_list(self):
96        return [qubit for qubit in self.qubits]
97
98    def decompose_operation(self, operation: cirq.ops.Operation) -> 'cirq.OP_TREE':
99
100        decomposition = [operation]
101
102        if not isinstance(operation, (cirq.ops.GateOperation, cirq.ParallelGateOperation)):
103            raise TypeError(f"{operation!r} is not a gate operation.")
104
105        # Try to decompose the operation into elementary device operations
106        if not self.is_pasqal_device_op(operation):
107            decomposition = PasqalConverter().pasqal_convert(
108                operation, keep=self.is_pasqal_device_op
109            )
110
111        return decomposition
112
113    def is_pasqal_device_op(self, op: cirq.ops.Operation) -> bool:
114        if not isinstance(op, cirq.ops.Operation):
115            raise ValueError('Got unknown operation:', op)
116        return op in self.gateset
117
118    # TODO(#3388) Add documentation for Raises.
119    # pylint: disable=missing-raises-doc
120    def validate_operation(self, operation: cirq.ops.Operation):
121        """Raises an error if the given operation is invalid on this device.
122
123        Args:
124            operation: the operation to validate
125
126        Raises:
127            ValueError: If the operation is not valid
128        """
129
130        if not isinstance(operation, (cirq.GateOperation, cirq.ParallelGateOperation)):
131            raise ValueError("Unsupported operation")
132
133        if not self.is_pasqal_device_op(operation):
134            raise ValueError(f'{operation.gate!r} is not a supported gate')
135
136        for qub in operation.qubits:
137            if not isinstance(qub, self.supported_qubit_type):
138                raise ValueError(
139                    '{} is not a valid qubit for gate {!r}. This '
140                    'device accepts gates on qubits of type: '
141                    '{}'.format(qub, operation.gate, self.supported_qubit_type)
142                )
143            if qub not in self.qubit_set():
144                raise ValueError(f'{qub} is not part of the device.')
145
146        if isinstance(operation.gate, cirq.ops.MeasurementGate):
147            if operation.gate.invert_mask != ():
148                raise NotImplementedError(
149                    "Measurements on Pasqal devices don't support invert_mask."
150                )
151
152    # pylint: enable=missing-raises-doc
153    def validate_circuit(self, circuit: 'cirq.AbstractCircuit') -> None:
154        """Raises an error if the given circuit is invalid on this device.
155
156        A circuit is invalid if any of its moments are invalid or if there
157        is a non-empty moment after a moment with a measurement.
158
159        Args:
160            circuit: The circuit to validate
161
162        Raises:
163            ValueError: If the given circuit can't be run on this device
164        """
165        super().validate_circuit(circuit)
166
167        # Measurements must be in the last non-empty moment
168        has_measurement_occurred = False
169        for moment in circuit:
170            if has_measurement_occurred:
171                if len(moment.operations) > 0:
172                    raise ValueError("Non-empty moment after measurement")
173            for operation in moment.operations:
174                if isinstance(operation.gate, cirq.ops.MeasurementGate):
175                    has_measurement_occurred = True
176
177    def can_add_operation_into_moment(
178        self, operation: cirq.ops.Operation, moment: cirq.ops.Moment
179    ) -> bool:
180        """Determines if it's possible to add an operation into a moment.
181
182        An operation can be added if the moment with the operation added is
183        valid.
184
185        Args:
186            operation: The operation being added.
187            moment: The moment being transformed.
188
189        Returns:
190            Whether or not the moment will validate after adding the operation.
191
192        Raises:
193            ValueError: If either of the given moment or operation is invalid
194        """
195        if not super().can_add_operation_into_moment(operation, moment):
196            return False
197        try:
198            self.validate_moment(moment.with_operation(operation))
199        except ValueError:
200            return False
201        return True
202
203    def __repr__(self):
204        return f'pasqal.PasqalDevice(qubits={sorted(self.qubits)!r})'
205
206    def _value_equality_values_(self):
207        return self.qubits
208
209    def _json_dict_(self):
210        return cirq.protocols.obj_to_dict_helper(self, ['qubits'])
211
212
213class PasqalVirtualDevice(PasqalDevice):
214    """A Pasqal virtual device with qubits in 3d.
215
216    A virtual representation of a Pasqal device, enforcing the constraints
217    typically found in a physical device. The qubits can be positioned in 3d
218    space, although 2d layouts will be supported sooner and are thus
219    recommended. Only accepts qubits with physical placement.
220    """
221
222    def __init__(
223        self, control_radius: float, qubits: Sequence[Union[ThreeDQubit, GridQubit, LineQubit]]
224    ) -> None:
225        """Initializes a device with some qubits.
226
227        Args:
228            control_radius: the maximum distance between qubits for a controlled
229                gate. Distance is measured in units of the coordinates passed
230                into the qubit constructor.
231            qubits: Qubits on the device, identified by their x, y, z position.
232                Must be of type ThreeDQubit, TwoDQubit, LineQubit or GridQubit.
233
234        Raises:
235            ValueError: if the wrong qubit type is provided or if invalid
236                parameter is provided for control_radius."""
237
238        super().__init__(qubits)
239
240        if not control_radius >= 0:
241            raise ValueError('Control_radius needs to be a non-negative float.')
242
243        if len(self.qubits) > 1:
244            if control_radius > 3.0 * self.minimal_distance():
245                raise ValueError(
246                    'Control_radius cannot be larger than 3 times'
247                    ' the minimal distance between qubits.'
248                )
249
250        self.control_radius = control_radius
251        self.exclude_gateset = cirq.Gateset(
252            cirq.AnyIntegerPowerGateFamily(cirq.CNotPowGate),
253            cirq.AnyIntegerPowerGateFamily(cirq.CCNotPowGate),
254            cirq.AnyIntegerPowerGateFamily(cirq.CCZPowGate),
255        )
256        self.controlled_gateset = cirq.Gateset(
257            *self.exclude_gateset.gates,
258            cirq.AnyIntegerPowerGateFamily(cirq.CZPowGate),
259        )
260
261    @property
262    def supported_qubit_type(self):
263        return (
264            ThreeDQubit,
265            TwoDQubit,
266            GridQubit,
267            LineQubit,
268        )
269
270    def is_pasqal_device_op(self, op: cirq.ops.Operation) -> bool:
271        return super().is_pasqal_device_op(op) and op not in self.exclude_gateset
272
273    def validate_operation(self, operation: cirq.ops.Operation):
274        """Raises an error if the given operation is invalid on this device.
275
276        Args:
277            operation: the operation to validate
278        Raises:
279            ValueError: If the operation is not valid
280        """
281        super().validate_operation(operation)
282
283        # Verify that a controlled gate operation is valid
284        if operation in self.controlled_gateset:
285            for p in operation.qubits:
286                for q in operation.qubits:
287                    if self.distance(p, q) > self.control_radius:
288                        raise ValueError(f"Qubits {p!r}, {q!r} are too far away")
289
290    def validate_moment(self, moment: cirq.ops.Moment):
291        """Raises an error if the given moment is invalid on this device.
292
293        Args:
294            moment: The moment to validate.
295        Raises:
296            ValueError: If the given moment is invalid.
297        """
298
299        super().validate_moment(moment)
300        if len(moment) > 1:
301            for operation in moment:
302                if not isinstance(operation.gate, cirq.ops.MeasurementGate):
303                    raise ValueError("Cannot do simultaneous gates. Use cirq.InsertStrategy.NEW.")
304
305    def minimal_distance(self) -> float:
306        """Returns the minimal distance between two qubits in qubits.
307
308        Args:
309            qubits: qubit involved in the distance computation
310
311        Raises:
312            ValueError: If the device has only one qubit
313
314        Returns:
315            The minimal distance between qubits, in spacial coordinate units.
316        """
317        if len(self.qubits) <= 1:
318            raise ValueError("Two qubits to compute a minimal distance.")
319
320        return min([self.distance(q1, q2) for q1 in self.qubits for q2 in self.qubits if q1 != q2])
321
322    def distance(self, p: Any, q: Any) -> float:
323        """Returns the distance between two qubits.
324
325        Args:
326            p: qubit involved in the distance computation
327            q: qubit involved in the distance computation
328
329        Raises:
330            ValueError: If p or q not part of the device
331
332        Returns:
333            The distance between qubits p and q.
334        """
335        all_qubits = self.qubit_list()
336        if p not in all_qubits or q not in all_qubits:
337            raise ValueError("Qubit not part of the device.")
338
339        if isinstance(p, GridQubit):
340            return np.sqrt((p.row - q.row) ** 2 + (p.col - q.col) ** 2)
341
342        if isinstance(p, LineQubit):
343            return abs(p.x - q.x)
344
345        return np.sqrt((p.x - q.x) ** 2 + (p.y - q.y) ** 2 + (p.z - q.z) ** 2)
346
347    def __repr__(self):
348        return ('pasqal.PasqalVirtualDevice(control_radius={!r}, qubits={!r})').format(
349            self.control_radius, sorted(self.qubits)
350        )
351
352    def _value_equality_values_(self) -> Any:
353        return (self.control_radius, self.qubits)
354
355    def _json_dict_(self) -> Dict[str, Any]:
356        return cirq.protocols.obj_to_dict_helper(self, ['control_radius', 'qubits'])
357
358    def qid_pairs(self) -> FrozenSet['cirq.SymmetricalQidPair']:
359        """Returns a list of qubit edges on the device.
360
361        Returns:
362            All qubit pairs that are less or equal to the control radius apart.
363        """
364        qs = self.qubits
365        return frozenset(
366            [
367                cirq.SymmetricalQidPair(q, q2)
368                for q in qs
369                for q2 in qs
370                if q < q2 and self.distance(q, q2) <= self.control_radius
371            ]
372        )
373
374
375class PasqalConverter(cirq.neutral_atoms.ConvertToNeutralAtomGates):
376    """A gate converter for compatibility with Pasqal processors.
377
378    Modified version of ConvertToNeutralAtomGates, where a new 'convert' method
379    'pasqal_convert' takes the 'keep' function as an input.
380    """
381
382    def pasqal_convert(
383        self, op: cirq.ops.Operation, keep: Callable[[cirq.ops.Operation], bool]
384    ) -> List[cirq.ops.Operation]:
385        def on_stuck_raise(bad):
386            return TypeError(
387                "Don't know how to work with {!r}. "
388                "It isn't a native PasqalDevice operation, "
389                "a 1 or 2 qubit gate with a known unitary, "
390                "or composite.".format(bad)
391            )
392
393        return cirq.protocols.decompose(
394            op,
395            keep=keep,
396            intercepting_decomposer=self._convert_one,
397            on_stuck_raise=None if self.ignore_failures else on_stuck_raise,
398        )
399