1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import Dict, List 16 17import copy 18import numpy as np 19import pytest 20import sympy 21 22from google.protobuf import json_format 23 24import cirq 25import cirq_google as cg 26from cirq_google.api import v2 27 28 29DEFAULT_TOKEN = 'test_tag' 30 31 32def op_proto(json: Dict) -> v2.program_pb2.Operation: 33 op = v2.program_pb2.Operation() 34 json_format.ParseDict(json, op) 35 return op 36 37 38class GateWithAttribute(cirq.SingleQubitGate): 39 def __init__(self, val): 40 self.val = val 41 42 43class GateWithProperty(cirq.SingleQubitGate): 44 def __init__(self, val, not_req=None): 45 self._val = val 46 self._not_req = not_req 47 48 @property 49 def val(self): 50 return self._val 51 52 53class GateWithMethod(cirq.SingleQubitGate): 54 def __init__(self, val): 55 self._val = val 56 57 def get_val(self): 58 return self._val 59 60 61class SubclassGate(GateWithAttribute): 62 63 pass 64 65 66def get_val(op): 67 return op.gate.get_val() 68 69 70TEST_CASES = ( 71 (float, 1.0, {'arg_value': {'float_value': 1.0}}), 72 (str, 'abc', {'arg_value': {'string_value': 'abc'}}), 73 (float, 1, {'arg_value': {'float_value': 1.0}}), 74 (List[bool], [True, False], {'arg_value': {'bool_values': {'values': [True, False]}}}), 75 (List[bool], (True, False), {'arg_value': {'bool_values': {'values': [True, False]}}}), 76 ( 77 List[bool], 78 np.array([True, False], dtype=bool), 79 {'arg_value': {'bool_values': {'values': [True, False]}}}, 80 ), 81 (sympy.Symbol, sympy.Symbol('x'), {'symbol': 'x'}), 82 (float, sympy.Symbol('x'), {'symbol': 'x'}), 83 ( 84 float, 85 sympy.Symbol('x') - sympy.Symbol('y'), 86 { 87 'func': { 88 'type': 'add', 89 'args': [ 90 {'symbol': 'x'}, 91 { 92 'func': { 93 'type': 'mul', 94 'args': [{'arg_value': {'float_value': -1.0}}, {'symbol': 'y'}], 95 } 96 }, 97 ], 98 } 99 }, 100 ), 101) 102 103 104@pytest.mark.parametrize(('val_type', 'val', 'arg_value'), TEST_CASES) 105def test_to_proto_attribute(val_type, val, arg_value): 106 serializer = cg.GateOpSerializer( 107 gate_type=GateWithAttribute, 108 serialized_gate_id='my_gate', 109 args=[ 110 cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter='val') 111 ], 112 ) 113 q = cirq.GridQubit(1, 2) 114 result = serializer.to_proto(GateWithAttribute(val)(q), arg_function_language='linear') 115 expected = op_proto( 116 {'gate': {'id': 'my_gate'}, 'args': {'my_val': arg_value}, 'qubits': [{'id': '1_2'}]} 117 ) 118 assert result == expected 119 120 121@pytest.mark.parametrize(('val_type', 'val', 'arg_value'), TEST_CASES) 122def test_to_proto_property(val_type, val, arg_value): 123 serializer = cg.GateOpSerializer( 124 gate_type=GateWithProperty, 125 serialized_gate_id='my_gate', 126 args=[ 127 cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter='val') 128 ], 129 ) 130 q = cirq.GridQubit(1, 2) 131 result = serializer.to_proto(GateWithProperty(val)(q), arg_function_language='linear') 132 expected = op_proto( 133 {'gate': {'id': 'my_gate'}, 'args': {'my_val': arg_value}, 'qubits': [{'id': '1_2'}]} 134 ) 135 assert result == expected 136 137 138@pytest.mark.parametrize(('val_type', 'val', 'arg_value'), TEST_CASES) 139def test_to_proto_callable(val_type, val, arg_value): 140 serializer = cg.GateOpSerializer( 141 gate_type=GateWithMethod, 142 serialized_gate_id='my_gate', 143 args=[ 144 cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter=get_val) 145 ], 146 ) 147 q = cirq.GridQubit(1, 2) 148 result = serializer.to_proto(GateWithMethod(val)(q), arg_function_language='linear') 149 expected = op_proto( 150 {'gate': {'id': 'my_gate'}, 'args': {'my_val': arg_value}, 'qubits': [{'id': '1_2'}]} 151 ) 152 assert result == expected 153 154 155def test_to_proto_gate_predicate(): 156 serializer = cg.GateOpSerializer( 157 gate_type=GateWithAttribute, 158 serialized_gate_id='my_gate', 159 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')], 160 can_serialize_predicate=lambda x: x.gate.val == 1, 161 ) 162 q = cirq.GridQubit(1, 2) 163 assert serializer.to_proto(GateWithAttribute(0)(q)) is None 164 assert serializer.to_proto(GateWithAttribute(1)(q)) is not None 165 assert not serializer.can_serialize_operation(GateWithAttribute(0)(q)) 166 assert serializer.can_serialize_operation(GateWithAttribute(1)(q)) 167 168 169def test_to_proto_gate_mismatch(): 170 serializer = cg.GateOpSerializer( 171 gate_type=GateWithProperty, 172 serialized_gate_id='my_gate', 173 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')], 174 ) 175 q = cirq.GridQubit(1, 2) 176 with pytest.raises(ValueError, match='GateWithAttribute.*GateWithProperty'): 177 serializer.to_proto(GateWithAttribute(1.0)(q)) 178 179 180def test_to_proto_unsupported_type(): 181 serializer = cg.GateOpSerializer( 182 gate_type=GateWithProperty, 183 serialized_gate_id='my_gate', 184 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=bytes, op_getter='val')], 185 ) 186 q = cirq.GridQubit(1, 2) 187 with pytest.raises(ValueError, match='bytes'): 188 serializer.to_proto(GateWithProperty(b's')(q)) 189 190 191def test_to_proto_named_qubit_supported(): 192 serializer = cg.GateOpSerializer( 193 gate_type=GateWithProperty, 194 serialized_gate_id='my_gate', 195 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')], 196 ) 197 q = cirq.NamedQubit('a') 198 arg_value = 1.0 199 result = serializer.to_proto(GateWithProperty(arg_value)(q)) 200 201 expected = op_proto( 202 { 203 'gate': {'id': 'my_gate'}, 204 'args': {'my_val': {'arg_value': {'float_value': arg_value}}}, 205 'qubits': [{'id': 'a'}], 206 } 207 ) 208 assert result == expected 209 210 211def test_to_proto_line_qubit_supported(): 212 serializer = cg.GateOpSerializer( 213 gate_type=GateWithProperty, 214 serialized_gate_id='my_gate', 215 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')], 216 ) 217 q = cirq.LineQubit('10') 218 arg_value = 1.0 219 result = serializer.to_proto(GateWithProperty(arg_value)(q)) 220 221 expected = op_proto( 222 { 223 'gate': {'id': 'my_gate'}, 224 'args': {'my_val': {'arg_value': {'float_value': arg_value}}}, 225 'qubits': [{'id': '10'}], 226 } 227 ) 228 assert result == expected 229 230 231def test_to_proto_required_but_not_present(): 232 serializer = cg.GateOpSerializer( 233 gate_type=GateWithProperty, 234 serialized_gate_id='my_gate', 235 args=[ 236 cg.SerializingArg( 237 serialized_name='my_val', serialized_type=float, op_getter=lambda x: None 238 ) 239 ], 240 ) 241 q = cirq.GridQubit(1, 2) 242 with pytest.raises(ValueError, match='required'): 243 serializer.to_proto(GateWithProperty(1.0)(q)) 244 245 246def test_to_proto_no_getattr(): 247 serializer = cg.GateOpSerializer( 248 gate_type=GateWithProperty, 249 serialized_gate_id='my_gate', 250 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='nope')], 251 ) 252 q = cirq.GridQubit(1, 2) 253 with pytest.raises(ValueError, match='does not have'): 254 serializer.to_proto(GateWithProperty(1.0)(q)) 255 256 257def test_to_proto_not_required_ok(): 258 serializer = cg.GateOpSerializer( 259 gate_type=GateWithProperty, 260 serialized_gate_id='my_gate', 261 args=[ 262 cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val'), 263 cg.SerializingArg( 264 serialized_name='not_req', 265 serialized_type=float, 266 op_getter='not_req', 267 required=False, 268 ), 269 ], 270 ) 271 expected = op_proto( 272 { 273 'gate': {'id': 'my_gate'}, 274 'args': {'my_val': {'arg_value': {'float_value': 0.125}}}, 275 'qubits': [{'id': '1_2'}], 276 } 277 ) 278 279 q = cirq.GridQubit(1, 2) 280 assert serializer.to_proto(GateWithProperty(0.125)(q)) == expected 281 282 283@pytest.mark.parametrize( 284 ('val_type', 'val'), 285 ( 286 (float, 's'), 287 (str, 1.0), 288 (sympy.Symbol, 1.0), 289 (List[bool], [1.0]), 290 (List[bool], 'a'), 291 (List[bool], (1.0,)), 292 ), 293) 294def test_to_proto_type_mismatch(val_type, val): 295 serializer = cg.GateOpSerializer( 296 gate_type=GateWithProperty, 297 serialized_gate_id='my_gate', 298 args=[ 299 cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter='val') 300 ], 301 ) 302 q = cirq.GridQubit(1, 2) 303 with pytest.raises(ValueError, match=str(type(val))): 304 serializer.to_proto(GateWithProperty(val)(q)) 305 306 307def test_can_serialize_operation_subclass(): 308 serializer = cg.GateOpSerializer( 309 gate_type=GateWithAttribute, 310 serialized_gate_id='my_gate', 311 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')], 312 can_serialize_predicate=lambda x: x.gate.val == 1, 313 ) 314 q = cirq.GridQubit(1, 1) 315 assert serializer.can_serialize_operation(SubclassGate(1)(q)) 316 assert not serializer.can_serialize_operation(SubclassGate(0)(q)) 317 318 319def test_defaults_not_serialized(): 320 serializer = cg.GateOpSerializer( 321 gate_type=GateWithAttribute, 322 serialized_gate_id='my_gate', 323 args=[ 324 cg.SerializingArg( 325 serialized_name='my_val', serialized_type=float, default=1.0, op_getter='val' 326 ) 327 ], 328 ) 329 q = cirq.GridQubit(1, 2) 330 no_default = op_proto( 331 { 332 'gate': {'id': 'my_gate'}, 333 'args': {'my_val': {'arg_value': {'float_value': 0.125}}}, 334 'qubits': [{'id': '1_2'}], 335 } 336 ) 337 assert no_default == serializer.to_proto(GateWithAttribute(0.125)(q)) 338 with_default = op_proto({'gate': {'id': 'my_gate'}, 'qubits': [{'id': '1_2'}]}) 339 assert with_default == serializer.to_proto(GateWithAttribute(1.0)(q)) 340 341 342def test_token_serialization(): 343 serializer = cg.GateOpSerializer( 344 gate_type=GateWithAttribute, 345 serialized_gate_id='my_gate', 346 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')], 347 ) 348 q = cirq.GridQubit(1, 2) 349 tag = cg.CalibrationTag('my_token') 350 expected = op_proto( 351 { 352 'gate': {'id': 'my_gate'}, 353 'args': {'my_val': {'arg_value': {'float_value': 0.125}}}, 354 'qubits': [{'id': '1_2'}], 355 'token_value': 'my_token', 356 } 357 ) 358 assert expected == serializer.to_proto(GateWithAttribute(0.125)(q).with_tags(tag)) 359 360 361ONE_CONSTANT = [v2.program_pb2.Constant(string_value='my_token')] 362TWO_CONSTANTS = [ 363 v2.program_pb2.Constant(string_value='other_token'), 364 v2.program_pb2.Constant(string_value='my_token'), 365] 366 367 368@pytest.mark.parametrize( 369 ('constants', 'expected_index', 'expected_constants'), 370 ( 371 ([], 0, ONE_CONSTANT), 372 (ONE_CONSTANT, 0, ONE_CONSTANT), 373 (TWO_CONSTANTS, 1, TWO_CONSTANTS), 374 ), 375) 376def test_token_serialization_with_constant_reference(constants, expected_index, expected_constants): 377 serializer = cg.GateOpSerializer( 378 gate_type=GateWithAttribute, 379 serialized_gate_id='my_gate', 380 args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')], 381 ) 382 # Make a local copy since we are modifying the array in-place. 383 constants = copy.copy(constants) 384 q = cirq.GridQubit(1, 2) 385 tag = cg.CalibrationTag('my_token') 386 expected = op_proto( 387 { 388 'gate': {'id': 'my_gate'}, 389 'args': {'my_val': {'arg_value': {'float_value': 0.125}}}, 390 'qubits': [{'id': '1_2'}], 391 'token_constant_index': expected_index, 392 } 393 ) 394 assert expected == serializer.to_proto( 395 GateWithAttribute(0.125)(q).with_tags(tag), constants=constants 396 ) 397 assert constants == expected_constants 398 399 400def default_circuit_proto(): 401 op1 = v2.program_pb2.Operation() 402 op1.gate.id = 'x_pow' 403 op1.args['half_turns'].arg_value.string_value = 'k' 404 op1.qubits.add().id = '1_1' 405 406 op2 = v2.program_pb2.Operation() 407 op2.gate.id = 'x_pow' 408 op2.args['half_turns'].arg_value.float_value = 1.0 409 op2.qubits.add().id = '1_2' 410 op2.token_constant_index = 0 411 412 return v2.program_pb2.Circuit( 413 scheduling_strategy=v2.program_pb2.Circuit.MOMENT_BY_MOMENT, 414 moments=[ 415 v2.program_pb2.Moment( 416 operations=[op1, op2], 417 ), 418 ], 419 ) 420 421 422def default_circuit(): 423 return cirq.FrozenCircuit( 424 cirq.X(cirq.GridQubit(1, 1)) ** sympy.Symbol('k'), 425 cirq.X(cirq.GridQubit(1, 2)).with_tags(DEFAULT_TOKEN), 426 cirq.measure(cirq.GridQubit(1, 1), key='m'), 427 ) 428 429 430def test_circuit_op_serializer_properties(): 431 serializer = cg.CircuitOpSerializer() 432 assert serializer.internal_type == cirq.FrozenCircuit 433 assert serializer.serialized_id == 'circuit' 434 435 436def test_can_serialize_circuit_op(): 437 serializer = cg.CircuitOpSerializer() 438 assert serializer.can_serialize_operation(cirq.CircuitOperation(default_circuit())) 439 assert not serializer.can_serialize_operation(cirq.X(cirq.GridQubit(1, 1))) 440 441 442def test_circuit_op_to_proto_errors(): 443 serializer = cg.CircuitOpSerializer() 444 to_serialize = cirq.CircuitOperation(default_circuit()) 445 446 constants = [ 447 v2.program_pb2.Constant(string_value=DEFAULT_TOKEN), 448 v2.program_pb2.Constant(circuit_value=default_circuit_proto()), 449 ] 450 raw_constants = { 451 DEFAULT_TOKEN: 0, 452 default_circuit(): 1, 453 } 454 455 with pytest.raises(ValueError, match='CircuitOp serialization requires a constants list'): 456 serializer.to_proto(to_serialize) 457 458 with pytest.raises(ValueError, match='CircuitOp serialization requires a constants list'): 459 serializer.to_proto(to_serialize, constants=constants) 460 461 with pytest.raises(ValueError, match='CircuitOp serialization requires a constants list'): 462 serializer.to_proto(to_serialize, raw_constants=raw_constants) 463 464 with pytest.raises(ValueError, match='Serializer expected CircuitOperation'): 465 serializer.to_proto( 466 v2.program_pb2.Operation(), constants=constants, raw_constants=raw_constants 467 ) 468 469 bad_raw_constants = {cirq.FrozenCircuit(): 0} 470 with pytest.raises(ValueError, match='Encountered a circuit not in the constants table'): 471 serializer.to_proto(to_serialize, constants=constants, raw_constants=bad_raw_constants) 472 473 474@pytest.mark.parametrize('repetitions', [1, 5, ['a', 'b', 'c']]) 475def test_circuit_op_to_proto(repetitions): 476 serializer = cg.CircuitOpSerializer() 477 if isinstance(repetitions, int): 478 repetition_ids = None 479 else: 480 repetition_ids = repetitions 481 repetitions = len(repetition_ids) 482 to_serialize = cirq.CircuitOperation( 483 circuit=default_circuit(), 484 qubit_map={cirq.GridQubit(1, 1): cirq.GridQubit(1, 2)}, 485 measurement_key_map={'m': 'results'}, 486 param_resolver={'k': 1.0}, 487 repetitions=repetitions, 488 repetition_ids=repetition_ids, 489 ) 490 491 constants = [ 492 v2.program_pb2.Constant(string_value=DEFAULT_TOKEN), 493 v2.program_pb2.Constant(circuit_value=default_circuit_proto()), 494 ] 495 raw_constants = { 496 DEFAULT_TOKEN: 0, 497 default_circuit(): 1, 498 } 499 500 repetition_spec = v2.program_pb2.RepetitionSpecification() 501 if repetition_ids is None: 502 repetition_spec.repetition_count = repetitions 503 else: 504 for rep_id in repetition_ids: 505 repetition_spec.repetition_ids.ids.append(rep_id) 506 507 qubit_map = v2.program_pb2.QubitMapping() 508 q_p1 = qubit_map.entries.add() 509 q_p1.key.id = '1_1' 510 q_p1.value.id = '1_2' 511 512 measurement_key_map = v2.program_pb2.MeasurementKeyMapping() 513 meas_p1 = measurement_key_map.entries.add() 514 meas_p1.key.string_key = 'm' 515 meas_p1.value.string_key = 'results' 516 517 arg_map = v2.program_pb2.ArgMapping() 518 arg_p1 = arg_map.entries.add() 519 arg_p1.key.arg_value.string_value = 'k' 520 arg_p1.value.arg_value.float_value = 1.0 521 522 expected = v2.program_pb2.CircuitOperation( 523 circuit_constant_index=1, 524 repetition_specification=repetition_spec, 525 qubit_map=qubit_map, 526 measurement_key_map=measurement_key_map, 527 arg_map=arg_map, 528 ) 529 actual = serializer.to_proto(to_serialize, constants=constants, raw_constants=raw_constants) 530 assert actual == expected 531