1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Dict, List
16
17import copy
18import numpy as np
19import pytest
20import sympy
21
22from google.protobuf import json_format
23
24import cirq
25import cirq_google as cg
26from cirq_google.api import v2
27
28
29DEFAULT_TOKEN = 'test_tag'
30
31
32def op_proto(json: Dict) -> v2.program_pb2.Operation:
33    op = v2.program_pb2.Operation()
34    json_format.ParseDict(json, op)
35    return op
36
37
38class GateWithAttribute(cirq.SingleQubitGate):
39    def __init__(self, val):
40        self.val = val
41
42
43class GateWithProperty(cirq.SingleQubitGate):
44    def __init__(self, val, not_req=None):
45        self._val = val
46        self._not_req = not_req
47
48    @property
49    def val(self):
50        return self._val
51
52
53class GateWithMethod(cirq.SingleQubitGate):
54    def __init__(self, val):
55        self._val = val
56
57    def get_val(self):
58        return self._val
59
60
61class SubclassGate(GateWithAttribute):
62
63    pass
64
65
66def get_val(op):
67    return op.gate.get_val()
68
69
70TEST_CASES = (
71    (float, 1.0, {'arg_value': {'float_value': 1.0}}),
72    (str, 'abc', {'arg_value': {'string_value': 'abc'}}),
73    (float, 1, {'arg_value': {'float_value': 1.0}}),
74    (List[bool], [True, False], {'arg_value': {'bool_values': {'values': [True, False]}}}),
75    (List[bool], (True, False), {'arg_value': {'bool_values': {'values': [True, False]}}}),
76    (
77        List[bool],
78        np.array([True, False], dtype=bool),
79        {'arg_value': {'bool_values': {'values': [True, False]}}},
80    ),
81    (sympy.Symbol, sympy.Symbol('x'), {'symbol': 'x'}),
82    (float, sympy.Symbol('x'), {'symbol': 'x'}),
83    (
84        float,
85        sympy.Symbol('x') - sympy.Symbol('y'),
86        {
87            'func': {
88                'type': 'add',
89                'args': [
90                    {'symbol': 'x'},
91                    {
92                        'func': {
93                            'type': 'mul',
94                            'args': [{'arg_value': {'float_value': -1.0}}, {'symbol': 'y'}],
95                        }
96                    },
97                ],
98            }
99        },
100    ),
101)
102
103
104@pytest.mark.parametrize(('val_type', 'val', 'arg_value'), TEST_CASES)
105def test_to_proto_attribute(val_type, val, arg_value):
106    serializer = cg.GateOpSerializer(
107        gate_type=GateWithAttribute,
108        serialized_gate_id='my_gate',
109        args=[
110            cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter='val')
111        ],
112    )
113    q = cirq.GridQubit(1, 2)
114    result = serializer.to_proto(GateWithAttribute(val)(q), arg_function_language='linear')
115    expected = op_proto(
116        {'gate': {'id': 'my_gate'}, 'args': {'my_val': arg_value}, 'qubits': [{'id': '1_2'}]}
117    )
118    assert result == expected
119
120
121@pytest.mark.parametrize(('val_type', 'val', 'arg_value'), TEST_CASES)
122def test_to_proto_property(val_type, val, arg_value):
123    serializer = cg.GateOpSerializer(
124        gate_type=GateWithProperty,
125        serialized_gate_id='my_gate',
126        args=[
127            cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter='val')
128        ],
129    )
130    q = cirq.GridQubit(1, 2)
131    result = serializer.to_proto(GateWithProperty(val)(q), arg_function_language='linear')
132    expected = op_proto(
133        {'gate': {'id': 'my_gate'}, 'args': {'my_val': arg_value}, 'qubits': [{'id': '1_2'}]}
134    )
135    assert result == expected
136
137
138@pytest.mark.parametrize(('val_type', 'val', 'arg_value'), TEST_CASES)
139def test_to_proto_callable(val_type, val, arg_value):
140    serializer = cg.GateOpSerializer(
141        gate_type=GateWithMethod,
142        serialized_gate_id='my_gate',
143        args=[
144            cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter=get_val)
145        ],
146    )
147    q = cirq.GridQubit(1, 2)
148    result = serializer.to_proto(GateWithMethod(val)(q), arg_function_language='linear')
149    expected = op_proto(
150        {'gate': {'id': 'my_gate'}, 'args': {'my_val': arg_value}, 'qubits': [{'id': '1_2'}]}
151    )
152    assert result == expected
153
154
155def test_to_proto_gate_predicate():
156    serializer = cg.GateOpSerializer(
157        gate_type=GateWithAttribute,
158        serialized_gate_id='my_gate',
159        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')],
160        can_serialize_predicate=lambda x: x.gate.val == 1,
161    )
162    q = cirq.GridQubit(1, 2)
163    assert serializer.to_proto(GateWithAttribute(0)(q)) is None
164    assert serializer.to_proto(GateWithAttribute(1)(q)) is not None
165    assert not serializer.can_serialize_operation(GateWithAttribute(0)(q))
166    assert serializer.can_serialize_operation(GateWithAttribute(1)(q))
167
168
169def test_to_proto_gate_mismatch():
170    serializer = cg.GateOpSerializer(
171        gate_type=GateWithProperty,
172        serialized_gate_id='my_gate',
173        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')],
174    )
175    q = cirq.GridQubit(1, 2)
176    with pytest.raises(ValueError, match='GateWithAttribute.*GateWithProperty'):
177        serializer.to_proto(GateWithAttribute(1.0)(q))
178
179
180def test_to_proto_unsupported_type():
181    serializer = cg.GateOpSerializer(
182        gate_type=GateWithProperty,
183        serialized_gate_id='my_gate',
184        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=bytes, op_getter='val')],
185    )
186    q = cirq.GridQubit(1, 2)
187    with pytest.raises(ValueError, match='bytes'):
188        serializer.to_proto(GateWithProperty(b's')(q))
189
190
191def test_to_proto_named_qubit_supported():
192    serializer = cg.GateOpSerializer(
193        gate_type=GateWithProperty,
194        serialized_gate_id='my_gate',
195        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')],
196    )
197    q = cirq.NamedQubit('a')
198    arg_value = 1.0
199    result = serializer.to_proto(GateWithProperty(arg_value)(q))
200
201    expected = op_proto(
202        {
203            'gate': {'id': 'my_gate'},
204            'args': {'my_val': {'arg_value': {'float_value': arg_value}}},
205            'qubits': [{'id': 'a'}],
206        }
207    )
208    assert result == expected
209
210
211def test_to_proto_line_qubit_supported():
212    serializer = cg.GateOpSerializer(
213        gate_type=GateWithProperty,
214        serialized_gate_id='my_gate',
215        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')],
216    )
217    q = cirq.LineQubit('10')
218    arg_value = 1.0
219    result = serializer.to_proto(GateWithProperty(arg_value)(q))
220
221    expected = op_proto(
222        {
223            'gate': {'id': 'my_gate'},
224            'args': {'my_val': {'arg_value': {'float_value': arg_value}}},
225            'qubits': [{'id': '10'}],
226        }
227    )
228    assert result == expected
229
230
231def test_to_proto_required_but_not_present():
232    serializer = cg.GateOpSerializer(
233        gate_type=GateWithProperty,
234        serialized_gate_id='my_gate',
235        args=[
236            cg.SerializingArg(
237                serialized_name='my_val', serialized_type=float, op_getter=lambda x: None
238            )
239        ],
240    )
241    q = cirq.GridQubit(1, 2)
242    with pytest.raises(ValueError, match='required'):
243        serializer.to_proto(GateWithProperty(1.0)(q))
244
245
246def test_to_proto_no_getattr():
247    serializer = cg.GateOpSerializer(
248        gate_type=GateWithProperty,
249        serialized_gate_id='my_gate',
250        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='nope')],
251    )
252    q = cirq.GridQubit(1, 2)
253    with pytest.raises(ValueError, match='does not have'):
254        serializer.to_proto(GateWithProperty(1.0)(q))
255
256
257def test_to_proto_not_required_ok():
258    serializer = cg.GateOpSerializer(
259        gate_type=GateWithProperty,
260        serialized_gate_id='my_gate',
261        args=[
262            cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val'),
263            cg.SerializingArg(
264                serialized_name='not_req',
265                serialized_type=float,
266                op_getter='not_req',
267                required=False,
268            ),
269        ],
270    )
271    expected = op_proto(
272        {
273            'gate': {'id': 'my_gate'},
274            'args': {'my_val': {'arg_value': {'float_value': 0.125}}},
275            'qubits': [{'id': '1_2'}],
276        }
277    )
278
279    q = cirq.GridQubit(1, 2)
280    assert serializer.to_proto(GateWithProperty(0.125)(q)) == expected
281
282
283@pytest.mark.parametrize(
284    ('val_type', 'val'),
285    (
286        (float, 's'),
287        (str, 1.0),
288        (sympy.Symbol, 1.0),
289        (List[bool], [1.0]),
290        (List[bool], 'a'),
291        (List[bool], (1.0,)),
292    ),
293)
294def test_to_proto_type_mismatch(val_type, val):
295    serializer = cg.GateOpSerializer(
296        gate_type=GateWithProperty,
297        serialized_gate_id='my_gate',
298        args=[
299            cg.SerializingArg(serialized_name='my_val', serialized_type=val_type, op_getter='val')
300        ],
301    )
302    q = cirq.GridQubit(1, 2)
303    with pytest.raises(ValueError, match=str(type(val))):
304        serializer.to_proto(GateWithProperty(val)(q))
305
306
307def test_can_serialize_operation_subclass():
308    serializer = cg.GateOpSerializer(
309        gate_type=GateWithAttribute,
310        serialized_gate_id='my_gate',
311        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')],
312        can_serialize_predicate=lambda x: x.gate.val == 1,
313    )
314    q = cirq.GridQubit(1, 1)
315    assert serializer.can_serialize_operation(SubclassGate(1)(q))
316    assert not serializer.can_serialize_operation(SubclassGate(0)(q))
317
318
319def test_defaults_not_serialized():
320    serializer = cg.GateOpSerializer(
321        gate_type=GateWithAttribute,
322        serialized_gate_id='my_gate',
323        args=[
324            cg.SerializingArg(
325                serialized_name='my_val', serialized_type=float, default=1.0, op_getter='val'
326            )
327        ],
328    )
329    q = cirq.GridQubit(1, 2)
330    no_default = op_proto(
331        {
332            'gate': {'id': 'my_gate'},
333            'args': {'my_val': {'arg_value': {'float_value': 0.125}}},
334            'qubits': [{'id': '1_2'}],
335        }
336    )
337    assert no_default == serializer.to_proto(GateWithAttribute(0.125)(q))
338    with_default = op_proto({'gate': {'id': 'my_gate'}, 'qubits': [{'id': '1_2'}]})
339    assert with_default == serializer.to_proto(GateWithAttribute(1.0)(q))
340
341
342def test_token_serialization():
343    serializer = cg.GateOpSerializer(
344        gate_type=GateWithAttribute,
345        serialized_gate_id='my_gate',
346        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')],
347    )
348    q = cirq.GridQubit(1, 2)
349    tag = cg.CalibrationTag('my_token')
350    expected = op_proto(
351        {
352            'gate': {'id': 'my_gate'},
353            'args': {'my_val': {'arg_value': {'float_value': 0.125}}},
354            'qubits': [{'id': '1_2'}],
355            'token_value': 'my_token',
356        }
357    )
358    assert expected == serializer.to_proto(GateWithAttribute(0.125)(q).with_tags(tag))
359
360
361ONE_CONSTANT = [v2.program_pb2.Constant(string_value='my_token')]
362TWO_CONSTANTS = [
363    v2.program_pb2.Constant(string_value='other_token'),
364    v2.program_pb2.Constant(string_value='my_token'),
365]
366
367
368@pytest.mark.parametrize(
369    ('constants', 'expected_index', 'expected_constants'),
370    (
371        ([], 0, ONE_CONSTANT),
372        (ONE_CONSTANT, 0, ONE_CONSTANT),
373        (TWO_CONSTANTS, 1, TWO_CONSTANTS),
374    ),
375)
376def test_token_serialization_with_constant_reference(constants, expected_index, expected_constants):
377    serializer = cg.GateOpSerializer(
378        gate_type=GateWithAttribute,
379        serialized_gate_id='my_gate',
380        args=[cg.SerializingArg(serialized_name='my_val', serialized_type=float, op_getter='val')],
381    )
382    # Make a local copy since we are modifying the array in-place.
383    constants = copy.copy(constants)
384    q = cirq.GridQubit(1, 2)
385    tag = cg.CalibrationTag('my_token')
386    expected = op_proto(
387        {
388            'gate': {'id': 'my_gate'},
389            'args': {'my_val': {'arg_value': {'float_value': 0.125}}},
390            'qubits': [{'id': '1_2'}],
391            'token_constant_index': expected_index,
392        }
393    )
394    assert expected == serializer.to_proto(
395        GateWithAttribute(0.125)(q).with_tags(tag), constants=constants
396    )
397    assert constants == expected_constants
398
399
400def default_circuit_proto():
401    op1 = v2.program_pb2.Operation()
402    op1.gate.id = 'x_pow'
403    op1.args['half_turns'].arg_value.string_value = 'k'
404    op1.qubits.add().id = '1_1'
405
406    op2 = v2.program_pb2.Operation()
407    op2.gate.id = 'x_pow'
408    op2.args['half_turns'].arg_value.float_value = 1.0
409    op2.qubits.add().id = '1_2'
410    op2.token_constant_index = 0
411
412    return v2.program_pb2.Circuit(
413        scheduling_strategy=v2.program_pb2.Circuit.MOMENT_BY_MOMENT,
414        moments=[
415            v2.program_pb2.Moment(
416                operations=[op1, op2],
417            ),
418        ],
419    )
420
421
422def default_circuit():
423    return cirq.FrozenCircuit(
424        cirq.X(cirq.GridQubit(1, 1)) ** sympy.Symbol('k'),
425        cirq.X(cirq.GridQubit(1, 2)).with_tags(DEFAULT_TOKEN),
426        cirq.measure(cirq.GridQubit(1, 1), key='m'),
427    )
428
429
430def test_circuit_op_serializer_properties():
431    serializer = cg.CircuitOpSerializer()
432    assert serializer.internal_type == cirq.FrozenCircuit
433    assert serializer.serialized_id == 'circuit'
434
435
436def test_can_serialize_circuit_op():
437    serializer = cg.CircuitOpSerializer()
438    assert serializer.can_serialize_operation(cirq.CircuitOperation(default_circuit()))
439    assert not serializer.can_serialize_operation(cirq.X(cirq.GridQubit(1, 1)))
440
441
442def test_circuit_op_to_proto_errors():
443    serializer = cg.CircuitOpSerializer()
444    to_serialize = cirq.CircuitOperation(default_circuit())
445
446    constants = [
447        v2.program_pb2.Constant(string_value=DEFAULT_TOKEN),
448        v2.program_pb2.Constant(circuit_value=default_circuit_proto()),
449    ]
450    raw_constants = {
451        DEFAULT_TOKEN: 0,
452        default_circuit(): 1,
453    }
454
455    with pytest.raises(ValueError, match='CircuitOp serialization requires a constants list'):
456        serializer.to_proto(to_serialize)
457
458    with pytest.raises(ValueError, match='CircuitOp serialization requires a constants list'):
459        serializer.to_proto(to_serialize, constants=constants)
460
461    with pytest.raises(ValueError, match='CircuitOp serialization requires a constants list'):
462        serializer.to_proto(to_serialize, raw_constants=raw_constants)
463
464    with pytest.raises(ValueError, match='Serializer expected CircuitOperation'):
465        serializer.to_proto(
466            v2.program_pb2.Operation(), constants=constants, raw_constants=raw_constants
467        )
468
469    bad_raw_constants = {cirq.FrozenCircuit(): 0}
470    with pytest.raises(ValueError, match='Encountered a circuit not in the constants table'):
471        serializer.to_proto(to_serialize, constants=constants, raw_constants=bad_raw_constants)
472
473
474@pytest.mark.parametrize('repetitions', [1, 5, ['a', 'b', 'c']])
475def test_circuit_op_to_proto(repetitions):
476    serializer = cg.CircuitOpSerializer()
477    if isinstance(repetitions, int):
478        repetition_ids = None
479    else:
480        repetition_ids = repetitions
481        repetitions = len(repetition_ids)
482    to_serialize = cirq.CircuitOperation(
483        circuit=default_circuit(),
484        qubit_map={cirq.GridQubit(1, 1): cirq.GridQubit(1, 2)},
485        measurement_key_map={'m': 'results'},
486        param_resolver={'k': 1.0},
487        repetitions=repetitions,
488        repetition_ids=repetition_ids,
489    )
490
491    constants = [
492        v2.program_pb2.Constant(string_value=DEFAULT_TOKEN),
493        v2.program_pb2.Constant(circuit_value=default_circuit_proto()),
494    ]
495    raw_constants = {
496        DEFAULT_TOKEN: 0,
497        default_circuit(): 1,
498    }
499
500    repetition_spec = v2.program_pb2.RepetitionSpecification()
501    if repetition_ids is None:
502        repetition_spec.repetition_count = repetitions
503    else:
504        for rep_id in repetition_ids:
505            repetition_spec.repetition_ids.ids.append(rep_id)
506
507    qubit_map = v2.program_pb2.QubitMapping()
508    q_p1 = qubit_map.entries.add()
509    q_p1.key.id = '1_1'
510    q_p1.value.id = '1_2'
511
512    measurement_key_map = v2.program_pb2.MeasurementKeyMapping()
513    meas_p1 = measurement_key_map.entries.add()
514    meas_p1.key.string_key = 'm'
515    meas_p1.value.string_key = 'results'
516
517    arg_map = v2.program_pb2.ArgMapping()
518    arg_p1 = arg_map.entries.add()
519    arg_p1.key.arg_value.string_value = 'k'
520    arg_p1.value.arg_value.float_value = 1.0
521
522    expected = v2.program_pb2.CircuitOperation(
523        circuit_constant_index=1,
524        repetition_specification=repetition_spec,
525        qubit_map=qubit_map,
526        measurement_key_map=measurement_key_map,
527        arg_map=arg_map,
528    )
529    actual = serializer.to_proto(to_serialize, constants=constants, raw_constants=raw_constants)
530    assert actual == expected
531