1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Support for serializing and deserializing cirq_google.api.v2 protos.""" 15 16from itertools import chain 17from typing import ( 18 Any, 19 Dict, 20 Iterable, 21 List, 22 Optional, 23 Tuple, 24 Type, 25 Union, 26) 27 28import cirq 29from cirq._compat import deprecated, deprecated_parameter 30from cirq_google.api import v2 31from cirq_google.serialization import serializer, op_deserializer, op_serializer, arg_func_langs 32 33 34class SerializableGateSet(serializer.Serializer): 35 """A class for serializing and deserializing programs and operations. 36 37 This class is for cirq_google.api.v2. protos. 38 """ 39 40 def __init__( 41 self, 42 gate_set_name: str, 43 serializers: Iterable[op_serializer.OpSerializer], 44 deserializers: Iterable[op_deserializer.OpDeserializer], 45 ): 46 """Construct the gate set. 47 48 Args: 49 gate_set_name: The name used to identify the gate set. 50 serializers: The OpSerializers to use for serialization. 51 Multiple serializers for a given gate type are allowed and 52 will be checked for a given type in the order specified here. 53 This allows for a given gate type to be serialized into 54 different serialized form depending on the parameters of the 55 gate. 56 deserializers: The OpDeserializers to convert serialized 57 forms of gates or circuits into Operations. 58 """ 59 self.gate_set_name = gate_set_name 60 self.serializers: Dict[Type, List[op_serializer.OpSerializer]] = {} 61 for s in serializers: 62 self.serializers.setdefault(s.internal_type, []).append(s) 63 self.deserializers = {d.serialized_id: d for d in deserializers} 64 65 def with_added_types( 66 self, 67 *, 68 gate_set_name: Optional[str] = None, 69 serializers: Iterable[op_serializer.OpSerializer] = (), 70 deserializers: Iterable[op_deserializer.OpDeserializer] = (), 71 ) -> 'SerializableGateSet': 72 """Creates a new gateset with additional (de)serializers. 73 74 Args: 75 gate_set_name: Optional new name of the gateset. If not given, use 76 the same name as this gateset. 77 serializers: Serializers to add to those in this gateset. 78 deserializers: Deserializers to add to those in this gateset. 79 """ 80 # Iterate over all serializers in this gateset. 81 curr_serializers = ( 82 serializer for serializers in self.serializers.values() for serializer in serializers 83 ) 84 return SerializableGateSet( 85 gate_set_name or self.gate_set_name, 86 serializers=[*curr_serializers, *serializers], 87 deserializers=[*self.deserializers.values(), *deserializers], 88 ) 89 90 @deprecated(deadline='v0.13', fix='Use with_added_types instead.') 91 def with_added_gates( 92 self, 93 *, 94 gate_set_name: Optional[str] = None, 95 serializers: Iterable[op_serializer.OpSerializer] = (), 96 deserializers: Iterable[op_deserializer.OpDeserializer] = (), 97 ) -> 'SerializableGateSet': 98 return self.with_added_types( 99 gate_set_name=gate_set_name, 100 serializers=serializers, 101 deserializers=deserializers, 102 ) 103 104 def supported_internal_types(self) -> Tuple: 105 return tuple(self.serializers.keys()) 106 107 @deprecated(deadline='v0.13', fix='Use supported_internal_types instead.') 108 def supported_gate_types(self) -> Tuple: 109 return self.supported_internal_types() 110 111 def is_supported(self, op_tree: cirq.OP_TREE) -> bool: 112 """Whether the given object contains only supported operations.""" 113 return all(self.is_supported_operation(op) for op in cirq.flatten_to_ops(op_tree)) 114 115 def is_supported_operation(self, op: cirq.Operation) -> bool: 116 """Whether or not the given gate can be serialized by this gate set.""" 117 subcircuit = getattr(op.untagged, 'circuit', None) 118 if subcircuit is not None: 119 return self.is_supported(subcircuit) 120 return any( 121 serializer.can_serialize_operation(op) 122 for gate_type in type(op.gate).mro() 123 for serializer in self.serializers.get(gate_type, []) 124 ) 125 126 @deprecated_parameter( 127 deadline='v0.13', 128 fix='Use use_constants instead.', 129 parameter_desc='keyword use_constants_table_for_tokens', 130 match=lambda args, kwargs: 'use_constants_table_for_tokens' in kwargs, 131 rewrite=lambda args, kwargs: ( 132 args, 133 { 134 ('use_constants' if k == 'use_constants_table_for_tokens' else k): v 135 for k, v in kwargs.items() 136 }, 137 ), 138 ) 139 def serialize( 140 self, 141 program: cirq.Circuit, 142 msg: Optional[v2.program_pb2.Program] = None, 143 *, 144 arg_function_language: Optional[str] = None, 145 use_constants: bool = True, 146 ) -> v2.program_pb2.Program: 147 """Serialize a Circuit to cirq_google.api.v2.Program proto. 148 149 Args: 150 program: The Circuit to serialize. 151 msg: An optional proto object to populate with the serialization 152 results. 153 arg_function_language: The `arg_function_language` field from 154 `Program.Language`. 155 use_constants: Whether to use constants in serialization. This is 156 required to be True for serializing CircuitOperations. 157 """ 158 if msg is None: 159 msg = v2.program_pb2.Program() 160 msg.language.gate_set = self.gate_set_name 161 if isinstance(program, cirq.Circuit): 162 constants: Optional[List[v2.program_pb2.Constant]] = [] if use_constants else None 163 raw_constants: Optional[Dict[Any, int]] = {} if use_constants else None 164 self._serialize_circuit( 165 program, 166 msg.circuit, 167 arg_function_language=arg_function_language, 168 constants=constants, 169 raw_constants=raw_constants, 170 ) 171 if constants is not None: 172 msg.constants.extend(constants) 173 if arg_function_language is None: 174 arg_function_language = arg_func_langs._infer_function_language_from_circuit( 175 msg.circuit 176 ) 177 else: 178 raise NotImplementedError(f'Unrecognized program type: {type(program)}') 179 msg.language.arg_function_language = arg_function_language 180 return msg 181 182 def serialize_op( 183 self, 184 op: cirq.Operation, 185 msg: Union[None, v2.program_pb2.Operation, v2.program_pb2.CircuitOperation] = None, 186 **kwargs, 187 ) -> Union[v2.program_pb2.Operation, v2.program_pb2.CircuitOperation]: 188 """Disambiguation for operation serialization.""" 189 if msg is None: 190 if op.gate is not None: 191 return self.serialize_gate_op(op, msg, **kwargs) 192 if hasattr(op.untagged, 'circuit'): 193 return self.serialize_circuit_op(op, msg, **kwargs) 194 raise ValueError(f'Operation is of an unrecognized type: {op!r}') 195 196 if isinstance(msg, v2.program_pb2.Operation): 197 return self.serialize_gate_op(op, msg, **kwargs) 198 if isinstance(msg, v2.program_pb2.CircuitOperation): 199 return self.serialize_circuit_op(op, msg, **kwargs) 200 raise ValueError(f'Operation proto is of an unrecognized type: {msg!r}') 201 202 def serialize_gate_op( 203 self, 204 op: cirq.Operation, 205 msg: Optional[v2.program_pb2.Operation] = None, 206 *, 207 arg_function_language: Optional[str] = '', 208 constants: Optional[List[v2.program_pb2.Constant]] = None, 209 raw_constants: Optional[Dict[Any, int]] = None, 210 ) -> v2.program_pb2.Operation: 211 """Serialize an Operation to cirq_google.api.v2.Operation proto. 212 213 Args: 214 op: The operation to serialize. 215 msg: An optional proto object to populate with the serialization 216 results. 217 arg_function_language: The `arg_function_language` field from 218 `Program.Language`. 219 constants: The list of previously-serialized Constant protos. 220 raw_constants: A map raw objects to their respective indices in 221 `constants`. 222 223 Returns: 224 The cirq.google.api.v2.Operation proto. 225 """ 226 gate_type = type(op.gate) 227 for gate_type_mro in gate_type.mro(): 228 # Check all super classes in method resolution order. 229 if gate_type_mro in self.serializers: 230 # Check each serializer in turn, if serializer proto returns 231 # None, then skip. 232 for serializer in self.serializers[gate_type_mro]: 233 proto_msg = serializer.to_proto( 234 op, 235 msg, 236 arg_function_language=arg_function_language, 237 constants=constants, 238 raw_constants=raw_constants, 239 ) 240 if proto_msg is not None: 241 return proto_msg 242 raise ValueError(f'Cannot serialize op {op!r} of type {gate_type}') 243 244 def serialize_circuit_op( 245 self, 246 op: cirq.Operation, 247 msg: Optional[v2.program_pb2.CircuitOperation] = None, 248 *, 249 arg_function_language: Optional[str] = '', 250 constants: Optional[List[v2.program_pb2.Constant]] = None, 251 raw_constants: Optional[Dict[Any, int]] = None, 252 ) -> Union[v2.program_pb2.Operation, v2.program_pb2.CircuitOperation]: 253 """Serialize a CircuitOperation to cirq.google.api.v2.CircuitOperation proto. 254 255 Args: 256 op: The circuit operation to serialize. 257 msg: An optional proto object to populate with the serialization 258 results. 259 arg_function_language: The `arg_function_language` field from 260 `Program.Language`. 261 constants: The list of previously-serialized Constant protos. 262 raw_constants: A map raw objects to their respective indices in 263 `constants`. 264 265 Returns: 266 The cirq.google.api.v2.CircuitOperation proto. 267 """ 268 circuit = getattr(op.untagged, 'circuit', None) 269 if constants is None or raw_constants is None: 270 raise ValueError( 271 'CircuitOp serialization requires a constants list and a corresponding ' 272 'map of pre-serialization values to indices (raw_constants).' 273 ) 274 if cirq.FrozenCircuit in self.serializers: 275 serializer = self.serializers[cirq.FrozenCircuit][0] 276 if circuit not in raw_constants: 277 subcircuit_msg = v2.program_pb2.Circuit() 278 self._serialize_circuit( 279 circuit, 280 subcircuit_msg, 281 arg_function_language=arg_function_language, 282 constants=constants, 283 raw_constants=raw_constants, 284 ) 285 constants.append(v2.program_pb2.Constant(circuit_value=subcircuit_msg)) 286 raw_constants[circuit] = len(constants) - 1 287 proto_msg = serializer.to_proto( 288 op, 289 msg, 290 arg_function_language=arg_function_language, 291 constants=constants, 292 raw_constants=raw_constants, 293 ) 294 if proto_msg is not None: 295 return proto_msg 296 raise ValueError(f'Cannot serialize CircuitOperation {op!r}') 297 298 def deserialize( 299 self, proto: v2.program_pb2.Program, device: Optional[cirq.Device] = None 300 ) -> cirq.Circuit: 301 """Deserialize a Circuit from a cirq_google.api.v2.Program. 302 303 Args: 304 proto: A dictionary representing a cirq_google.api.v2.Program proto. 305 device: If the proto is for a schedule, a device is required 306 Otherwise optional. 307 308 Returns: 309 The deserialized Circuit, with a device if device was 310 not None. 311 """ 312 if not proto.HasField('language') or not proto.language.gate_set: 313 raise ValueError('Missing gate set specification.') 314 if proto.language.gate_set != self.gate_set_name: 315 raise ValueError( 316 'Gate set in proto was {} but expected {}'.format( 317 proto.language.gate_set, self.gate_set_name 318 ) 319 ) 320 which = proto.WhichOneof('program') 321 if which == 'circuit': 322 deserialized_constants: List[Any] = [] 323 for constant in proto.constants: 324 which_const = constant.WhichOneof('const_value') 325 if which_const == 'string_value': 326 deserialized_constants.append(constant.string_value) 327 elif which_const == 'circuit_value': 328 circuit = self._deserialize_circuit( 329 constant.circuit_value, 330 arg_function_language=proto.language.arg_function_language, 331 constants=proto.constants, 332 deserialized_constants=deserialized_constants, 333 ) 334 deserialized_constants.append(circuit.freeze()) 335 circuit = self._deserialize_circuit( 336 proto.circuit, 337 arg_function_language=proto.language.arg_function_language, 338 constants=proto.constants, 339 deserialized_constants=deserialized_constants, 340 ) 341 return circuit if device is None else circuit.with_device(device) 342 if which == 'schedule': 343 if device is None: 344 raise ValueError('Deserializing schedule requires a device but None was given.') 345 return self._deserialize_schedule( 346 proto.schedule, device, arg_function_language=proto.language.arg_function_language 347 ) 348 349 raise NotImplementedError('Program proto does not contain a circuit.') 350 351 def deserialize_op( 352 self, 353 operation_proto: Union[ 354 v2.program_pb2.Operation, 355 v2.program_pb2.CircuitOperation, 356 ], 357 **kwargs, 358 ) -> cirq.Operation: 359 """Disambiguation for operation deserialization.""" 360 if isinstance(operation_proto, v2.program_pb2.Operation): 361 return self.deserialize_gate_op(operation_proto, **kwargs) 362 363 if isinstance(operation_proto, v2.program_pb2.CircuitOperation): 364 return self.deserialize_circuit_op(operation_proto, **kwargs) 365 366 raise ValueError(f'Operation proto has unknown type: {type(operation_proto)}.') 367 368 def deserialize_gate_op( 369 self, 370 operation_proto: v2.program_pb2.Operation, 371 *, 372 arg_function_language: str = '', 373 constants: Optional[List[v2.program_pb2.Constant]] = None, 374 deserialized_constants: Optional[List[Any]] = None, 375 ) -> cirq.Operation: 376 """Deserialize an Operation from a cirq_google.api.v2.Operation. 377 378 Args: 379 operation_proto: A dictionary representing a 380 cirq.google.api.v2.Operation proto. 381 arg_function_language: The `arg_function_language` field from 382 `Program.Language`. 383 constants: The list of Constant protos referenced by constant 384 table indices in `proto`. 385 deserialized_constants: The deserialized contents of `constants`. 386 cirq_google.api.v2.Operation proto. 387 388 Returns: 389 The deserialized Operation. 390 """ 391 if not operation_proto.gate.id: 392 raise ValueError('Operation proto does not have a gate.') 393 394 gate_id = operation_proto.gate.id 395 deserializer = self.deserializers.get(gate_id, None) 396 if deserializer is None: 397 raise ValueError( 398 f'Unsupported serialized gate with id "{gate_id}".' 399 f'\n\noperation_proto:\n{operation_proto}' 400 ) 401 402 return deserializer.from_proto( 403 operation_proto, 404 arg_function_language=arg_function_language, 405 constants=constants, 406 deserialized_constants=deserialized_constants, 407 ) 408 409 def deserialize_circuit_op( 410 self, 411 operation_proto: v2.program_pb2.CircuitOperation, 412 *, 413 arg_function_language: str = '', 414 constants: Optional[List[v2.program_pb2.Constant]] = None, 415 deserialized_constants: Optional[List[Any]] = None, 416 ) -> cirq.CircuitOperation: 417 """Deserialize a CircuitOperation from a 418 cirq.google.api.v2.CircuitOperation. 419 420 Args: 421 operation_proto: A dictionary representing a 422 cirq.google.api.v2.CircuitOperation proto. 423 arg_function_language: The `arg_function_language` field from 424 `Program.Language`. 425 constants: The list of Constant protos referenced by constant 426 table indices in `proto`. 427 deserialized_constants: The deserialized contents of `constants`. 428 429 Returns: 430 The deserialized CircuitOperation. 431 """ 432 deserializer = self.deserializers.get('circuit', None) 433 if deserializer is None: 434 raise ValueError( 435 f'Unsupported serialized CircuitOperation.\n\noperation_proto:\n{operation_proto}' 436 ) 437 438 if not isinstance(deserializer, op_deserializer.CircuitOpDeserializer): 439 raise ValueError( 440 'Expected CircuitOpDeserializer for id "circuit", ' 441 f'got {deserializer.serialized_id}.' 442 ) 443 444 return deserializer.from_proto( 445 operation_proto, 446 arg_function_language=arg_function_language, 447 constants=constants, 448 deserialized_constants=deserialized_constants, 449 ) 450 451 def _serialize_circuit( 452 self, 453 circuit: cirq.AbstractCircuit, 454 msg: v2.program_pb2.Circuit, 455 *, 456 arg_function_language: Optional[str], 457 constants: Optional[List[v2.program_pb2.Constant]] = None, 458 raw_constants: Optional[Dict[Any, int]] = None, 459 ) -> None: 460 msg.scheduling_strategy = v2.program_pb2.Circuit.MOMENT_BY_MOMENT 461 for moment in circuit: 462 moment_proto = msg.moments.add() 463 for op in moment: 464 if isinstance(op.untagged, cirq.CircuitOperation): 465 op_pb = moment_proto.circuit_operations.add() 466 else: 467 op_pb = moment_proto.operations.add() 468 self.serialize_op( 469 op, 470 op_pb, 471 arg_function_language=arg_function_language, 472 constants=constants, 473 raw_constants=raw_constants, 474 ) 475 476 def _deserialize_circuit( 477 self, 478 circuit_proto: v2.program_pb2.Circuit, 479 *, 480 arg_function_language: str, 481 constants: List[v2.program_pb2.Constant], 482 deserialized_constants: List[Any], 483 ) -> cirq.Circuit: 484 moments = [] 485 for i, moment_proto in enumerate(circuit_proto.moments): 486 moment_ops = [] 487 for op in chain(moment_proto.operations, moment_proto.circuit_operations): 488 try: 489 moment_ops.append( 490 self.deserialize_op( 491 op, 492 arg_function_language=arg_function_language, 493 constants=constants, 494 deserialized_constants=deserialized_constants, 495 ) 496 ) 497 except ValueError as ex: 498 raise ValueError( 499 f'Failed to deserialize circuit. ' 500 f'There was a problem in moment {i} ' 501 f'handling an operation with the ' 502 f'following proto:\n{op}' 503 ) from ex 504 moments.append(cirq.Moment(moment_ops)) 505 return cirq.Circuit(moments) 506 507 def _deserialize_schedule( 508 self, 509 schedule_proto: v2.program_pb2.Schedule, 510 device: cirq.Device, 511 *, 512 arg_function_language: str, 513 ) -> cirq.Circuit: 514 result = [] 515 for scheduled_op_proto in schedule_proto.scheduled_operations: 516 if not scheduled_op_proto.HasField('operation'): 517 raise ValueError(f'Scheduled op missing an operation {scheduled_op_proto}') 518 result.append( 519 self.deserialize_op( 520 scheduled_op_proto.operation, arg_function_language=arg_function_language 521 ) 522 ) 523 return cirq.Circuit(result, device=device) 524