1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Support for serializing and deserializing cirq_google.api.v2 protos."""
15
16from typing import Any, Dict, List, Optional
17import sympy
18
19import cirq
20from cirq_google.api import v2
21from cirq_google.ops import PhysicalZTag
22from cirq_google.ops.calibration_tag import CalibrationTag
23from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs
24
25
26class CircuitSerializer(serializer.Serializer):
27    """A class for serializing and deserializing programs and operations.
28
29    This class is for serializing cirq_google.api.v2. protos using one
30    message type per gate type.  It serializes qubits by adding a field
31    into the constants table.  Usage is by passing a `cirq.Circuit`
32    to the `serialize()` method of the class, which will produce a
33    `Program` proto.  Likewise, the `deserialize` method will produce
34    a `cirq.Circuit` object from a `Program` proto.
35
36    This class is more performant than the previous `SerializableGateSet`
37    at the cost of some extendability.
38    """
39
40    def __init__(
41        self,
42        gate_set_name: str,
43    ):
44        """Construct the circuit serializer object.
45
46        Args:
47            gate_set_name: The name used to identify the gate set.
48        """
49        super().__init__(gate_set_name)
50
51    # TODO(#3388) Add documentation for Raises.
52    # pylint: disable=missing-raises-doc
53    def serialize(
54        self,
55        program: cirq.AbstractCircuit,
56        msg: Optional[v2.program_pb2.Program] = None,
57        *,
58        arg_function_language: Optional[str] = None,
59    ) -> v2.program_pb2.Program:
60        """Serialize a Circuit to cirq_google.api.v2.Program proto.
61
62        Args:
63            program: The Circuit to serialize.
64            msg: An optional proto object to populate with the serialization
65                results.
66            arg_function_language: The `arg_function_language` field from
67                `Program.Language`.
68        """
69        if not isinstance(program, cirq.Circuit):
70            raise NotImplementedError(f'Unrecognized program type: {type(program)}')
71        raw_constants: Dict[Any, int] = {}
72        if msg is None:
73            msg = v2.program_pb2.Program()
74        msg.language.gate_set = self.name
75        msg.language.arg_function_language = (
76            arg_function_language or arg_func_langs.MOST_PERMISSIVE_LANGUAGE
77        )
78        self._serialize_circuit(
79            program,
80            msg.circuit,
81            arg_function_language=arg_function_language,
82            constants=msg.constants,
83            raw_constants=raw_constants,
84        )
85        return msg
86
87    # pylint: enable=missing-raises-doc
88    def _serialize_circuit(
89        self,
90        circuit: cirq.AbstractCircuit,
91        msg: v2.program_pb2.Circuit,
92        *,
93        arg_function_language: Optional[str],
94        constants: List[v2.program_pb2.Constant],
95        raw_constants: Dict[Any, int],
96    ) -> None:
97        msg.scheduling_strategy = v2.program_pb2.Circuit.MOMENT_BY_MOMENT
98        for moment in circuit:
99            moment_proto = msg.moments.add()
100            for op in moment:
101                if isinstance(op.untagged, cirq.CircuitOperation):
102                    op_pb = moment_proto.circuit_operations.add()
103                    self._serialize_circuit_op(
104                        op.untagged,
105                        op_pb,
106                        arg_function_language=arg_function_language,
107                        constants=constants,
108                        raw_constants=raw_constants,
109                    )
110                else:
111                    op_pb = moment_proto.operations.add()
112                    self._serialize_gate_op(
113                        op,
114                        op_pb,
115                        arg_function_language=arg_function_language,
116                        constants=constants,
117                        raw_constants=raw_constants,
118                    )
119
120    # TODO(#3388) Add documentation for Raises.
121    # pylint: disable=missing-raises-doc
122    def _serialize_gate_op(
123        self,
124        op: cirq.Operation,
125        msg: v2.program_pb2.Operation,
126        *,
127        constants: List[v2.program_pb2.Constant],
128        raw_constants: Dict[Any, int],
129        arg_function_language: Optional[str] = '',
130    ) -> v2.program_pb2.Operation:
131        """Serialize an Operation to cirq_google.api.v2.Operation proto.
132
133        Args:
134            op: The operation to serialize.
135            msg: An optional proto object to populate with the serialization
136                results.
137            arg_function_language: The `arg_function_language` field from
138                `Program.Language`.
139            constants: The list of previously-serialized Constant protos.
140            raw_constants: A map raw objects to their respective indices in
141                `constants`.
142
143        Returns:
144            The cirq.google.api.v2.Operation proto.
145        """
146        gate = op.gate
147
148        if isinstance(gate, cirq.XPowGate):
149            arg_func_langs.float_arg_to_proto(
150                gate.exponent,
151                out=msg.xpowgate.exponent,
152                arg_function_language=arg_function_language,
153            )
154        elif isinstance(gate, cirq.YPowGate):
155            arg_func_langs.float_arg_to_proto(
156                gate.exponent,
157                out=msg.ypowgate.exponent,
158                arg_function_language=arg_function_language,
159            )
160        elif isinstance(gate, cirq.ZPowGate):
161            arg_func_langs.float_arg_to_proto(
162                gate.exponent,
163                out=msg.zpowgate.exponent,
164                arg_function_language=arg_function_language,
165            )
166            if any(isinstance(tag, PhysicalZTag) for tag in op.tags):
167                msg.zpowgate.is_physical_z = True
168        elif isinstance(gate, cirq.PhasedXPowGate):
169            arg_func_langs.float_arg_to_proto(
170                gate.phase_exponent,
171                out=msg.phasedxpowgate.phase_exponent,
172                arg_function_language=arg_function_language,
173            )
174            arg_func_langs.float_arg_to_proto(
175                gate.exponent,
176                out=msg.phasedxpowgate.exponent,
177                arg_function_language=arg_function_language,
178            )
179        elif isinstance(gate, cirq.PhasedXZGate):
180            arg_func_langs.float_arg_to_proto(
181                gate.x_exponent,
182                out=msg.phasedxzgate.x_exponent,
183                arg_function_language=arg_function_language,
184            )
185            arg_func_langs.float_arg_to_proto(
186                gate.z_exponent,
187                out=msg.phasedxzgate.z_exponent,
188                arg_function_language=arg_function_language,
189            )
190            arg_func_langs.float_arg_to_proto(
191                gate.axis_phase_exponent,
192                out=msg.phasedxzgate.axis_phase_exponent,
193                arg_function_language=arg_function_language,
194            )
195        elif isinstance(gate, cirq.CZPowGate):
196            arg_func_langs.float_arg_to_proto(
197                gate.exponent,
198                out=msg.czpowgate.exponent,
199                arg_function_language=arg_function_language,
200            )
201        elif isinstance(gate, cirq.ISwapPowGate):
202            arg_func_langs.float_arg_to_proto(
203                gate.exponent,
204                out=msg.iswappowgate.exponent,
205                arg_function_language=arg_function_language,
206            )
207        elif isinstance(gate, cirq.FSimGate):
208            arg_func_langs.float_arg_to_proto(
209                gate.theta,
210                out=msg.fsimgate.theta,
211                arg_function_language=arg_function_language,
212            )
213            arg_func_langs.float_arg_to_proto(
214                gate.phi,
215                out=msg.fsimgate.phi,
216                arg_function_language=arg_function_language,
217            )
218        elif isinstance(gate, cirq.MeasurementGate):
219            arg_func_langs.arg_to_proto(
220                gate.key,
221                out=msg.measurementgate.key,
222                arg_function_language=arg_function_language,
223            )
224            arg_func_langs.arg_to_proto(
225                gate.invert_mask,
226                out=msg.measurementgate.invert_mask,
227                arg_function_language=arg_function_language,
228            )
229        elif isinstance(gate, cirq.WaitGate):
230            arg_func_langs.float_arg_to_proto(
231                gate.duration.total_nanos(),
232                out=msg.waitgate.duration_nanos,
233                arg_function_language=arg_function_language,
234            )
235        else:
236            raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')
237
238        for qubit in op.qubits:
239            if qubit not in raw_constants:
240                constants.append(
241                    v2.program_pb2.Constant(
242                        qubit=v2.program_pb2.Qubit(id=v2.qubit_to_proto_id(qubit))
243                    )
244                )
245                raw_constants[qubit] = len(constants) - 1
246            msg.qubit_constant_index.append(raw_constants[qubit])
247
248        for tag in op.tags:
249            if isinstance(tag, CalibrationTag):
250                constant = v2.program_pb2.Constant()
251                constant.string_value = tag.token
252                if tag.token in raw_constants:
253                    msg.token_constant_index = raw_constants[tag.token]
254                else:
255                    # Token not found, add it to the list
256                    msg.token_constant_index = len(constants)
257                    constants.append(constant)
258                    if raw_constants is not None:
259                        raw_constants[tag.token] = msg.token_constant_index
260        return msg
261
262    # TODO(#3388) Add documentation for Raises.
263    def _serialize_circuit_op(
264        self,
265        op: cirq.CircuitOperation,
266        msg: Optional[v2.program_pb2.CircuitOperation] = None,
267        *,
268        arg_function_language: Optional[str] = '',
269        constants: Optional[List[v2.program_pb2.Constant]] = None,
270        raw_constants: Optional[Dict[Any, int]] = None,
271    ) -> v2.program_pb2.CircuitOperation:
272        """Serialize a CircuitOperation to cirq.google.api.v2.CircuitOperation proto.
273
274        Args:
275            op: The circuit operation to serialize.
276            msg: An optional proto object to populate with the serialization
277                results.
278            arg_function_language: The `arg_function_language` field from
279                `Program.Language`.
280            constants: The list of previously-serialized Constant protos.
281            raw_constants: A map raw objects to their respective indices in
282                `constants`.
283
284        Returns:
285            The cirq.google.api.v2.CircuitOperation proto.
286        """
287        circuit = op.circuit
288        if constants is None or raw_constants is None:
289            raise ValueError(
290                'CircuitOp serialization requires a constants list and a corresponding '
291                'map of pre-serialization values to indices (raw_constants).'
292            )
293        serializer = op_serializer.CircuitOpSerializer()
294        if circuit not in raw_constants:
295            subcircuit_msg = v2.program_pb2.Circuit()
296            self._serialize_circuit(
297                circuit,
298                subcircuit_msg,
299                arg_function_language=arg_function_language,
300                constants=constants,
301                raw_constants=raw_constants,
302            )
303            constants.append(v2.program_pb2.Constant(circuit_value=subcircuit_msg))
304            raw_constants[circuit] = len(constants) - 1
305        return serializer.to_proto(
306            op,
307            msg,
308            arg_function_language=arg_function_language,
309            constants=constants,
310            raw_constants=raw_constants,
311        )
312
313    # TODO(#3388) Add documentation for Raises.
314    def deserialize(
315        self, proto: v2.program_pb2.Program, device: Optional[cirq.Device] = None
316    ) -> cirq.Circuit:
317        """Deserialize a Circuit from a cirq_google.api.v2.Program.
318
319        Args:
320            proto: A dictionary representing a cirq_google.api.v2.Program proto.
321            device: If the proto is for a schedule, a device is required
322                Otherwise optional.
323
324        Returns:
325            The deserialized Circuit, with a device if device was
326            not None.
327        """
328        if not proto.HasField('language') or not proto.language.gate_set:
329            raise ValueError('Missing gate set specification.')
330        if proto.language.gate_set != self.name:
331            raise ValueError(
332                'Gate set in proto was {} but expected {}'.format(
333                    proto.language.gate_set, self.name
334                )
335            )
336        which = proto.WhichOneof('program')
337        arg_func_language = (
338            proto.language.arg_function_language or arg_func_langs.MOST_PERMISSIVE_LANGUAGE
339        )
340
341        if which == 'circuit':
342            deserialized_constants: List[Any] = []
343            for constant in proto.constants:
344                which_const = constant.WhichOneof('const_value')
345                if which_const == 'string_value':
346                    deserialized_constants.append(constant.string_value)
347                elif which_const == 'circuit_value':
348                    circuit = self._deserialize_circuit(
349                        constant.circuit_value,
350                        arg_function_language=arg_func_language,
351                        constants=proto.constants,
352                        deserialized_constants=deserialized_constants,
353                    )
354                    deserialized_constants.append(circuit.freeze())
355                elif which_const == 'qubit':
356                    deserialized_constants.append(v2.qubit_from_proto_id(constant.qubit.id))
357            circuit = self._deserialize_circuit(
358                proto.circuit,
359                arg_function_language=arg_func_language,
360                constants=proto.constants,
361                deserialized_constants=deserialized_constants,
362            )
363            return circuit if device is None else circuit.with_device(device)
364        if which == 'schedule':
365            raise ValueError('Deserializing a schedule is no longer supported.')
366
367        raise NotImplementedError('Program proto does not contain a circuit.')
368
369    # pylint: enable=missing-raises-doc
370    def _deserialize_circuit(
371        self,
372        circuit_proto: v2.program_pb2.Circuit,
373        *,
374        arg_function_language: str,
375        constants: List[v2.program_pb2.Constant],
376        deserialized_constants: List[Any],
377    ) -> cirq.Circuit:
378        moments = []
379        for moment_proto in circuit_proto.moments:
380            moment_ops = []
381            for op in moment_proto.operations:
382                moment_ops.append(
383                    self._deserialize_gate_op(
384                        op,
385                        arg_function_language=arg_function_language,
386                        constants=constants,
387                        deserialized_constants=deserialized_constants,
388                    )
389                )
390            for op in moment_proto.circuit_operations:
391                moment_ops.append(
392                    self._deserialize_circuit_op(
393                        op,
394                        arg_function_language=arg_function_language,
395                        constants=constants,
396                        deserialized_constants=deserialized_constants,
397                    )
398                )
399            moments.append(cirq.Moment(moment_ops))
400        return cirq.Circuit(moments)
401
402    # TODO(#3388) Add documentation for Raises.
403    # pylint: disable=missing-raises-doc
404    def _deserialize_gate_op(
405        self,
406        operation_proto: v2.program_pb2.Operation,
407        *,
408        arg_function_language: str = '',
409        constants: Optional[List[v2.program_pb2.Constant]] = None,
410        deserialized_constants: Optional[List[Any]] = None,
411    ) -> cirq.Operation:
412        """Deserialize an Operation from a cirq_google.api.v2.Operation.
413
414        Args:
415            operation_proto: A dictionary representing a
416                cirq.google.api.v2.Operation proto.
417            arg_function_language: The `arg_function_language` field from
418                `Program.Language`.
419            constants: The list of Constant protos referenced by constant
420                table indices in `proto`.
421            deserialized_constants: The deserialized contents of `constants`.
422                cirq_google.api.v2.Operation proto.
423
424        Returns:
425            The deserialized Operation.
426        """
427        if deserialized_constants is not None:
428            qubits = [deserialized_constants[q] for q in operation_proto.qubit_constant_index]
429        else:
430            qubits = []
431        for q in operation_proto.qubits:
432            # Preserve previous functionality in case
433            # constants table was not used
434            qubits.append(v2.qubit_from_proto_id(q.id))
435
436        which_gate_type = operation_proto.WhichOneof('gate_value')
437
438        if which_gate_type == 'xpowgate':
439            op = cirq.XPowGate(
440                exponent=arg_func_langs.float_arg_from_proto(
441                    operation_proto.xpowgate.exponent,
442                    arg_function_language=arg_function_language,
443                    required_arg_name=None,
444                )
445            )(*qubits)
446        elif which_gate_type == 'ypowgate':
447            op = cirq.YPowGate(
448                exponent=arg_func_langs.float_arg_from_proto(
449                    operation_proto.ypowgate.exponent,
450                    arg_function_language=arg_function_language,
451                    required_arg_name=None,
452                )
453            )(*qubits)
454        elif which_gate_type == 'zpowgate':
455            op = cirq.ZPowGate(
456                exponent=arg_func_langs.float_arg_from_proto(
457                    operation_proto.zpowgate.exponent,
458                    arg_function_language=arg_function_language,
459                    required_arg_name=None,
460                )
461            )(*qubits)
462            if operation_proto.zpowgate.is_physical_z:
463                op = op.with_tags(PhysicalZTag())
464        elif which_gate_type == 'phasedxpowgate':
465            exponent = arg_func_langs.float_arg_from_proto(
466                operation_proto.phasedxpowgate.exponent,
467                arg_function_language=arg_function_language,
468                required_arg_name=None,
469            )
470            phase_exponent = arg_func_langs.float_arg_from_proto(
471                operation_proto.phasedxpowgate.phase_exponent,
472                arg_function_language=arg_function_language,
473                required_arg_name=None,
474            )
475            op = cirq.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)(*qubits)
476        elif which_gate_type == 'phasedxzgate':
477            x_exponent = arg_func_langs.float_arg_from_proto(
478                operation_proto.phasedxzgate.x_exponent,
479                arg_function_language=arg_function_language,
480                required_arg_name=None,
481            )
482            z_exponent = arg_func_langs.float_arg_from_proto(
483                operation_proto.phasedxzgate.z_exponent,
484                arg_function_language=arg_function_language,
485                required_arg_name=None,
486            )
487            axis_phase_exponent = arg_func_langs.float_arg_from_proto(
488                operation_proto.phasedxzgate.axis_phase_exponent,
489                arg_function_language=arg_function_language,
490                required_arg_name=None,
491            )
492            op = cirq.PhasedXZGate(
493                x_exponent=x_exponent,
494                z_exponent=z_exponent,
495                axis_phase_exponent=axis_phase_exponent,
496            )(*qubits)
497        elif which_gate_type == 'czpowgate':
498            op = cirq.CZPowGate(
499                exponent=arg_func_langs.float_arg_from_proto(
500                    operation_proto.czpowgate.exponent,
501                    arg_function_language=arg_function_language,
502                    required_arg_name=None,
503                )
504            )(*qubits)
505        elif which_gate_type == 'iswappowgate':
506            op = cirq.ISwapPowGate(
507                exponent=arg_func_langs.float_arg_from_proto(
508                    operation_proto.iswappowgate.exponent,
509                    arg_function_language=arg_function_language,
510                    required_arg_name=None,
511                )
512            )(*qubits)
513        elif which_gate_type == 'fsimgate':
514            theta = arg_func_langs.float_arg_from_proto(
515                operation_proto.fsimgate.theta,
516                arg_function_language=arg_function_language,
517                required_arg_name=None,
518            )
519            phi = arg_func_langs.float_arg_from_proto(
520                operation_proto.fsimgate.phi,
521                arg_function_language=arg_function_language,
522                required_arg_name=None,
523            )
524            if isinstance(theta, (float, sympy.Basic)) and isinstance(phi, (float, sympy.Basic)):
525                op = cirq.FSimGate(theta=theta, phi=phi)(*qubits)
526            else:
527                raise ValueError('theta and phi must be specified for FSimGate')
528        elif which_gate_type == 'measurementgate':
529            key = arg_func_langs.arg_from_proto(
530                operation_proto.measurementgate.key,
531                arg_function_language=arg_function_language,
532                required_arg_name=None,
533            )
534            invert_mask = arg_func_langs.arg_from_proto(
535                operation_proto.measurementgate.invert_mask,
536                arg_function_language=arg_function_language,
537                required_arg_name=None,
538            )
539            if isinstance(invert_mask, list) and isinstance(key, str):
540                op = cirq.MeasurementGate(
541                    num_qubits=len(qubits), key=key, invert_mask=tuple(invert_mask)
542                )(*qubits)
543            else:
544                raise ValueError(f'Incorrect types for measurement gate {invert_mask} {key}')
545
546        elif which_gate_type == 'waitgate':
547            total_nanos = arg_func_langs.float_arg_from_proto(
548                operation_proto.waitgate.duration_nanos,
549                arg_function_language=arg_function_language,
550                required_arg_name=None,
551            )
552            op = cirq.WaitGate(duration=cirq.Duration(nanos=total_nanos))(*qubits)
553        else:
554            raise ValueError(
555                f'Unsupported serialized gate with type "{which_gate_type}".'
556                f'\n\noperation_proto:\n{operation_proto}'
557            )
558
559        which = operation_proto.WhichOneof('token')
560        if which == 'token_constant_index':
561            if not constants:
562                raise ValueError(
563                    'Proto has references to constants table '
564                    'but none was passed in, value ='
565                    f'{operation_proto}'
566                )
567            op = op.with_tags(
568                CalibrationTag(constants[operation_proto.token_constant_index].string_value)
569            )
570        elif which == 'token_value':
571            op = op.with_tags(CalibrationTag(operation_proto.token_value))
572
573        return op
574
575    # pylint: enable=missing-raises-doc
576    def _deserialize_circuit_op(
577        self,
578        operation_proto: v2.program_pb2.CircuitOperation,
579        *,
580        arg_function_language: str = '',
581        constants: Optional[List[v2.program_pb2.Constant]] = None,
582        deserialized_constants: Optional[List[Any]] = None,
583    ) -> cirq.CircuitOperation:
584        """Deserialize a CircuitOperation from a
585            cirq.google.api.v2.CircuitOperation.
586
587        Args:
588            operation_proto: A dictionary representing a
589                cirq.google.api.v2.CircuitOperation proto.
590            arg_function_language: The `arg_function_language` field from
591                `Program.Language`.
592            constants: The list of Constant protos referenced by constant
593                table indices in `proto`.
594            deserialized_constants: The deserialized contents of `constants`.
595
596        Returns:
597            The deserialized CircuitOperation.
598        """
599        return op_deserializer.CircuitOpDeserializer().from_proto(
600            operation_proto,
601            arg_function_language=arg_function_language,
602            constants=constants,
603            deserialized_constants=deserialized_constants,
604        )
605
606
607CIRCUIT_SERIALIZER = CircuitSerializer('v2_5')
608