1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17import sympy
18
19import cirq
20import cirq_google as cg
21import cirq_google.api.v1.programs as programs
22from cirq_google.api.v1 import operations_pb2
23
24
25def assert_proto_dict_convert(gate: cirq.Gate, proto: operations_pb2.Operation, *qubits: cirq.Qid):
26    assert programs.gate_to_proto(gate, qubits, delay=0) == proto
27    assert programs.xmon_op_from_proto(proto) == gate(*qubits)
28
29
30def test_protobuf_round_trip():
31    device = cg.Foxtail
32    circuit = cirq.Circuit(
33        [cirq.X(q) ** 0.5 for q in device.qubits],
34        [cirq.CZ(q, q2) for q in [cirq.GridQubit(0, 0)] for q2 in device.neighbors_of(q)],
35        device=device,
36    )
37
38    protos = list(programs.circuit_as_schedule_to_protos(circuit))
39    s2 = programs.circuit_from_schedule_from_protos(device, protos)
40    assert s2 == circuit
41
42
43def make_bytes(s: str) -> bytes:
44    """Helper function to convert a string of digits into packed bytes.
45
46    Ignores any characters other than 0 and 1, in particular whitespace. The
47    bits are packed in little-endian order within each byte.
48    """
49    buf = []
50    byte = 0
51    idx = 0
52    for c in s:
53        if c == '0':
54            pass
55        elif c == '1':
56            byte |= 1 << idx
57        else:
58            # coverage: ignore
59            continue
60        idx += 1
61        if idx == 8:
62            buf.append(byte)
63            byte = 0
64            idx = 0
65    if idx:
66        buf.append(byte)
67    return bytearray(buf)
68
69
70def test_pack_results():
71    measurements = [
72        (
73            'a',
74            np.array(
75                [
76                    [0, 0, 0],
77                    [0, 0, 1],
78                    [0, 1, 0],
79                    [0, 1, 1],
80                    [1, 0, 0],
81                    [1, 0, 1],
82                    [1, 1, 0],
83                ]
84            ),
85        ),
86        (
87            'b',
88            np.array(
89                [
90                    [0, 0],
91                    [0, 1],
92                    [1, 0],
93                    [1, 1],
94                    [0, 0],
95                    [0, 1],
96                    [1, 0],
97                ]
98            ),
99        ),
100    ]
101    data = programs.pack_results(measurements)
102    expected = make_bytes(
103        """
104        000 00
105        001 01
106        010 10
107        011 11
108        100 00
109        101 01
110        110 10
111
112        000 00 -- padding
113    """
114    )
115    assert data == expected
116
117
118def test_pack_results_no_measurements():
119    assert programs.pack_results([]) == b''
120
121
122def test_pack_results_incompatible_shapes():
123    def bools(*shape):
124        return np.zeros(shape, dtype=bool)
125
126    with pytest.raises(ValueError):
127        programs.pack_results([('a', bools(10))])
128
129    with pytest.raises(ValueError):
130        programs.pack_results([('a', bools(7, 3)), ('b', bools(8, 2))])
131
132
133def test_unpack_results():
134    data = make_bytes(
135        """
136        000 00
137        001 01
138        010 10
139        011 11
140        100 00
141        101 01
142        110 10
143    """
144    )
145    assert len(data) == 5  # 35 data bits + 5 padding bits
146    results = programs.unpack_results(data, 7, [('a', 3), ('b', 2)])
147    assert 'a' in results
148    assert results['a'].shape == (7, 3)
149    assert results['a'].dtype == bool
150    np.testing.assert_array_equal(
151        results['a'],
152        [
153            [0, 0, 0],
154            [0, 0, 1],
155            [0, 1, 0],
156            [0, 1, 1],
157            [1, 0, 0],
158            [1, 0, 1],
159            [1, 1, 0],
160        ],
161    )
162
163    assert 'b' in results
164    assert results['b'].shape == (7, 2)
165    assert results['b'].dtype == bool
166    np.testing.assert_array_equal(
167        results['b'],
168        [
169            [0, 0],
170            [0, 1],
171            [1, 0],
172            [1, 1],
173            [0, 0],
174            [0, 1],
175            [1, 0],
176        ],
177    )
178
179
180def test_single_qubit_measurement_proto_convert():
181    gate = cirq.MeasurementGate(1, 'test')
182    proto = operations_pb2.Operation(
183        measurement=operations_pb2.Measurement(
184            targets=[operations_pb2.Qubit(row=2, col=3)], key='test'
185        )
186    )
187    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
188
189
190def test_single_qubit_measurement_to_proto_convert_invert_mask():
191    gate = cirq.MeasurementGate(1, 'test', invert_mask=(True,))
192    proto = operations_pb2.Operation(
193        measurement=operations_pb2.Measurement(
194            targets=[operations_pb2.Qubit(row=2, col=3)], key='test', invert_mask=[True]
195        )
196    )
197    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
198
199
200def test_single_qubit_measurement_to_proto_pad_invert_mask():
201    gate = cirq.MeasurementGate(2, 'test', invert_mask=(True,))
202    proto = operations_pb2.Operation(
203        measurement=operations_pb2.Measurement(
204            targets=[operations_pb2.Qubit(row=2, col=3), operations_pb2.Qubit(row=2, col=4)],
205            key='test',
206            invert_mask=[True, False],
207        )
208    )
209    assert (
210        programs.gate_to_proto(gate, (cirq.GridQubit(2, 3), cirq.GridQubit(2, 4)), delay=0) == proto
211    )
212
213
214def test_multi_qubit_measurement_to_proto():
215    gate = cirq.MeasurementGate(2, 'test')
216    proto = operations_pb2.Operation(
217        measurement=operations_pb2.Measurement(
218            targets=[operations_pb2.Qubit(row=2, col=3), operations_pb2.Qubit(row=3, col=4)],
219            key='test',
220        )
221    )
222    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3), cirq.GridQubit(3, 4))
223
224
225def test_z_proto_convert():
226    gate = cirq.Z ** sympy.Symbol('k')
227    proto = operations_pb2.Operation(
228        exp_z=operations_pb2.ExpZ(
229            target=operations_pb2.Qubit(row=2, col=3),
230            half_turns=operations_pb2.ParameterizedFloat(parameter_key='k'),
231        )
232    )
233
234    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
235    gate = cirq.Z ** 0.5
236    proto = operations_pb2.Operation(
237        exp_z=operations_pb2.ExpZ(
238            target=operations_pb2.Qubit(row=2, col=3),
239            half_turns=operations_pb2.ParameterizedFloat(raw=0.5),
240        )
241    )
242    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
243
244
245def test_cz_proto_convert():
246    gate = cirq.CZ ** sympy.Symbol('k')
247    proto = operations_pb2.Operation(
248        exp_11=operations_pb2.Exp11(
249            target1=operations_pb2.Qubit(row=2, col=3),
250            target2=operations_pb2.Qubit(row=3, col=4),
251            half_turns=operations_pb2.ParameterizedFloat(parameter_key='k'),
252        )
253    )
254    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3), cirq.GridQubit(3, 4))
255
256    gate = cirq.CZ ** 0.5
257    proto = operations_pb2.Operation(
258        exp_11=operations_pb2.Exp11(
259            target1=operations_pb2.Qubit(row=2, col=3),
260            target2=operations_pb2.Qubit(row=3, col=4),
261            half_turns=operations_pb2.ParameterizedFloat(raw=0.5),
262        )
263    )
264    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3), cirq.GridQubit(3, 4))
265
266
267def test_w_to_proto():
268    gate = cirq.PhasedXPowGate(exponent=sympy.Symbol('k'), phase_exponent=1)
269    proto = operations_pb2.Operation(
270        exp_w=operations_pb2.ExpW(
271            target=operations_pb2.Qubit(row=2, col=3),
272            axis_half_turns=operations_pb2.ParameterizedFloat(raw=1),
273            half_turns=operations_pb2.ParameterizedFloat(parameter_key='k'),
274        )
275    )
276    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
277
278    gate = cirq.PhasedXPowGate(exponent=0.5, phase_exponent=sympy.Symbol('j'))
279    proto = operations_pb2.Operation(
280        exp_w=operations_pb2.ExpW(
281            target=operations_pb2.Qubit(row=2, col=3),
282            axis_half_turns=operations_pb2.ParameterizedFloat(parameter_key='j'),
283            half_turns=operations_pb2.ParameterizedFloat(raw=0.5),
284        )
285    )
286    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
287
288    gate = cirq.X ** 0.25
289    proto = operations_pb2.Operation(
290        exp_w=operations_pb2.ExpW(
291            target=operations_pb2.Qubit(row=2, col=3),
292            axis_half_turns=operations_pb2.ParameterizedFloat(raw=0.0),
293            half_turns=operations_pb2.ParameterizedFloat(raw=0.25),
294        )
295    )
296    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
297
298    gate = cirq.Y ** 0.25
299    proto = operations_pb2.Operation(
300        exp_w=operations_pb2.ExpW(
301            target=operations_pb2.Qubit(row=2, col=3),
302            axis_half_turns=operations_pb2.ParameterizedFloat(raw=0.5),
303            half_turns=operations_pb2.ParameterizedFloat(raw=0.25),
304        )
305    )
306    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
307
308    gate = cirq.PhasedXPowGate(exponent=0.5, phase_exponent=sympy.Symbol('j'))
309    proto = operations_pb2.Operation(
310        exp_w=operations_pb2.ExpW(
311            target=operations_pb2.Qubit(row=2, col=3),
312            axis_half_turns=operations_pb2.ParameterizedFloat(parameter_key='j'),
313            half_turns=operations_pb2.ParameterizedFloat(raw=0.5),
314        )
315    )
316    assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))
317
318
319def test_unsupported_op():
320    with pytest.raises(ValueError, match='invalid operation'):
321        programs.xmon_op_from_proto(operations_pb2.Operation())
322    with pytest.raises(ValueError, match='know how to serialize'):
323        programs.gate_to_proto(
324            cirq.CCZ, (cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)), delay=0
325        )
326
327
328def test_invalid_to_proto_dict_qubit_number():
329    with pytest.raises(ValueError, match='Wrong number of qubits'):
330        _ = programs.gate_to_proto(cirq.CZ ** 0.5, (cirq.GridQubit(2, 3),), delay=0)
331    with pytest.raises(ValueError, match='Wrong number of qubits'):
332        programs.gate_to_proto(cirq.Z ** 0.5, (cirq.GridQubit(2, 3), cirq.GridQubit(3, 4)), delay=0)
333    with pytest.raises(ValueError, match='Wrong number of qubits'):
334        programs.gate_to_proto(
335            cirq.PhasedXPowGate(exponent=0.5, phase_exponent=0),
336            (cirq.GridQubit(2, 3), cirq.GridQubit(3, 4)),
337            delay=0,
338        )
339
340
341def test_parameterized_value_from_proto():
342    from_proto = programs._parameterized_value_from_proto
343
344    m1 = operations_pb2.ParameterizedFloat(raw=5)
345    assert from_proto(m1) == 5
346
347    with pytest.raises(ValueError):
348        from_proto(operations_pb2.ParameterizedFloat())
349
350    m3 = operations_pb2.ParameterizedFloat(parameter_key='rr')
351    assert from_proto(m3) == sympy.Symbol('rr')
352
353
354def test_invalid_measurement_gate():
355    with pytest.raises(ValueError, match='length'):
356        _ = programs.gate_to_proto(
357            cirq.MeasurementGate(3, 'test', invert_mask=(True,)),
358            (cirq.GridQubit(2, 3), cirq.GridQubit(3, 4)),
359            delay=0,
360        )
361    with pytest.raises(ValueError, match='no qubits'):
362        _ = programs.gate_to_proto(cirq.MeasurementGate(1, 'test'), (), delay=0)
363
364
365def test_is_supported():
366    a = cirq.GridQubit(0, 0)
367    b = cirq.GridQubit(0, 1)
368    c = cirq.GridQubit(1, 0)
369    assert programs.is_native_xmon_op(cirq.CZ(a, b))
370    assert programs.is_native_xmon_op(cirq.X(a) ** 0.5)
371    assert programs.is_native_xmon_op(cirq.Y(a) ** 0.5)
372    assert programs.is_native_xmon_op(cirq.Z(a) ** 0.5)
373    assert programs.is_native_xmon_op(cirq.PhasedXPowGate(phase_exponent=0.2).on(a) ** 0.5)
374    assert programs.is_native_xmon_op(cirq.Z(a) ** 1)
375    assert not programs.is_native_xmon_op(cirq.CCZ(a, b, c))
376    assert not programs.is_native_xmon_op(cirq.SWAP(a, b))
377
378
379def test_is_native_xmon_gate():
380    assert programs.is_native_xmon_gate(cirq.CZ)
381    assert programs.is_native_xmon_gate(cirq.X ** 0.5)
382    assert programs.is_native_xmon_gate(cirq.Y ** 0.5)
383    assert programs.is_native_xmon_gate(cirq.Z ** 0.5)
384    assert programs.is_native_xmon_gate(cirq.PhasedXPowGate(phase_exponent=0.2) ** 0.5)
385    assert programs.is_native_xmon_gate(cirq.Z ** 1)
386    assert not programs.is_native_xmon_gate(cirq.CCZ)
387    assert not programs.is_native_xmon_gate(cirq.SWAP)
388