1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Any, FrozenSet, Iterable, Optional, Set, TYPE_CHECKING
16
17from cirq import circuits, value, devices, ops, protocols
18from cirq.ion import convert_to_ion_gates
19
20if TYPE_CHECKING:
21    import cirq
22
23
24def get_ion_gateset() -> ops.Gateset:
25    return ops.Gateset(
26        ops.XXPowGate,
27        ops.MeasurementGate,
28        ops.XPowGate,
29        ops.YPowGate,
30        ops.ZPowGate,
31        ops.PhasedXPowGate,
32        unroll_circuit_op=False,
33        accept_global_phase_op=False,
34    )
35
36
37@value.value_equality
38class IonDevice(devices.Device):
39    """A device with qubits placed on a line.
40
41    Qubits have all-to-all connectivity.
42    """
43
44    def __init__(
45        self,
46        measurement_duration: 'cirq.DURATION_LIKE',
47        twoq_gates_duration: 'cirq.DURATION_LIKE',
48        oneq_gates_duration: 'cirq.DURATION_LIKE',
49        qubits: Iterable[devices.LineQubit],
50    ) -> None:
51        """Initializes the description of an ion trap device.
52
53        Args:
54            measurement_duration: The maximum duration of a measurement.
55            twoq_gates_duration: The maximum duration of a two qubit operation.
56            oneq_gates_duration: The maximum duration of a single qubit
57            operation.
58            qubits: Qubits on the device, identified by their x, y location.
59        """
60        self._measurement_duration = value.Duration(measurement_duration)
61        self._twoq_gates_duration = value.Duration(twoq_gates_duration)
62        self._oneq_gates_duration = value.Duration(oneq_gates_duration)
63        self.qubits = frozenset(qubits)
64        self.gateset = get_ion_gateset()
65
66    def qubit_set(self) -> FrozenSet['cirq.LineQubit']:
67        return self.qubits
68
69    def qid_pairs(self) -> FrozenSet['cirq.SymmetricalQidPair']:
70        """Qubits have all-to-all connectivity, so returns all pairs.
71
72        Returns:
73            All qubit pairs on the device.
74        """
75        qs = self.qubits
76        return frozenset([devices.SymmetricalQidPair(q, q2) for q in qs for q2 in qs if q < q2])
77
78    def decompose_operation(self, operation: ops.Operation) -> ops.OP_TREE:
79        return convert_to_ion_gates.ConvertToIonGates().convert_one(operation)
80
81    def decompose_circuit(self, circuit: circuits.Circuit) -> circuits.Circuit:
82        return convert_to_ion_gates.ConvertToIonGates().convert_circuit(circuit)
83
84    def duration_of(self, operation):
85        if isinstance(operation.gate, ops.XXPowGate):
86            return self._twoq_gates_duration
87        if isinstance(
88            operation.gate, (ops.XPowGate, ops.YPowGate, ops.ZPowGate, ops.PhasedXPowGate)
89        ):
90            return self._oneq_gates_duration
91        if isinstance(operation.gate, ops.MeasurementGate):
92            return self._measurement_duration
93        raise ValueError(f'Unsupported gate type: {operation!r}')
94
95    def validate_gate(self, gate: ops.Gate):
96        if gate not in self.gateset:
97            raise ValueError(f'Unsupported gate type: {gate!r}')
98
99    def validate_operation(self, operation):
100        if not isinstance(operation, ops.GateOperation):
101            raise ValueError(f'Unsupported operation: {operation!r}')
102
103        self.validate_gate(operation.gate)
104
105        for q in operation.qubits:
106            if not isinstance(q, devices.LineQubit):
107                raise ValueError(f'Unsupported qubit type: {q!r}')
108            if q not in self.qubits:
109                raise ValueError(f'Qubit not on device: {q!r}')
110
111    def validate_circuit(self, circuit: circuits.AbstractCircuit):
112        super().validate_circuit(circuit)
113        _verify_unique_measurement_keys(circuit.all_operations())
114
115    def at(self, position: int) -> Optional[devices.LineQubit]:
116        """Returns the qubit at the given position, if there is one, else None."""
117        q = devices.LineQubit(position)
118        return q if q in self.qubits else None
119
120    def neighbors_of(self, qubit: devices.LineQubit) -> Iterable[devices.LineQubit]:
121        """Returns the qubits that the given qubit can interact with."""
122        possibles = [
123            devices.LineQubit(qubit.x + 1),
124            devices.LineQubit(qubit.x - 1),
125        ]
126        return [e for e in possibles if e in self.qubits]
127
128    def __repr__(self) -> str:
129        return (
130            f'IonDevice(measurement_duration={self._measurement_duration!r}, '
131            f'twoq_gates_duration={self._twoq_gates_duration!r}, '
132            f'oneq_gates_duration={self._oneq_gates_duration!r} '
133            f'qubits={sorted(self.qubits)!r})'
134        )
135
136    def __str__(self) -> str:
137        diagram = circuits.TextDiagramDrawer()
138
139        for q in self.qubits:
140            diagram.write(q.x, 0, str(q))
141            for q2 in self.neighbors_of(q):
142                diagram.grid_line(q.x, 0, q2.x, 0)
143
144        return diagram.render(horizontal_spacing=3, vertical_spacing=2, use_unicode_characters=True)
145
146    def _value_equality_values_(self) -> Any:
147        return (
148            self._measurement_duration,
149            self._twoq_gates_duration,
150            self._oneq_gates_duration,
151            self.qubits,
152        )
153
154
155def _verify_unique_measurement_keys(operations: Iterable[ops.Operation]):
156    seen: Set[str] = set()
157    for op in operations:
158        if isinstance(op.gate, ops.MeasurementGate):
159            meas = op.gate
160            key = protocols.measurement_key_name(meas)
161            if key in seen:
162                raise ValueError(f'Measurement key {key} repeated')
163            seen.add(key)
164