1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Device object for converting from device specification protos"""
15
16from typing import (
17    Any,
18    Callable,
19    cast,
20    Dict,
21    Iterable,
22    Optional,
23    List,
24    Set,
25    Tuple,
26    Type,
27    FrozenSet,
28)
29
30import cirq
31from cirq_google.serialization import serializable_gate_set
32from cirq_google.api import v2
33
34
35class _GateDefinition:
36    """Class for keeping track of gate definitions within SerializableDevice"""
37
38    def __init__(
39        self,
40        duration: cirq.DURATION_LIKE,
41        target_set: Set[Tuple[cirq.Qid, ...]],
42        number_of_qubits: int,
43        is_permutation: bool,
44        can_serialize_predicate: Callable[[cirq.Operation], bool] = lambda x: True,
45    ):
46        self.duration = cirq.Duration(duration)
47        self.target_set = target_set
48        self.is_permutation = is_permutation
49        self.number_of_qubits = number_of_qubits
50        self.can_serialize_predicate = can_serialize_predicate
51
52        # Compute the set of all qubits in all target sets.
53        self.flattened_qubits = {q for qubit_tuple in target_set for q in qubit_tuple}
54
55    def with_can_serialize_predicate(
56        self, can_serialize_predicate: Callable[[cirq.Operation], bool]
57    ) -> '_GateDefinition':
58        """Creates a new _GateDefinition as a copy of the existing definition
59        but with a new with_can_serialize_predicate.  This is useful if multiple
60        definitions exist for the same gate, but with different conditions.
61
62        An example is if gates at certain angles of a gate take longer or are
63        not allowed.
64        """
65        return _GateDefinition(
66            self.duration,
67            self.target_set,
68            self.number_of_qubits,
69            self.is_permutation,
70            can_serialize_predicate,
71        )
72
73    def __eq__(self, other):
74        if not isinstance(other, self.__class__):
75            return NotImplemented
76        return self.__dict__ == other.__dict__
77
78
79class SerializableDevice(cirq.Device):
80    """Device object generated from a device specification proto.
81
82    Given a device specification proto and a gate_set to translate the
83    serialized gate_ids to cirq Gates, this will generate a Device that can
84    verify operations and circuits for the hardware specified by the device.
85
86    Expected usage is through constructing this class through a proto using
87    the static function call from_proto().
88
89    This class only supports GridQubits and NamedQubits.  NamedQubits with names
90    that conflict (such as "4_3") may be converted to GridQubits on
91    deserialization.
92    """
93
94    def __init__(
95        self,
96        qubits: List[cirq.Qid],
97        gate_definitions: Dict[Type[cirq.Gate], List[_GateDefinition]],
98    ):
99        """Constructor for SerializableDevice using python objects.
100
101        Note that the preferred method of constructing this object is through
102        the static from_proto() call.
103
104        Args:
105            qubits: A list of valid Qid for the device.
106            gate_definitions: Maps cirq gates to device properties for that
107                gate.
108        """
109        self.qubits = qubits
110        self.gate_definitions = gate_definitions
111
112    def qubit_set(self) -> FrozenSet[cirq.Qid]:
113        return frozenset(self.qubits)
114
115    # TODO(#3388) Add summary line to docstring.
116    # pylint: disable=docstring-first-line-empty
117    @classmethod
118    def from_proto(
119        cls,
120        proto: v2.device_pb2.DeviceSpecification,
121        gate_sets: Iterable[serializable_gate_set.SerializableGateSet],
122    ) -> 'SerializableDevice':
123        """
124
125        Args:
126            proto: A proto describing the qubits on the device, as well as the
127                supported gates and timing information.
128            gate_set: A SerializableGateSet that can translate the gate_ids
129                into cirq Gates.
130        """
131
132        # Store target sets, since they are referred to by name later
133        allowed_targets: Dict[str, Set[Tuple[cirq.Qid, ...]]] = {}
134        permutation_ids: Set[str] = set()
135        for ts in proto.valid_targets:
136            allowed_targets[ts.name] = cls._create_target_set(ts)
137            if ts.target_ordering == v2.device_pb2.TargetSet.SUBSET_PERMUTATION:
138                permutation_ids.add(ts.name)
139
140        # Store gate definitions from proto
141        gate_definitions: Dict[str, _GateDefinition] = {}
142        for gs in proto.valid_gate_sets:
143            for gate_def in gs.valid_gates:
144                # Combine all valid targets in the gate's listed target sets
145                gate_target_set = {
146                    target
147                    for ts_name in gate_def.valid_targets
148                    for target in allowed_targets[ts_name]
149                }
150                which_are_permutations = [t in permutation_ids for t in gate_def.valid_targets]
151                is_permutation = any(which_are_permutations)
152                if is_permutation:
153                    if not all(which_are_permutations):
154                        raise NotImplementedError(
155                            f'Id {gate_def.id} in {gs.name} mixes '
156                            'SUBSET_PERMUTATION with other types which is not '
157                            'currently allowed.'
158                        )
159                gate_definitions[gate_def.id] = _GateDefinition(
160                    duration=cirq.Duration(picos=gate_def.gate_duration_picos),
161                    target_set=gate_target_set,
162                    is_permutation=is_permutation,
163                    number_of_qubits=gate_def.number_of_qubits,
164                )
165
166        # Loop through serializers and map gate_definitions to type
167        gates_by_type: Dict[Type[cirq.Gate], List[_GateDefinition]] = {}
168        for gate_set in gate_sets:
169            for internal_type in gate_set.supported_internal_types():
170                for serializer in gate_set.serializers[internal_type]:
171                    serialized_id = serializer.serialized_id
172                    if serialized_id not in gate_definitions:
173                        raise ValueError(
174                            f'Serializer has {serialized_id} which is not supported '
175                            'by the device specification'
176                        )
177                    if internal_type not in gates_by_type:
178                        gates_by_type[internal_type] = []
179                    gate_def = gate_definitions[serialized_id].with_can_serialize_predicate(
180                        serializer.can_serialize_predicate
181                    )
182                    gates_by_type[internal_type].append(gate_def)
183
184        return SerializableDevice(
185            qubits=[_qid_from_str(q) for q in proto.valid_qubits],
186            gate_definitions=gates_by_type,
187        )
188
189    # pylint: enable=docstring-first-line-empty
190    @classmethod
191    def _create_target_set(cls, ts: v2.device_pb2.TargetSet) -> Set[Tuple[cirq.Qid, ...]]:
192        """Transform a TargetSet proto into a set of qubit tuples"""
193        target_set = set()
194        for target in ts.targets:
195            qid_tuple = tuple(_qid_from_str(q) for q in target.ids)
196            target_set.add(qid_tuple)
197            if ts.target_ordering == v2.device_pb2.TargetSet.SYMMETRIC:
198                target_set.add(qid_tuple[::-1])
199        return target_set
200
201    def __str__(self) -> str:
202        # If all qubits are grid qubits, render an appropriate text diagram.
203        if all(isinstance(q, cirq.GridQubit) for q in self.qubits):
204            diagram = cirq.TextDiagramDrawer()
205
206            qubits = cast(List[cirq.GridQubit], self.qubits)
207
208            # Don't print out extras newlines if the row/col doesn't start at 0
209            min_col = min(q.col for q in qubits)
210            min_row = min(q.row for q in qubits)
211
212            for q in qubits:
213                diagram.write(q.col - min_col, q.row - min_row, str(q))
214
215            # Find pairs that are connected by two-qubit gates.
216            Pair = Tuple[cirq.GridQubit, cirq.GridQubit]
217            pairs = {
218                cast(Pair, pair)
219                for gate_defs in self.gate_definitions.values()
220                for gate_def in gate_defs
221                if gate_def.number_of_qubits == 2
222                for pair in gate_def.target_set
223                if len(pair) == 2
224            }
225
226            # Draw lines between connected pairs. Limit to horizontal/vertical
227            # lines since that is all the diagram drawer can handle.
228            for q1, q2 in sorted(pairs):
229                if q1.row == q2.row or q1.col == q2.col:
230                    diagram.grid_line(
231                        q1.col - min_col, q1.row - min_row, q2.col - min_col, q2.row - min_row
232                    )
233
234            return diagram.render(
235                horizontal_spacing=3, vertical_spacing=2, use_unicode_characters=True
236            )
237
238        return super().__str__()
239
240    def qid_pairs(self) -> FrozenSet['cirq.SymmetricalQidPair']:
241        """Returns a list of qubit edges on the device, defined by the gate
242        definitions.
243
244        Returns:
245            The list of qubit edges on the device.
246        """
247        return frozenset(
248            [
249                cirq.SymmetricalQidPair(pair[0], pair[1])
250                for gate_defs in self.gate_definitions.values()
251                for gate_def in gate_defs
252                if gate_def.number_of_qubits == 2
253                for pair in gate_def.target_set
254                if len(pair) == 2 and pair[0] < pair[1]
255            ]
256        )
257
258    def _repr_pretty_(self, p: Any, cycle: bool) -> None:
259        """Creates ASCII diagram for Jupyter, IPython, etc."""
260        # There should never be a cycle, but just in case use the default repr.
261        p.text(repr(self) if cycle else str(self))
262
263    def _find_operation_type(self, op: cirq.Operation) -> Optional[_GateDefinition]:
264        """Finds the type (or a compatible type) of an operation from within
265        a dictionary with keys of Gate type.
266
267        Returns:
268             the value corresponding to that key or None if no type matches
269        """
270        for type_key, gate_defs in self.gate_definitions.items():
271            if type_key == cirq.FrozenCircuit and isinstance(op.untagged, cirq.CircuitOperation):
272                for gate_def in gate_defs:
273                    if gate_def.can_serialize_predicate(op):
274                        return gate_def
275            if isinstance(op.gate, type_key):
276                for gate_def in gate_defs:
277                    if gate_def.can_serialize_predicate(op):
278                        return gate_def
279        return None
280
281    def duration_of(self, operation: cirq.Operation) -> cirq.Duration:
282        gate_def = self._find_operation_type(operation)
283        if gate_def is None:
284            raise ValueError(f'Operation {operation} does not have a known duration')
285        return gate_def.duration
286
287    def validate_operation(self, operation: cirq.Operation) -> None:
288        for q in operation.qubits:
289            if q not in self.qubits:
290                raise ValueError(f'Qubit not on device: {q!r}')
291
292        gate_def = self._find_operation_type(operation)
293        if gate_def is None:
294            raise ValueError(f'{operation} is not a supported gate')
295
296        req_num_qubits = gate_def.number_of_qubits
297        if req_num_qubits > 0:
298            if len(operation.qubits) != req_num_qubits:
299                raise ValueError(
300                    f'{operation} has {len(operation.qubits)} '
301                    f'qubits but expected {req_num_qubits}'
302                )
303
304        if gate_def.is_permutation:
305            # A permutation gate can have any combination of qubits
306
307            if not gate_def.target_set:
308                # All qubits are valid
309                return
310
311            if not all(q in gate_def.flattened_qubits for q in operation.qubits):
312                raise ValueError('Operation does not use valid qubits: {operation}.')
313
314            return
315
316        if len(operation.qubits) > 1:
317            # TODO: verify args.
318            # Github issue: https://github.com/quantumlib/Cirq/issues/2964
319
320            if not gate_def.target_set:
321                # All qubit combinations are valid
322                return
323
324            qubit_tuple = tuple(operation.qubits)
325
326            if qubit_tuple not in gate_def.target_set:
327                # Target is not within the target sets specified by the gate.
328                raise ValueError(f'Operation does not use valid qubit target: {operation}.')
329
330
331def _qid_from_str(id_str: str) -> cirq.Qid:
332    """Translates a qubit id string info cirq.Qid objects.
333
334    Tries to translate to GridQubit if possible (e.g. '4_3'), otherwise
335    falls back to using NamedQubit.
336    """
337    try:
338        return v2.grid_qubit_from_proto_id(id_str)
339    except ValueError:
340        return v2.named_qubit_from_proto_id(id_str)
341