1# Copyright 2021 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Support for serializing and deserializing cirq_google.api.v2 protos.""" 15 16from typing import Any, Dict, List, Optional 17import sympy 18 19import cirq 20from cirq_google.api import v2 21from cirq_google.ops import PhysicalZTag 22from cirq_google.ops.calibration_tag import CalibrationTag 23from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs 24 25 26class CircuitSerializer(serializer.Serializer): 27 """A class for serializing and deserializing programs and operations. 28 29 This class is for serializing cirq_google.api.v2. protos using one 30 message type per gate type. It serializes qubits by adding a field 31 into the constants table. Usage is by passing a `cirq.Circuit` 32 to the `serialize()` method of the class, which will produce a 33 `Program` proto. Likewise, the `deserialize` method will produce 34 a `cirq.Circuit` object from a `Program` proto. 35 36 This class is more performant than the previous `SerializableGateSet` 37 at the cost of some extendability. 38 """ 39 40 def __init__( 41 self, 42 gate_set_name: str, 43 ): 44 """Construct the circuit serializer object. 45 46 Args: 47 gate_set_name: The name used to identify the gate set. 48 """ 49 super().__init__(gate_set_name) 50 51 # TODO(#3388) Add documentation for Raises. 52 # pylint: disable=missing-raises-doc 53 def serialize( 54 self, 55 program: cirq.AbstractCircuit, 56 msg: Optional[v2.program_pb2.Program] = None, 57 *, 58 arg_function_language: Optional[str] = None, 59 ) -> v2.program_pb2.Program: 60 """Serialize a Circuit to cirq_google.api.v2.Program proto. 61 62 Args: 63 program: The Circuit to serialize. 64 msg: An optional proto object to populate with the serialization 65 results. 66 arg_function_language: The `arg_function_language` field from 67 `Program.Language`. 68 """ 69 if not isinstance(program, cirq.Circuit): 70 raise NotImplementedError(f'Unrecognized program type: {type(program)}') 71 raw_constants: Dict[Any, int] = {} 72 if msg is None: 73 msg = v2.program_pb2.Program() 74 msg.language.gate_set = self.name 75 msg.language.arg_function_language = ( 76 arg_function_language or arg_func_langs.MOST_PERMISSIVE_LANGUAGE 77 ) 78 self._serialize_circuit( 79 program, 80 msg.circuit, 81 arg_function_language=arg_function_language, 82 constants=msg.constants, 83 raw_constants=raw_constants, 84 ) 85 return msg 86 87 # pylint: enable=missing-raises-doc 88 def _serialize_circuit( 89 self, 90 circuit: cirq.AbstractCircuit, 91 msg: v2.program_pb2.Circuit, 92 *, 93 arg_function_language: Optional[str], 94 constants: List[v2.program_pb2.Constant], 95 raw_constants: Dict[Any, int], 96 ) -> None: 97 msg.scheduling_strategy = v2.program_pb2.Circuit.MOMENT_BY_MOMENT 98 for moment in circuit: 99 moment_proto = msg.moments.add() 100 for op in moment: 101 if isinstance(op.untagged, cirq.CircuitOperation): 102 op_pb = moment_proto.circuit_operations.add() 103 self._serialize_circuit_op( 104 op.untagged, 105 op_pb, 106 arg_function_language=arg_function_language, 107 constants=constants, 108 raw_constants=raw_constants, 109 ) 110 else: 111 op_pb = moment_proto.operations.add() 112 self._serialize_gate_op( 113 op, 114 op_pb, 115 arg_function_language=arg_function_language, 116 constants=constants, 117 raw_constants=raw_constants, 118 ) 119 120 # TODO(#3388) Add documentation for Raises. 121 # pylint: disable=missing-raises-doc 122 def _serialize_gate_op( 123 self, 124 op: cirq.Operation, 125 msg: v2.program_pb2.Operation, 126 *, 127 constants: List[v2.program_pb2.Constant], 128 raw_constants: Dict[Any, int], 129 arg_function_language: Optional[str] = '', 130 ) -> v2.program_pb2.Operation: 131 """Serialize an Operation to cirq_google.api.v2.Operation proto. 132 133 Args: 134 op: The operation to serialize. 135 msg: An optional proto object to populate with the serialization 136 results. 137 arg_function_language: The `arg_function_language` field from 138 `Program.Language`. 139 constants: The list of previously-serialized Constant protos. 140 raw_constants: A map raw objects to their respective indices in 141 `constants`. 142 143 Returns: 144 The cirq.google.api.v2.Operation proto. 145 """ 146 gate = op.gate 147 148 if isinstance(gate, cirq.XPowGate): 149 arg_func_langs.float_arg_to_proto( 150 gate.exponent, 151 out=msg.xpowgate.exponent, 152 arg_function_language=arg_function_language, 153 ) 154 elif isinstance(gate, cirq.YPowGate): 155 arg_func_langs.float_arg_to_proto( 156 gate.exponent, 157 out=msg.ypowgate.exponent, 158 arg_function_language=arg_function_language, 159 ) 160 elif isinstance(gate, cirq.ZPowGate): 161 arg_func_langs.float_arg_to_proto( 162 gate.exponent, 163 out=msg.zpowgate.exponent, 164 arg_function_language=arg_function_language, 165 ) 166 if any(isinstance(tag, PhysicalZTag) for tag in op.tags): 167 msg.zpowgate.is_physical_z = True 168 elif isinstance(gate, cirq.PhasedXPowGate): 169 arg_func_langs.float_arg_to_proto( 170 gate.phase_exponent, 171 out=msg.phasedxpowgate.phase_exponent, 172 arg_function_language=arg_function_language, 173 ) 174 arg_func_langs.float_arg_to_proto( 175 gate.exponent, 176 out=msg.phasedxpowgate.exponent, 177 arg_function_language=arg_function_language, 178 ) 179 elif isinstance(gate, cirq.PhasedXZGate): 180 arg_func_langs.float_arg_to_proto( 181 gate.x_exponent, 182 out=msg.phasedxzgate.x_exponent, 183 arg_function_language=arg_function_language, 184 ) 185 arg_func_langs.float_arg_to_proto( 186 gate.z_exponent, 187 out=msg.phasedxzgate.z_exponent, 188 arg_function_language=arg_function_language, 189 ) 190 arg_func_langs.float_arg_to_proto( 191 gate.axis_phase_exponent, 192 out=msg.phasedxzgate.axis_phase_exponent, 193 arg_function_language=arg_function_language, 194 ) 195 elif isinstance(gate, cirq.CZPowGate): 196 arg_func_langs.float_arg_to_proto( 197 gate.exponent, 198 out=msg.czpowgate.exponent, 199 arg_function_language=arg_function_language, 200 ) 201 elif isinstance(gate, cirq.ISwapPowGate): 202 arg_func_langs.float_arg_to_proto( 203 gate.exponent, 204 out=msg.iswappowgate.exponent, 205 arg_function_language=arg_function_language, 206 ) 207 elif isinstance(gate, cirq.FSimGate): 208 arg_func_langs.float_arg_to_proto( 209 gate.theta, 210 out=msg.fsimgate.theta, 211 arg_function_language=arg_function_language, 212 ) 213 arg_func_langs.float_arg_to_proto( 214 gate.phi, 215 out=msg.fsimgate.phi, 216 arg_function_language=arg_function_language, 217 ) 218 elif isinstance(gate, cirq.MeasurementGate): 219 arg_func_langs.arg_to_proto( 220 gate.key, 221 out=msg.measurementgate.key, 222 arg_function_language=arg_function_language, 223 ) 224 arg_func_langs.arg_to_proto( 225 gate.invert_mask, 226 out=msg.measurementgate.invert_mask, 227 arg_function_language=arg_function_language, 228 ) 229 elif isinstance(gate, cirq.WaitGate): 230 arg_func_langs.float_arg_to_proto( 231 gate.duration.total_nanos(), 232 out=msg.waitgate.duration_nanos, 233 arg_function_language=arg_function_language, 234 ) 235 else: 236 raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}') 237 238 for qubit in op.qubits: 239 if qubit not in raw_constants: 240 constants.append( 241 v2.program_pb2.Constant( 242 qubit=v2.program_pb2.Qubit(id=v2.qubit_to_proto_id(qubit)) 243 ) 244 ) 245 raw_constants[qubit] = len(constants) - 1 246 msg.qubit_constant_index.append(raw_constants[qubit]) 247 248 for tag in op.tags: 249 if isinstance(tag, CalibrationTag): 250 constant = v2.program_pb2.Constant() 251 constant.string_value = tag.token 252 if tag.token in raw_constants: 253 msg.token_constant_index = raw_constants[tag.token] 254 else: 255 # Token not found, add it to the list 256 msg.token_constant_index = len(constants) 257 constants.append(constant) 258 if raw_constants is not None: 259 raw_constants[tag.token] = msg.token_constant_index 260 return msg 261 262 # TODO(#3388) Add documentation for Raises. 263 def _serialize_circuit_op( 264 self, 265 op: cirq.CircuitOperation, 266 msg: Optional[v2.program_pb2.CircuitOperation] = None, 267 *, 268 arg_function_language: Optional[str] = '', 269 constants: Optional[List[v2.program_pb2.Constant]] = None, 270 raw_constants: Optional[Dict[Any, int]] = None, 271 ) -> v2.program_pb2.CircuitOperation: 272 """Serialize a CircuitOperation to cirq.google.api.v2.CircuitOperation proto. 273 274 Args: 275 op: The circuit operation to serialize. 276 msg: An optional proto object to populate with the serialization 277 results. 278 arg_function_language: The `arg_function_language` field from 279 `Program.Language`. 280 constants: The list of previously-serialized Constant protos. 281 raw_constants: A map raw objects to their respective indices in 282 `constants`. 283 284 Returns: 285 The cirq.google.api.v2.CircuitOperation proto. 286 """ 287 circuit = op.circuit 288 if constants is None or raw_constants is None: 289 raise ValueError( 290 'CircuitOp serialization requires a constants list and a corresponding ' 291 'map of pre-serialization values to indices (raw_constants).' 292 ) 293 serializer = op_serializer.CircuitOpSerializer() 294 if circuit not in raw_constants: 295 subcircuit_msg = v2.program_pb2.Circuit() 296 self._serialize_circuit( 297 circuit, 298 subcircuit_msg, 299 arg_function_language=arg_function_language, 300 constants=constants, 301 raw_constants=raw_constants, 302 ) 303 constants.append(v2.program_pb2.Constant(circuit_value=subcircuit_msg)) 304 raw_constants[circuit] = len(constants) - 1 305 return serializer.to_proto( 306 op, 307 msg, 308 arg_function_language=arg_function_language, 309 constants=constants, 310 raw_constants=raw_constants, 311 ) 312 313 # TODO(#3388) Add documentation for Raises. 314 def deserialize( 315 self, proto: v2.program_pb2.Program, device: Optional[cirq.Device] = None 316 ) -> cirq.Circuit: 317 """Deserialize a Circuit from a cirq_google.api.v2.Program. 318 319 Args: 320 proto: A dictionary representing a cirq_google.api.v2.Program proto. 321 device: If the proto is for a schedule, a device is required 322 Otherwise optional. 323 324 Returns: 325 The deserialized Circuit, with a device if device was 326 not None. 327 """ 328 if not proto.HasField('language') or not proto.language.gate_set: 329 raise ValueError('Missing gate set specification.') 330 if proto.language.gate_set != self.name: 331 raise ValueError( 332 'Gate set in proto was {} but expected {}'.format( 333 proto.language.gate_set, self.name 334 ) 335 ) 336 which = proto.WhichOneof('program') 337 arg_func_language = ( 338 proto.language.arg_function_language or arg_func_langs.MOST_PERMISSIVE_LANGUAGE 339 ) 340 341 if which == 'circuit': 342 deserialized_constants: List[Any] = [] 343 for constant in proto.constants: 344 which_const = constant.WhichOneof('const_value') 345 if which_const == 'string_value': 346 deserialized_constants.append(constant.string_value) 347 elif which_const == 'circuit_value': 348 circuit = self._deserialize_circuit( 349 constant.circuit_value, 350 arg_function_language=arg_func_language, 351 constants=proto.constants, 352 deserialized_constants=deserialized_constants, 353 ) 354 deserialized_constants.append(circuit.freeze()) 355 elif which_const == 'qubit': 356 deserialized_constants.append(v2.qubit_from_proto_id(constant.qubit.id)) 357 circuit = self._deserialize_circuit( 358 proto.circuit, 359 arg_function_language=arg_func_language, 360 constants=proto.constants, 361 deserialized_constants=deserialized_constants, 362 ) 363 return circuit if device is None else circuit.with_device(device) 364 if which == 'schedule': 365 raise ValueError('Deserializing a schedule is no longer supported.') 366 367 raise NotImplementedError('Program proto does not contain a circuit.') 368 369 # pylint: enable=missing-raises-doc 370 def _deserialize_circuit( 371 self, 372 circuit_proto: v2.program_pb2.Circuit, 373 *, 374 arg_function_language: str, 375 constants: List[v2.program_pb2.Constant], 376 deserialized_constants: List[Any], 377 ) -> cirq.Circuit: 378 moments = [] 379 for moment_proto in circuit_proto.moments: 380 moment_ops = [] 381 for op in moment_proto.operations: 382 moment_ops.append( 383 self._deserialize_gate_op( 384 op, 385 arg_function_language=arg_function_language, 386 constants=constants, 387 deserialized_constants=deserialized_constants, 388 ) 389 ) 390 for op in moment_proto.circuit_operations: 391 moment_ops.append( 392 self._deserialize_circuit_op( 393 op, 394 arg_function_language=arg_function_language, 395 constants=constants, 396 deserialized_constants=deserialized_constants, 397 ) 398 ) 399 moments.append(cirq.Moment(moment_ops)) 400 return cirq.Circuit(moments) 401 402 # TODO(#3388) Add documentation for Raises. 403 # pylint: disable=missing-raises-doc 404 def _deserialize_gate_op( 405 self, 406 operation_proto: v2.program_pb2.Operation, 407 *, 408 arg_function_language: str = '', 409 constants: Optional[List[v2.program_pb2.Constant]] = None, 410 deserialized_constants: Optional[List[Any]] = None, 411 ) -> cirq.Operation: 412 """Deserialize an Operation from a cirq_google.api.v2.Operation. 413 414 Args: 415 operation_proto: A dictionary representing a 416 cirq.google.api.v2.Operation proto. 417 arg_function_language: The `arg_function_language` field from 418 `Program.Language`. 419 constants: The list of Constant protos referenced by constant 420 table indices in `proto`. 421 deserialized_constants: The deserialized contents of `constants`. 422 cirq_google.api.v2.Operation proto. 423 424 Returns: 425 The deserialized Operation. 426 """ 427 if deserialized_constants is not None: 428 qubits = [deserialized_constants[q] for q in operation_proto.qubit_constant_index] 429 else: 430 qubits = [] 431 for q in operation_proto.qubits: 432 # Preserve previous functionality in case 433 # constants table was not used 434 qubits.append(v2.qubit_from_proto_id(q.id)) 435 436 which_gate_type = operation_proto.WhichOneof('gate_value') 437 438 if which_gate_type == 'xpowgate': 439 op = cirq.XPowGate( 440 exponent=arg_func_langs.float_arg_from_proto( 441 operation_proto.xpowgate.exponent, 442 arg_function_language=arg_function_language, 443 required_arg_name=None, 444 ) 445 )(*qubits) 446 elif which_gate_type == 'ypowgate': 447 op = cirq.YPowGate( 448 exponent=arg_func_langs.float_arg_from_proto( 449 operation_proto.ypowgate.exponent, 450 arg_function_language=arg_function_language, 451 required_arg_name=None, 452 ) 453 )(*qubits) 454 elif which_gate_type == 'zpowgate': 455 op = cirq.ZPowGate( 456 exponent=arg_func_langs.float_arg_from_proto( 457 operation_proto.zpowgate.exponent, 458 arg_function_language=arg_function_language, 459 required_arg_name=None, 460 ) 461 )(*qubits) 462 if operation_proto.zpowgate.is_physical_z: 463 op = op.with_tags(PhysicalZTag()) 464 elif which_gate_type == 'phasedxpowgate': 465 exponent = arg_func_langs.float_arg_from_proto( 466 operation_proto.phasedxpowgate.exponent, 467 arg_function_language=arg_function_language, 468 required_arg_name=None, 469 ) 470 phase_exponent = arg_func_langs.float_arg_from_proto( 471 operation_proto.phasedxpowgate.phase_exponent, 472 arg_function_language=arg_function_language, 473 required_arg_name=None, 474 ) 475 op = cirq.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)(*qubits) 476 elif which_gate_type == 'phasedxzgate': 477 x_exponent = arg_func_langs.float_arg_from_proto( 478 operation_proto.phasedxzgate.x_exponent, 479 arg_function_language=arg_function_language, 480 required_arg_name=None, 481 ) 482 z_exponent = arg_func_langs.float_arg_from_proto( 483 operation_proto.phasedxzgate.z_exponent, 484 arg_function_language=arg_function_language, 485 required_arg_name=None, 486 ) 487 axis_phase_exponent = arg_func_langs.float_arg_from_proto( 488 operation_proto.phasedxzgate.axis_phase_exponent, 489 arg_function_language=arg_function_language, 490 required_arg_name=None, 491 ) 492 op = cirq.PhasedXZGate( 493 x_exponent=x_exponent, 494 z_exponent=z_exponent, 495 axis_phase_exponent=axis_phase_exponent, 496 )(*qubits) 497 elif which_gate_type == 'czpowgate': 498 op = cirq.CZPowGate( 499 exponent=arg_func_langs.float_arg_from_proto( 500 operation_proto.czpowgate.exponent, 501 arg_function_language=arg_function_language, 502 required_arg_name=None, 503 ) 504 )(*qubits) 505 elif which_gate_type == 'iswappowgate': 506 op = cirq.ISwapPowGate( 507 exponent=arg_func_langs.float_arg_from_proto( 508 operation_proto.iswappowgate.exponent, 509 arg_function_language=arg_function_language, 510 required_arg_name=None, 511 ) 512 )(*qubits) 513 elif which_gate_type == 'fsimgate': 514 theta = arg_func_langs.float_arg_from_proto( 515 operation_proto.fsimgate.theta, 516 arg_function_language=arg_function_language, 517 required_arg_name=None, 518 ) 519 phi = arg_func_langs.float_arg_from_proto( 520 operation_proto.fsimgate.phi, 521 arg_function_language=arg_function_language, 522 required_arg_name=None, 523 ) 524 if isinstance(theta, (float, sympy.Basic)) and isinstance(phi, (float, sympy.Basic)): 525 op = cirq.FSimGate(theta=theta, phi=phi)(*qubits) 526 else: 527 raise ValueError('theta and phi must be specified for FSimGate') 528 elif which_gate_type == 'measurementgate': 529 key = arg_func_langs.arg_from_proto( 530 operation_proto.measurementgate.key, 531 arg_function_language=arg_function_language, 532 required_arg_name=None, 533 ) 534 invert_mask = arg_func_langs.arg_from_proto( 535 operation_proto.measurementgate.invert_mask, 536 arg_function_language=arg_function_language, 537 required_arg_name=None, 538 ) 539 if isinstance(invert_mask, list) and isinstance(key, str): 540 op = cirq.MeasurementGate( 541 num_qubits=len(qubits), key=key, invert_mask=tuple(invert_mask) 542 )(*qubits) 543 else: 544 raise ValueError(f'Incorrect types for measurement gate {invert_mask} {key}') 545 546 elif which_gate_type == 'waitgate': 547 total_nanos = arg_func_langs.float_arg_from_proto( 548 operation_proto.waitgate.duration_nanos, 549 arg_function_language=arg_function_language, 550 required_arg_name=None, 551 ) 552 op = cirq.WaitGate(duration=cirq.Duration(nanos=total_nanos))(*qubits) 553 else: 554 raise ValueError( 555 f'Unsupported serialized gate with type "{which_gate_type}".' 556 f'\n\noperation_proto:\n{operation_proto}' 557 ) 558 559 which = operation_proto.WhichOneof('token') 560 if which == 'token_constant_index': 561 if not constants: 562 raise ValueError( 563 'Proto has references to constants table ' 564 'but none was passed in, value =' 565 f'{operation_proto}' 566 ) 567 op = op.with_tags( 568 CalibrationTag(constants[operation_proto.token_constant_index].string_value) 569 ) 570 elif which == 'token_value': 571 op = op.with_tags(CalibrationTag(operation_proto.token_value)) 572 573 return op 574 575 # pylint: enable=missing-raises-doc 576 def _deserialize_circuit_op( 577 self, 578 operation_proto: v2.program_pb2.CircuitOperation, 579 *, 580 arg_function_language: str = '', 581 constants: Optional[List[v2.program_pb2.Constant]] = None, 582 deserialized_constants: Optional[List[Any]] = None, 583 ) -> cirq.CircuitOperation: 584 """Deserialize a CircuitOperation from a 585 cirq.google.api.v2.CircuitOperation. 586 587 Args: 588 operation_proto: A dictionary representing a 589 cirq.google.api.v2.CircuitOperation proto. 590 arg_function_language: The `arg_function_language` field from 591 `Program.Language`. 592 constants: The list of Constant protos referenced by constant 593 table indices in `proto`. 594 deserialized_constants: The deserialized contents of `constants`. 595 596 Returns: 597 The deserialized CircuitOperation. 598 """ 599 return op_deserializer.CircuitOpDeserializer().from_proto( 600 operation_proto, 601 arg_function_language=arg_function_language, 602 constants=constants, 603 deserialized_constants=deserialized_constants, 604 ) 605 606 607CIRCUIT_SERIALIZER = CircuitSerializer('v2_5') 608