1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Support for serializing and deserializing cirq_google.api.v2 protos."""
15
16from itertools import chain
17from typing import (
18    Any,
19    Dict,
20    Iterable,
21    List,
22    Optional,
23    Tuple,
24    Type,
25    Union,
26)
27
28import cirq
29from cirq._compat import deprecated, deprecated_parameter
30from cirq_google.api import v2
31from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs
32
33
34class SerializableGateSet(serializer.Serializer):
35    """A class for serializing and deserializing programs and operations.
36
37    This class is for cirq_google.api.v2. protos.
38    """
39
40    def __init__(
41        self,
42        gate_set_name: str,
43        serializers: Iterable[op_serializer.OpSerializer],
44        deserializers: Iterable[op_deserializer.OpDeserializer],
45    ):
46        """Construct the gate set.
47
48        Args:
49            gate_set_name: The name used to identify the gate set.
50            serializers: The OpSerializers to use for serialization.
51                Multiple serializers for a given gate type are allowed and
52                will be checked for a given type in the order specified here.
53                This allows for a given gate type to be serialized into
54                different serialized form depending on the parameters of the
55                gate.
56            deserializers: The OpDeserializers to convert serialized
57                forms of gates or circuits into Operations.
58        """
59        self.gate_set_name = gate_set_name
60        self.serializers: Dict[Type, List[op_serializer.OpSerializer]] = {}
61        for s in serializers:
62            self.serializers.setdefault(s.internal_type, []).append(s)
63        self.deserializers = {d.serialized_id: d for d in deserializers}
64
65    def with_added_types(
66        self,
67        *,
68        gate_set_name: Optional[str] = None,
69        serializers: Iterable[op_serializer.OpSerializer] = (),
70        deserializers: Iterable[op_deserializer.OpDeserializer] = (),
71    ) -> 'SerializableGateSet':
72        """Creates a new gateset with additional (de)serializers.
73
74        Args:
75            gate_set_name: Optional new name of the gateset. If not given, use
76                the same name as this gateset.
77            serializers: Serializers to add to those in this gateset.
78            deserializers: Deserializers to add to those in this gateset.
79        """
80        # Iterate over all serializers in this gateset.
81        curr_serializers = (
82            serializer for serializers in self.serializers.values() for serializer in serializers
83        )
84        return SerializableGateSet(
85            gate_set_name or self.gate_set_name,
86            serializers=[*curr_serializers, *serializers],
87            deserializers=[*self.deserializers.values(), *deserializers],
88        )
89
90    @deprecated(deadline='v0.13', fix='Use with_added_types instead.')
91    def with_added_gates(
92        self,
93        *,
94        gate_set_name: Optional[str] = None,
95        serializers: Iterable[op_serializer.OpSerializer] = (),
96        deserializers: Iterable[op_deserializer.OpDeserializer] = (),
97    ) -> 'SerializableGateSet':
98        return self.with_added_types(
99            gate_set_name=gate_set_name,
100            serializers=serializers,
101            deserializers=deserializers,
102        )
103
104    def supported_internal_types(self) -> Tuple:
105        return tuple(self.serializers.keys())
106
107    @deprecated(deadline='v0.13', fix='Use supported_internal_types instead.')
108    def supported_gate_types(self) -> Tuple:
109        return self.supported_internal_types()
110
111    def is_supported(self, op_tree: cirq.OP_TREE) -> bool:
112        """Whether the given object contains only supported operations."""
113        return all(self.is_supported_operation(op) for op in cirq.flatten_to_ops(op_tree))
114
115    def is_supported_operation(self, op: cirq.Operation) -> bool:
116        """Whether or not the given gate can be serialized by this gate set."""
117        subcircuit = getattr(op.untagged, 'circuit', None)
118        if subcircuit is not None:
119            return self.is_supported(subcircuit)
120        return any(
121            serializer.can_serialize_operation(op)
122            for gate_type in type(op.gate).mro()
123            for serializer in self.serializers.get(gate_type, [])
124        )
125
126    @deprecated_parameter(
127        deadline='v0.13',
128        fix='Use use_constants instead.',
129        parameter_desc='keyword use_constants_table_for_tokens',
130        match=lambda args, kwargs: 'use_constants_table_for_tokens' in kwargs,
131        rewrite=lambda args, kwargs: (
132            args,
133            {
134                ('use_constants' if k == 'use_constants_table_for_tokens' else k): v
135                for k, v in kwargs.items()
136            },
137        ),
138    )
139    def serialize(
140        self,
141        program: cirq.Circuit,
142        msg: Optional[v2.program_pb2.Program] = None,
143        *,
144        arg_function_language: Optional[str] = None,
145        use_constants: bool = True,
146    ) -> v2.program_pb2.Program:
147        """Serialize a Circuit to cirq_google.api.v2.Program proto.
148
149        Args:
150            program: The Circuit to serialize.
151            msg: An optional proto object to populate with the serialization
152                results.
153            arg_function_language: The `arg_function_language` field from
154                `Program.Language`.
155            use_constants: Whether to use constants in serialization. This is
156                required to be True for serializing CircuitOperations.
157        """
158        if msg is None:
159            msg = v2.program_pb2.Program()
160        msg.language.gate_set = self.gate_set_name
161        if isinstance(program, cirq.Circuit):
162            constants: Optional[List[v2.program_pb2.Constant]] = [] if use_constants else None
163            raw_constants: Optional[Dict[Any, int]] = {} if use_constants else None
164            self._serialize_circuit(
165                program,
166                msg.circuit,
167                arg_function_language=arg_function_language,
168                constants=constants,
169                raw_constants=raw_constants,
170            )
171            if constants is not None:
172                msg.constants.extend(constants)
173            if arg_function_language is None:
174                arg_function_language = arg_func_langs._infer_function_language_from_circuit(
175                    msg.circuit
176                )
177        else:
178            raise NotImplementedError(f'Unrecognized program type: {type(program)}')
179        msg.language.arg_function_language = arg_function_language
180        return msg
181
182    def serialize_op(
183        self,
184        op: cirq.Operation,
185        msg: Union[None, v2.program_pb2.Operation, v2.program_pb2.CircuitOperation] = None,
186        **kwargs,
187    ) -> Union[v2.program_pb2.Operation, v2.program_pb2.CircuitOperation]:
188        """Disambiguation for operation serialization."""
189        if msg is None:
190            if op.gate is not None:
191                return self.serialize_gate_op(op, msg, **kwargs)
192            if hasattr(op.untagged, 'circuit'):
193                return self.serialize_circuit_op(op, msg, **kwargs)
194            raise ValueError(f'Operation is of an unrecognized type: {op!r}')
195
196        if isinstance(msg, v2.program_pb2.Operation):
197            return self.serialize_gate_op(op, msg, **kwargs)
198        if isinstance(msg, v2.program_pb2.CircuitOperation):
199            return self.serialize_circuit_op(op, msg, **kwargs)
200        raise ValueError(f'Operation proto is of an unrecognized type: {msg!r}')
201
202    def serialize_gate_op(
203        self,
204        op: cirq.Operation,
205        msg: Optional[v2.program_pb2.Operation] = None,
206        *,
207        arg_function_language: Optional[str] = '',
208        constants: Optional[List[v2.program_pb2.Constant]] = None,
209        raw_constants: Optional[Dict[Any, int]] = None,
210    ) -> v2.program_pb2.Operation:
211        """Serialize an Operation to cirq_google.api.v2.Operation proto.
212
213        Args:
214            op: The operation to serialize.
215            msg: An optional proto object to populate with the serialization
216                results.
217            arg_function_language: The `arg_function_language` field from
218                `Program.Language`.
219            constants: The list of previously-serialized Constant protos.
220            raw_constants: A map raw objects to their respective indices in
221                `constants`.
222
223        Returns:
224            The cirq.google.api.v2.Operation proto.
225        """
226        gate_type = type(op.gate)
227        for gate_type_mro in gate_type.mro():
228            # Check all super classes in method resolution order.
229            if gate_type_mro in self.serializers:
230                # Check each serializer in turn, if serializer proto returns
231                # None, then skip.
232                for serializer in self.serializers[gate_type_mro]:
233                    proto_msg = serializer.to_proto(
234                        op,
235                        msg,
236                        arg_function_language=arg_function_language,
237                        constants=constants,
238                        raw_constants=raw_constants,
239                    )
240                    if proto_msg is not None:
241                        return proto_msg
242        raise ValueError(f'Cannot serialize op {op!r} of type {gate_type}')
243
244    def serialize_circuit_op(
245        self,
246        op: cirq.Operation,
247        msg: Optional[v2.program_pb2.CircuitOperation] = None,
248        *,
249        arg_function_language: Optional[str] = '',
250        constants: Optional[List[v2.program_pb2.Constant]] = None,
251        raw_constants: Optional[Dict[Any, int]] = None,
252    ) -> Union[v2.program_pb2.Operation, v2.program_pb2.CircuitOperation]:
253        """Serialize a CircuitOperation to cirq.google.api.v2.CircuitOperation proto.
254
255        Args:
256            op: The circuit operation to serialize.
257            msg: An optional proto object to populate with the serialization
258                results.
259            arg_function_language: The `arg_function_language` field from
260                `Program.Language`.
261            constants: The list of previously-serialized Constant protos.
262            raw_constants: A map raw objects to their respective indices in
263                `constants`.
264
265        Returns:
266            The cirq.google.api.v2.CircuitOperation proto.
267        """
268        circuit = getattr(op.untagged, 'circuit', None)
269        if constants is None or raw_constants is None:
270            raise ValueError(
271                'CircuitOp serialization requires a constants list and a corresponding '
272                'map of pre-serialization values to indices (raw_constants).'
273            )
274        if cirq.FrozenCircuit in self.serializers:
275            serializer = self.serializers[cirq.FrozenCircuit][0]
276            if circuit not in raw_constants:
277                subcircuit_msg = v2.program_pb2.Circuit()
278                self._serialize_circuit(
279                    circuit,
280                    subcircuit_msg,
281                    arg_function_language=arg_function_language,
282                    constants=constants,
283                    raw_constants=raw_constants,
284                )
285                constants.append(v2.program_pb2.Constant(circuit_value=subcircuit_msg))
286                raw_constants[circuit] = len(constants) - 1
287            proto_msg = serializer.to_proto(
288                op,
289                msg,
290                arg_function_language=arg_function_language,
291                constants=constants,
292                raw_constants=raw_constants,
293            )
294            if proto_msg is not None:
295                return proto_msg
296        raise ValueError(f'Cannot serialize CircuitOperation {op!r}')
297
298    def deserialize(
299        self, proto: v2.program_pb2.Program, device: Optional[cirq.Device] = None
300    ) -> cirq.Circuit:
301        """Deserialize a Circuit from a cirq_google.api.v2.Program.
302
303        Args:
304            proto: A dictionary representing a cirq_google.api.v2.Program proto.
305            device: If the proto is for a schedule, a device is required
306                Otherwise optional.
307
308        Returns:
309            The deserialized Circuit, with a device if device was
310            not None.
311        """
312        if not proto.HasField('language') or not proto.language.gate_set:
313            raise ValueError('Missing gate set specification.')
314        if proto.language.gate_set != self.gate_set_name:
315            raise ValueError(
316                'Gate set in proto was {} but expected {}'.format(
317                    proto.language.gate_set, self.gate_set_name
318                )
319            )
320        which = proto.WhichOneof('program')
321        if which == 'circuit':
322            deserialized_constants: List[Any] = []
323            for constant in proto.constants:
324                which_const = constant.WhichOneof('const_value')
325                if which_const == 'string_value':
326                    deserialized_constants.append(constant.string_value)
327                elif which_const == 'circuit_value':
328                    circuit = self._deserialize_circuit(
329                        constant.circuit_value,
330                        arg_function_language=proto.language.arg_function_language,
331                        constants=proto.constants,
332                        deserialized_constants=deserialized_constants,
333                    )
334                    deserialized_constants.append(circuit.freeze())
335            circuit = self._deserialize_circuit(
336                proto.circuit,
337                arg_function_language=proto.language.arg_function_language,
338                constants=proto.constants,
339                deserialized_constants=deserialized_constants,
340            )
341            return circuit if device is None else circuit.with_device(device)
342        if which == 'schedule':
343            if device is None:
344                raise ValueError('Deserializing schedule requires a device but None was given.')
345            return self._deserialize_schedule(
346                proto.schedule, device, arg_function_language=proto.language.arg_function_language
347            )
348
349        raise NotImplementedError('Program proto does not contain a circuit.')
350
351    def deserialize_op(
352        self,
353        operation_proto: Union[
354            v2.program_pb2.Operation,
355            v2.program_pb2.CircuitOperation,
356        ],
357        **kwargs,
358    ) -> cirq.Operation:
359        """Disambiguation for operation deserialization."""
360        if isinstance(operation_proto, v2.program_pb2.Operation):
361            return self.deserialize_gate_op(operation_proto, **kwargs)
362
363        if isinstance(operation_proto, v2.program_pb2.CircuitOperation):
364            return self.deserialize_circuit_op(operation_proto, **kwargs)
365
366        raise ValueError(f'Operation proto has unknown type: {type(operation_proto)}.')
367
368    def deserialize_gate_op(
369        self,
370        operation_proto: v2.program_pb2.Operation,
371        *,
372        arg_function_language: str = '',
373        constants: Optional[List[v2.program_pb2.Constant]] = None,
374        deserialized_constants: Optional[List[Any]] = None,
375    ) -> cirq.Operation:
376        """Deserialize an Operation from a cirq_google.api.v2.Operation.
377
378        Args:
379            operation_proto: A dictionary representing a
380                cirq.google.api.v2.Operation proto.
381            arg_function_language: The `arg_function_language` field from
382                `Program.Language`.
383            constants: The list of Constant protos referenced by constant
384                table indices in `proto`.
385            deserialized_constants: The deserialized contents of `constants`.
386                cirq_google.api.v2.Operation proto.
387
388        Returns:
389            The deserialized Operation.
390        """
391        if not operation_proto.gate.id:
392            raise ValueError('Operation proto does not have a gate.')
393
394        gate_id = operation_proto.gate.id
395        deserializer = self.deserializers.get(gate_id, None)
396        if deserializer is None:
397            raise ValueError(
398                f'Unsupported serialized gate with id "{gate_id}".'
399                f'\n\noperation_proto:\n{operation_proto}'
400            )
401
402        return deserializer.from_proto(
403            operation_proto,
404            arg_function_language=arg_function_language,
405            constants=constants,
406            deserialized_constants=deserialized_constants,
407        )
408
409    def deserialize_circuit_op(
410        self,
411        operation_proto: v2.program_pb2.CircuitOperation,
412        *,
413        arg_function_language: str = '',
414        constants: Optional[List[v2.program_pb2.Constant]] = None,
415        deserialized_constants: Optional[List[Any]] = None,
416    ) -> cirq.CircuitOperation:
417        """Deserialize a CircuitOperation from a
418            cirq.google.api.v2.CircuitOperation.
419
420        Args:
421            operation_proto: A dictionary representing a
422                cirq.google.api.v2.CircuitOperation proto.
423            arg_function_language: The `arg_function_language` field from
424                `Program.Language`.
425            constants: The list of Constant protos referenced by constant
426                table indices in `proto`.
427            deserialized_constants: The deserialized contents of `constants`.
428
429        Returns:
430            The deserialized CircuitOperation.
431        """
432        deserializer = self.deserializers.get('circuit', None)
433        if deserializer is None:
434            raise ValueError(
435                f'Unsupported serialized CircuitOperation.\n\noperation_proto:\n{operation_proto}'
436            )
437
438        if not isinstance(deserializer, op_deserializer.CircuitOpDeserializer):
439            raise ValueError(
440                'Expected CircuitOpDeserializer for id "circuit", '
441                f'got {deserializer.serialized_id}.'
442            )
443
444        return deserializer.from_proto(
445            operation_proto,
446            arg_function_language=arg_function_language,
447            constants=constants,
448            deserialized_constants=deserialized_constants,
449        )
450
451    def _serialize_circuit(
452        self,
453        circuit: cirq.AbstractCircuit,
454        msg: v2.program_pb2.Circuit,
455        *,
456        arg_function_language: Optional[str],
457        constants: Optional[List[v2.program_pb2.Constant]] = None,
458        raw_constants: Optional[Dict[Any, int]] = None,
459    ) -> None:
460        msg.scheduling_strategy = v2.program_pb2.Circuit.MOMENT_BY_MOMENT
461        for moment in circuit:
462            moment_proto = msg.moments.add()
463            for op in moment:
464                if isinstance(op.untagged, cirq.CircuitOperation):
465                    op_pb = moment_proto.circuit_operations.add()
466                else:
467                    op_pb = moment_proto.operations.add()
468                self.serialize_op(
469                    op,
470                    op_pb,
471                    arg_function_language=arg_function_language,
472                    constants=constants,
473                    raw_constants=raw_constants,
474                )
475
476    def _deserialize_circuit(
477        self,
478        circuit_proto: v2.program_pb2.Circuit,
479        *,
480        arg_function_language: str,
481        constants: List[v2.program_pb2.Constant],
482        deserialized_constants: List[Any],
483    ) -> cirq.Circuit:
484        moments = []
485        for i, moment_proto in enumerate(circuit_proto.moments):
486            moment_ops = []
487            for op in chain(moment_proto.operations, moment_proto.circuit_operations):
488                try:
489                    moment_ops.append(
490                        self.deserialize_op(
491                            op,
492                            arg_function_language=arg_function_language,
493                            constants=constants,
494                            deserialized_constants=deserialized_constants,
495                        )
496                    )
497                except ValueError as ex:
498                    raise ValueError(
499                        f'Failed to deserialize circuit. '
500                        f'There was a problem in moment {i} '
501                        f'handling an operation with the '
502                        f'following proto:\n{op}'
503                    ) from ex
504            moments.append(cirq.Moment(moment_ops))
505        return cirq.Circuit(moments)
506
507    def _deserialize_schedule(
508        self,
509        schedule_proto: v2.program_pb2.Schedule,
510        device: cirq.Device,
511        *,
512        arg_function_language: str,
513    ) -> cirq.Circuit:
514        result = []
515        for scheduled_op_proto in schedule_proto.scheduled_operations:
516            if not scheduled_op_proto.HasField('operation'):
517                raise ValueError(f'Scheduled op missing an operation {scheduled_op_proto}')
518            result.append(
519                self.deserialize_op(
520                    scheduled_op_proto.operation, arg_function_language=arg_function_language
521                )
522            )
523        return cirq.Circuit(result, device=device)
524