1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17
18import cirq
19
20
21@pytest.mark.parametrize(
22    'key',
23    [
24        'q0_1_0',
25        cirq.MeasurementKey(name='q0_1_0'),
26        cirq.MeasurementKey(path=('a', 'b'), name='c'),
27    ],
28)
29def test_eval_repr(key):
30    # Basic safeguard against repr-inequality.
31    op = cirq.GateOperation(
32        gate=cirq.MeasurementGate(1, key),
33        qubits=[cirq.GridQubit(0, 1)],
34    )
35    cirq.testing.assert_equivalent_repr(op)
36
37
38@pytest.mark.parametrize('num_qubits', [1, 2, 4])
39def test_measure_init(num_qubits):
40    assert cirq.MeasurementGate(num_qubits, 'a').num_qubits() == num_qubits
41    assert cirq.MeasurementGate(num_qubits, key='a').key == 'a'
42    assert cirq.MeasurementGate(num_qubits, key='a').mkey == cirq.MeasurementKey('a')
43    assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')).key == 'a'
44    assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')) == cirq.MeasurementGate(
45        num_qubits, key='a'
46    )
47    assert cirq.MeasurementGate(num_qubits, 'a', invert_mask=(True,)).invert_mask == (True,)
48    assert cirq.qid_shape(cirq.MeasurementGate(num_qubits, 'a')) == (2,) * num_qubits
49    assert cirq.qid_shape(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) == (1, 2, 3)
50    assert cirq.qid_shape(cirq.MeasurementGate(key='a', qid_shape=(1, 2, 3))) == (1, 2, 3)
51    with pytest.raises(ValueError, match='len.* >'):
52        cirq.MeasurementGate(5, 'a', invert_mask=(True,) * 6)
53    with pytest.raises(ValueError, match='len.* !='):
54        cirq.MeasurementGate(5, 'a', qid_shape=(1, 2))
55    with pytest.raises(ValueError, match='valid string'):
56        cirq.MeasurementGate(2, qid_shape=(1, 2), key=None)
57    with pytest.raises(ValueError, match='Specify either'):
58        cirq.MeasurementGate()
59
60
61@pytest.mark.parametrize('num_qubits', [1, 2, 4])
62def test_has_stabilizer_effect(num_qubits):
63    assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a'))
64
65
66def test_measurement_eq():
67    eq = cirq.testing.EqualsTester()
68    eq.make_equality_group(
69        lambda: cirq.MeasurementGate(1, 'a'),
70        lambda: cirq.MeasurementGate(1, 'a', invert_mask=()),
71        lambda: cirq.MeasurementGate(1, 'a', qid_shape=(2,)),
72    )
73    eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(True,)))
74    eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(False,)))
75    eq.add_equality_group(cirq.MeasurementGate(1, 'b'))
76    eq.add_equality_group(cirq.MeasurementGate(2, 'a'))
77    eq.add_equality_group(
78        cirq.MeasurementGate(3, 'a'), cirq.MeasurementGate(3, 'a', qid_shape=(2, 2, 2))
79    )
80    eq.add_equality_group(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3)))
81
82
83def test_measurement_full_invert_mask():
84    assert cirq.MeasurementGate(1, 'a').full_invert_mask() == (False,)
85    assert cirq.MeasurementGate(2, 'a', invert_mask=(False, True)).full_invert_mask() == (
86        False,
87        True,
88    )
89    assert cirq.MeasurementGate(2, 'a', invert_mask=(True,)).full_invert_mask() == (True, False)
90
91
92@pytest.mark.parametrize('use_protocol', [False, True])
93@pytest.mark.parametrize(
94    'gate',
95    [
96        cirq.MeasurementGate(1, 'a'),
97        cirq.MeasurementGate(1, 'a', invert_mask=(True,)),
98        cirq.MeasurementGate(1, 'a', qid_shape=(3,)),
99        cirq.MeasurementGate(2, 'a', invert_mask=(True, False), qid_shape=(2, 3)),
100    ],
101)
102def test_measurement_with_key(use_protocol, gate):
103    if use_protocol:
104        gate1 = cirq.with_measurement_key_mapping(gate, {'a': 'b'})
105    else:
106        gate1 = gate.with_key('b')
107    assert gate1.key == 'b'
108    assert gate1.num_qubits() == gate.num_qubits()
109    assert gate1.invert_mask == gate.invert_mask
110    assert cirq.qid_shape(gate1) == cirq.qid_shape(gate)
111    if use_protocol:
112        gate2 = cirq.with_measurement_key_mapping(gate, {'a': 'a'})
113    else:
114        gate2 = gate.with_key('a')
115    assert gate2 == gate
116
117
118@pytest.mark.parametrize(
119    'num_qubits, mask, bits, flipped',
120    [
121        (1, (), [0], (True,)),
122        (3, (False,), [1], (False, True)),
123        (3, (False, False), [0, 2], (True, False, True)),
124    ],
125)
126def test_measurement_with_bits_flipped(num_qubits, mask, bits, flipped):
127    gate = cirq.MeasurementGate(num_qubits, key='a', invert_mask=mask, qid_shape=(3,) * num_qubits)
128
129    gate1 = gate.with_bits_flipped(*bits)
130    assert gate1.key == gate.key
131    assert gate1.num_qubits() == gate.num_qubits()
132    assert gate1.invert_mask == flipped
133    assert cirq.qid_shape(gate1) == cirq.qid_shape(gate)
134
135    # Flipping bits again restores the mask (but may have extended it).
136    gate2 = gate1.with_bits_flipped(*bits)
137    assert gate2.full_invert_mask() == gate.full_invert_mask()
138
139
140def test_qudit_measure_qasm():
141    assert (
142        cirq.qasm(
143            cirq.measure(cirq.LineQid(0, 3), key='a'),
144            args=cirq.QasmArgs(),
145            default='not implemented',
146        )
147        == 'not implemented'
148    )
149
150
151def test_qudit_measure_quil():
152    q0 = cirq.LineQid(0, 3)
153    qubit_id_map = {q0: '0'}
154    assert (
155        cirq.quil(
156            cirq.measure(q0, key='a'),
157            formatter=cirq.QuilFormatter(qubit_id_map=qubit_id_map, measurement_id_map={}),
158        )
159        == None
160    )
161
162
163def test_measurement_gate_diagram():
164    # Shows key.
165    assert cirq.circuit_diagram_info(
166        cirq.MeasurementGate(1, key='test')
167    ) == cirq.CircuitDiagramInfo(("M('test')",))
168
169    # Uses known qubit count.
170    assert (
171        cirq.circuit_diagram_info(
172            cirq.MeasurementGate(3, 'a'),
173            cirq.CircuitDiagramInfoArgs(
174                known_qubits=None,
175                known_qubit_count=3,
176                use_unicode_characters=True,
177                precision=None,
178                qubit_map=None,
179            ),
180        )
181        == cirq.CircuitDiagramInfo(("M('a')", 'M', 'M'))
182    )
183
184    # Shows invert mask.
185    assert cirq.circuit_diagram_info(
186        cirq.MeasurementGate(2, 'a', invert_mask=(False, True))
187    ) == cirq.CircuitDiagramInfo(("M('a')", "!M"))
188
189    # Omits key when it is the default.
190    a = cirq.NamedQubit('a')
191    b = cirq.NamedQubit('b')
192    cirq.testing.assert_has_diagram(
193        cirq.Circuit(cirq.measure(a, b)),
194        """
195a: ───M───
196197b: ───M───
198""",
199    )
200    cirq.testing.assert_has_diagram(
201        cirq.Circuit(cirq.measure(a, b, invert_mask=(True,))),
202        """
203a: ───!M───
204205b: ───M────
206""",
207    )
208    cirq.testing.assert_has_diagram(
209        cirq.Circuit(cirq.measure(a, b, key='test')),
210        """
211a: ───M('test')───
212213b: ───M───────────
214""",
215    )
216
217
218def test_measurement_channel():
219    np.testing.assert_allclose(
220        cirq.kraus(cirq.MeasurementGate(1, 'a')),
221        (np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])),
222    )
223    # yapf: disable
224    np.testing.assert_allclose(
225            cirq.kraus(cirq.MeasurementGate(2, 'a')),
226            (np.array([[1, 0, 0, 0],
227                       [0, 0, 0, 0],
228                       [0, 0, 0, 0],
229                       [0, 0, 0, 0]]),
230             np.array([[0, 0, 0, 0],
231                       [0, 1, 0, 0],
232                       [0, 0, 0, 0],
233                       [0, 0, 0, 0]]),
234             np.array([[0, 0, 0, 0],
235                       [0, 0, 0, 0],
236                       [0, 0, 1, 0],
237                       [0, 0, 0, 0]]),
238             np.array([[0, 0, 0, 0],
239                       [0, 0, 0, 0],
240                       [0, 0, 0, 0],
241                       [0, 0, 0, 1]])))
242    np.testing.assert_allclose(
243            cirq.kraus(cirq.MeasurementGate(2, 'a', qid_shape=(2, 3))),
244            (np.diag([1, 0, 0, 0, 0, 0]),
245             np.diag([0, 1, 0, 0, 0, 0]),
246             np.diag([0, 0, 1, 0, 0, 0]),
247             np.diag([0, 0, 0, 1, 0, 0]),
248             np.diag([0, 0, 0, 0, 1, 0]),
249             np.diag([0, 0, 0, 0, 0, 1])))
250    # yapf: enable
251
252
253def test_measurement_qubit_count_vs_mask_length():
254    a = cirq.NamedQubit('a')
255    b = cirq.NamedQubit('b')
256    c = cirq.NamedQubit('c')
257
258    _ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True,)).on(a)
259    _ = cirq.MeasurementGate(num_qubits=2, key='a', invert_mask=(True, False)).on(a, b)
260    _ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b, c)
261    with pytest.raises(ValueError):
262        _ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True, False)).on(a)
263    with pytest.raises(ValueError):
264        _ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b)
265
266
267def test_consistent_protocols():
268    for n in range(1, 5):
269        gate = cirq.MeasurementGate(num_qubits=n, key='a')
270        cirq.testing.assert_implements_consistent_protocols(gate)
271
272        gate = cirq.MeasurementGate(num_qubits=n, key='a', qid_shape=(3,) * n)
273        cirq.testing.assert_implements_consistent_protocols(gate)
274
275
276def test_op_repr():
277    a, b = cirq.LineQubit.range(2)
278    assert repr(cirq.measure(a)) == 'cirq.measure(cirq.LineQubit(0))'
279    assert repr(cirq.measure(a, b)) == ('cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1))')
280    assert repr(cirq.measure(a, b, key='out', invert_mask=(False, True))) == (
281        "cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), "
282        "key=cirq.MeasurementKey(name='out'), "
283        "invert_mask=(False, True))"
284    )
285
286
287def test_act_on_state_vector():
288    a, b = [cirq.LineQubit(3), cirq.LineQubit(1)]
289    m = cirq.measure(a, b, key='out', invert_mask=(True,))
290
291    args = cirq.ActOnStateVectorArgs(
292        target_tensor=cirq.one_hot(shape=(2, 2, 2, 2, 2), dtype=np.complex64),
293        available_buffer=np.empty(shape=(2, 2, 2, 2, 2)),
294        qubits=cirq.LineQubit.range(5),
295        prng=np.random.RandomState(),
296        log_of_measurement_results={},
297    )
298    cirq.act_on(m, args)
299    assert args.log_of_measurement_results == {'out': [1, 0]}
300
301    args = cirq.ActOnStateVectorArgs(
302        target_tensor=cirq.one_hot(
303            index=(0, 1, 0, 0, 0), shape=(2, 2, 2, 2, 2), dtype=np.complex64
304        ),
305        available_buffer=np.empty(shape=(2, 2, 2, 2, 2)),
306        qubits=cirq.LineQubit.range(5),
307        prng=np.random.RandomState(),
308        log_of_measurement_results={},
309    )
310    cirq.act_on(m, args)
311    assert args.log_of_measurement_results == {'out': [1, 1]}
312
313    args = cirq.ActOnStateVectorArgs(
314        target_tensor=cirq.one_hot(
315            index=(0, 1, 0, 1, 0), shape=(2, 2, 2, 2, 2), dtype=np.complex64
316        ),
317        available_buffer=np.empty(shape=(2, 2, 2, 2, 2)),
318        qubits=cirq.LineQubit.range(5),
319        prng=np.random.RandomState(),
320        log_of_measurement_results={},
321    )
322    cirq.act_on(m, args)
323    assert args.log_of_measurement_results == {'out': [0, 1]}
324
325    with pytest.raises(ValueError, match="already logged to key"):
326        cirq.act_on(m, args)
327
328
329def test_act_on_clifford_tableau():
330    a, b = [cirq.LineQubit(3), cirq.LineQubit(1)]
331    m = cirq.measure(a, b, key='out', invert_mask=(True,))
332    # The below assertion does not fail since it ignores non-unitary operations
333    cirq.testing.assert_all_implemented_act_on_effects_match_unitary(m)
334
335    args = cirq.ActOnCliffordTableauArgs(
336        tableau=cirq.CliffordTableau(num_qubits=5, initial_state=0),
337        qubits=cirq.LineQubit.range(5),
338        prng=np.random.RandomState(),
339        log_of_measurement_results={},
340    )
341    cirq.act_on(m, args)
342    assert args.log_of_measurement_results == {'out': [1, 0]}
343
344    args = cirq.ActOnCliffordTableauArgs(
345        tableau=cirq.CliffordTableau(num_qubits=5, initial_state=8),
346        qubits=cirq.LineQubit.range(5),
347        prng=np.random.RandomState(),
348        log_of_measurement_results={},
349    )
350
351    cirq.act_on(m, args)
352    assert args.log_of_measurement_results == {'out': [1, 1]}
353
354    args = cirq.ActOnCliffordTableauArgs(
355        tableau=cirq.CliffordTableau(num_qubits=5, initial_state=10),
356        qubits=cirq.LineQubit.range(5),
357        prng=np.random.RandomState(),
358        log_of_measurement_results={},
359    )
360    cirq.act_on(m, args)
361    assert args.log_of_measurement_results == {'out': [0, 1]}
362
363    with pytest.raises(ValueError, match="already logged to key"):
364        cirq.act_on(m, args)
365
366
367def test_act_on_stabilizer_ch_form():
368    a, b = [cirq.LineQubit(3), cirq.LineQubit(1)]
369    m = cirq.measure(a, b, key='out', invert_mask=(True,))
370    # The below assertion does not fail since it ignores non-unitary operations
371    cirq.testing.assert_all_implemented_act_on_effects_match_unitary(m)
372
373    args = cirq.ActOnStabilizerCHFormArgs(
374        state=cirq.StabilizerStateChForm(num_qubits=5, initial_state=0),
375        qubits=cirq.LineQubit.range(5),
376        prng=np.random.RandomState(),
377        log_of_measurement_results={},
378    )
379    cirq.act_on(m, args)
380    assert args.log_of_measurement_results == {'out': [1, 0]}
381
382    args = cirq.ActOnStabilizerCHFormArgs(
383        state=cirq.StabilizerStateChForm(num_qubits=5, initial_state=8),
384        qubits=cirq.LineQubit.range(5),
385        prng=np.random.RandomState(),
386        log_of_measurement_results={},
387    )
388
389    cirq.act_on(m, args)
390    assert args.log_of_measurement_results == {'out': [1, 1]}
391
392    args = cirq.ActOnStabilizerCHFormArgs(
393        state=cirq.StabilizerStateChForm(num_qubits=5, initial_state=10),
394        qubits=cirq.LineQubit.range(5),
395        prng=np.random.RandomState(),
396        log_of_measurement_results={},
397    )
398    cirq.act_on(m, args)
399    assert args.log_of_measurement_results == {'out': [0, 1]}
400
401    with pytest.raises(ValueError, match="already logged to key"):
402        cirq.act_on(m, args)
403
404
405def test_act_on_qutrit():
406    a, b = [cirq.LineQid(3, dimension=3), cirq.LineQid(1, dimension=3)]
407    m = cirq.measure(a, b, key='out', invert_mask=(True,))
408
409    args = cirq.ActOnStateVectorArgs(
410        target_tensor=cirq.one_hot(
411            index=(0, 2, 0, 2, 0), shape=(3, 3, 3, 3, 3), dtype=np.complex64
412        ),
413        available_buffer=np.empty(shape=(3, 3, 3, 3, 3)),
414        qubits=cirq.LineQid.range(5, dimension=3),
415        prng=np.random.RandomState(),
416        log_of_measurement_results={},
417    )
418    cirq.act_on(m, args)
419    assert args.log_of_measurement_results == {'out': [2, 2]}
420
421    args = cirq.ActOnStateVectorArgs(
422        target_tensor=cirq.one_hot(
423            index=(0, 1, 0, 2, 0), shape=(3, 3, 3, 3, 3), dtype=np.complex64
424        ),
425        available_buffer=np.empty(shape=(3, 3, 3, 3, 3)),
426        qubits=cirq.LineQid.range(5, dimension=3),
427        prng=np.random.RandomState(),
428        log_of_measurement_results={},
429    )
430    cirq.act_on(m, args)
431    assert args.log_of_measurement_results == {'out': [2, 1]}
432
433    args = cirq.ActOnStateVectorArgs(
434        target_tensor=cirq.one_hot(
435            index=(0, 2, 0, 1, 0), shape=(3, 3, 3, 3, 3), dtype=np.complex64
436        ),
437        available_buffer=np.empty(shape=(3, 3, 3, 3, 3)),
438        qubits=cirq.LineQid.range(5, dimension=3),
439        prng=np.random.RandomState(),
440        log_of_measurement_results={},
441    )
442    cirq.act_on(m, args)
443    assert args.log_of_measurement_results == {'out': [0, 2]}
444