1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17 18import cirq 19 20 21@pytest.mark.parametrize( 22 'key', 23 [ 24 'q0_1_0', 25 cirq.MeasurementKey(name='q0_1_0'), 26 cirq.MeasurementKey(path=('a', 'b'), name='c'), 27 ], 28) 29def test_eval_repr(key): 30 # Basic safeguard against repr-inequality. 31 op = cirq.GateOperation( 32 gate=cirq.MeasurementGate(1, key), 33 qubits=[cirq.GridQubit(0, 1)], 34 ) 35 cirq.testing.assert_equivalent_repr(op) 36 37 38@pytest.mark.parametrize('num_qubits', [1, 2, 4]) 39def test_measure_init(num_qubits): 40 assert cirq.MeasurementGate(num_qubits, 'a').num_qubits() == num_qubits 41 assert cirq.MeasurementGate(num_qubits, key='a').key == 'a' 42 assert cirq.MeasurementGate(num_qubits, key='a').mkey == cirq.MeasurementKey('a') 43 assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')).key == 'a' 44 assert cirq.MeasurementGate(num_qubits, key=cirq.MeasurementKey('a')) == cirq.MeasurementGate( 45 num_qubits, key='a' 46 ) 47 assert cirq.MeasurementGate(num_qubits, 'a', invert_mask=(True,)).invert_mask == (True,) 48 assert cirq.qid_shape(cirq.MeasurementGate(num_qubits, 'a')) == (2,) * num_qubits 49 assert cirq.qid_shape(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) == (1, 2, 3) 50 assert cirq.qid_shape(cirq.MeasurementGate(key='a', qid_shape=(1, 2, 3))) == (1, 2, 3) 51 with pytest.raises(ValueError, match='len.* >'): 52 cirq.MeasurementGate(5, 'a', invert_mask=(True,) * 6) 53 with pytest.raises(ValueError, match='len.* !='): 54 cirq.MeasurementGate(5, 'a', qid_shape=(1, 2)) 55 with pytest.raises(ValueError, match='valid string'): 56 cirq.MeasurementGate(2, qid_shape=(1, 2), key=None) 57 with pytest.raises(ValueError, match='Specify either'): 58 cirq.MeasurementGate() 59 60 61@pytest.mark.parametrize('num_qubits', [1, 2, 4]) 62def test_has_stabilizer_effect(num_qubits): 63 assert cirq.has_stabilizer_effect(cirq.MeasurementGate(num_qubits, 'a')) 64 65 66def test_measurement_eq(): 67 eq = cirq.testing.EqualsTester() 68 eq.make_equality_group( 69 lambda: cirq.MeasurementGate(1, 'a'), 70 lambda: cirq.MeasurementGate(1, 'a', invert_mask=()), 71 lambda: cirq.MeasurementGate(1, 'a', qid_shape=(2,)), 72 ) 73 eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(True,))) 74 eq.add_equality_group(cirq.MeasurementGate(1, 'a', invert_mask=(False,))) 75 eq.add_equality_group(cirq.MeasurementGate(1, 'b')) 76 eq.add_equality_group(cirq.MeasurementGate(2, 'a')) 77 eq.add_equality_group( 78 cirq.MeasurementGate(3, 'a'), cirq.MeasurementGate(3, 'a', qid_shape=(2, 2, 2)) 79 ) 80 eq.add_equality_group(cirq.MeasurementGate(3, 'a', qid_shape=(1, 2, 3))) 81 82 83def test_measurement_full_invert_mask(): 84 assert cirq.MeasurementGate(1, 'a').full_invert_mask() == (False,) 85 assert cirq.MeasurementGate(2, 'a', invert_mask=(False, True)).full_invert_mask() == ( 86 False, 87 True, 88 ) 89 assert cirq.MeasurementGate(2, 'a', invert_mask=(True,)).full_invert_mask() == (True, False) 90 91 92@pytest.mark.parametrize('use_protocol', [False, True]) 93@pytest.mark.parametrize( 94 'gate', 95 [ 96 cirq.MeasurementGate(1, 'a'), 97 cirq.MeasurementGate(1, 'a', invert_mask=(True,)), 98 cirq.MeasurementGate(1, 'a', qid_shape=(3,)), 99 cirq.MeasurementGate(2, 'a', invert_mask=(True, False), qid_shape=(2, 3)), 100 ], 101) 102def test_measurement_with_key(use_protocol, gate): 103 if use_protocol: 104 gate1 = cirq.with_measurement_key_mapping(gate, {'a': 'b'}) 105 else: 106 gate1 = gate.with_key('b') 107 assert gate1.key == 'b' 108 assert gate1.num_qubits() == gate.num_qubits() 109 assert gate1.invert_mask == gate.invert_mask 110 assert cirq.qid_shape(gate1) == cirq.qid_shape(gate) 111 if use_protocol: 112 gate2 = cirq.with_measurement_key_mapping(gate, {'a': 'a'}) 113 else: 114 gate2 = gate.with_key('a') 115 assert gate2 == gate 116 117 118@pytest.mark.parametrize( 119 'num_qubits, mask, bits, flipped', 120 [ 121 (1, (), [0], (True,)), 122 (3, (False,), [1], (False, True)), 123 (3, (False, False), [0, 2], (True, False, True)), 124 ], 125) 126def test_measurement_with_bits_flipped(num_qubits, mask, bits, flipped): 127 gate = cirq.MeasurementGate(num_qubits, key='a', invert_mask=mask, qid_shape=(3,) * num_qubits) 128 129 gate1 = gate.with_bits_flipped(*bits) 130 assert gate1.key == gate.key 131 assert gate1.num_qubits() == gate.num_qubits() 132 assert gate1.invert_mask == flipped 133 assert cirq.qid_shape(gate1) == cirq.qid_shape(gate) 134 135 # Flipping bits again restores the mask (but may have extended it). 136 gate2 = gate1.with_bits_flipped(*bits) 137 assert gate2.full_invert_mask() == gate.full_invert_mask() 138 139 140def test_qudit_measure_qasm(): 141 assert ( 142 cirq.qasm( 143 cirq.measure(cirq.LineQid(0, 3), key='a'), 144 args=cirq.QasmArgs(), 145 default='not implemented', 146 ) 147 == 'not implemented' 148 ) 149 150 151def test_qudit_measure_quil(): 152 q0 = cirq.LineQid(0, 3) 153 qubit_id_map = {q0: '0'} 154 assert ( 155 cirq.quil( 156 cirq.measure(q0, key='a'), 157 formatter=cirq.QuilFormatter(qubit_id_map=qubit_id_map, measurement_id_map={}), 158 ) 159 == None 160 ) 161 162 163def test_measurement_gate_diagram(): 164 # Shows key. 165 assert cirq.circuit_diagram_info( 166 cirq.MeasurementGate(1, key='test') 167 ) == cirq.CircuitDiagramInfo(("M('test')",)) 168 169 # Uses known qubit count. 170 assert ( 171 cirq.circuit_diagram_info( 172 cirq.MeasurementGate(3, 'a'), 173 cirq.CircuitDiagramInfoArgs( 174 known_qubits=None, 175 known_qubit_count=3, 176 use_unicode_characters=True, 177 precision=None, 178 qubit_map=None, 179 ), 180 ) 181 == cirq.CircuitDiagramInfo(("M('a')", 'M', 'M')) 182 ) 183 184 # Shows invert mask. 185 assert cirq.circuit_diagram_info( 186 cirq.MeasurementGate(2, 'a', invert_mask=(False, True)) 187 ) == cirq.CircuitDiagramInfo(("M('a')", "!M")) 188 189 # Omits key when it is the default. 190 a = cirq.NamedQubit('a') 191 b = cirq.NamedQubit('b') 192 cirq.testing.assert_has_diagram( 193 cirq.Circuit(cirq.measure(a, b)), 194 """ 195a: ───M─── 196 │ 197b: ───M─── 198""", 199 ) 200 cirq.testing.assert_has_diagram( 201 cirq.Circuit(cirq.measure(a, b, invert_mask=(True,))), 202 """ 203a: ───!M─── 204 │ 205b: ───M──── 206""", 207 ) 208 cirq.testing.assert_has_diagram( 209 cirq.Circuit(cirq.measure(a, b, key='test')), 210 """ 211a: ───M('test')─── 212 │ 213b: ───M─────────── 214""", 215 ) 216 217 218def test_measurement_channel(): 219 np.testing.assert_allclose( 220 cirq.kraus(cirq.MeasurementGate(1, 'a')), 221 (np.array([[1, 0], [0, 0]]), np.array([[0, 0], [0, 1]])), 222 ) 223 # yapf: disable 224 np.testing.assert_allclose( 225 cirq.kraus(cirq.MeasurementGate(2, 'a')), 226 (np.array([[1, 0, 0, 0], 227 [0, 0, 0, 0], 228 [0, 0, 0, 0], 229 [0, 0, 0, 0]]), 230 np.array([[0, 0, 0, 0], 231 [0, 1, 0, 0], 232 [0, 0, 0, 0], 233 [0, 0, 0, 0]]), 234 np.array([[0, 0, 0, 0], 235 [0, 0, 0, 0], 236 [0, 0, 1, 0], 237 [0, 0, 0, 0]]), 238 np.array([[0, 0, 0, 0], 239 [0, 0, 0, 0], 240 [0, 0, 0, 0], 241 [0, 0, 0, 1]]))) 242 np.testing.assert_allclose( 243 cirq.kraus(cirq.MeasurementGate(2, 'a', qid_shape=(2, 3))), 244 (np.diag([1, 0, 0, 0, 0, 0]), 245 np.diag([0, 1, 0, 0, 0, 0]), 246 np.diag([0, 0, 1, 0, 0, 0]), 247 np.diag([0, 0, 0, 1, 0, 0]), 248 np.diag([0, 0, 0, 0, 1, 0]), 249 np.diag([0, 0, 0, 0, 0, 1]))) 250 # yapf: enable 251 252 253def test_measurement_qubit_count_vs_mask_length(): 254 a = cirq.NamedQubit('a') 255 b = cirq.NamedQubit('b') 256 c = cirq.NamedQubit('c') 257 258 _ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True,)).on(a) 259 _ = cirq.MeasurementGate(num_qubits=2, key='a', invert_mask=(True, False)).on(a, b) 260 _ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b, c) 261 with pytest.raises(ValueError): 262 _ = cirq.MeasurementGate(num_qubits=1, key='a', invert_mask=(True, False)).on(a) 263 with pytest.raises(ValueError): 264 _ = cirq.MeasurementGate(num_qubits=3, key='a', invert_mask=(True, False, True)).on(a, b) 265 266 267def test_consistent_protocols(): 268 for n in range(1, 5): 269 gate = cirq.MeasurementGate(num_qubits=n, key='a') 270 cirq.testing.assert_implements_consistent_protocols(gate) 271 272 gate = cirq.MeasurementGate(num_qubits=n, key='a', qid_shape=(3,) * n) 273 cirq.testing.assert_implements_consistent_protocols(gate) 274 275 276def test_op_repr(): 277 a, b = cirq.LineQubit.range(2) 278 assert repr(cirq.measure(a)) == 'cirq.measure(cirq.LineQubit(0))' 279 assert repr(cirq.measure(a, b)) == ('cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1))') 280 assert repr(cirq.measure(a, b, key='out', invert_mask=(False, True))) == ( 281 "cirq.measure(cirq.LineQubit(0), cirq.LineQubit(1), " 282 "key=cirq.MeasurementKey(name='out'), " 283 "invert_mask=(False, True))" 284 ) 285 286 287def test_act_on_state_vector(): 288 a, b = [cirq.LineQubit(3), cirq.LineQubit(1)] 289 m = cirq.measure(a, b, key='out', invert_mask=(True,)) 290 291 args = cirq.ActOnStateVectorArgs( 292 target_tensor=cirq.one_hot(shape=(2, 2, 2, 2, 2), dtype=np.complex64), 293 available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), 294 qubits=cirq.LineQubit.range(5), 295 prng=np.random.RandomState(), 296 log_of_measurement_results={}, 297 ) 298 cirq.act_on(m, args) 299 assert args.log_of_measurement_results == {'out': [1, 0]} 300 301 args = cirq.ActOnStateVectorArgs( 302 target_tensor=cirq.one_hot( 303 index=(0, 1, 0, 0, 0), shape=(2, 2, 2, 2, 2), dtype=np.complex64 304 ), 305 available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), 306 qubits=cirq.LineQubit.range(5), 307 prng=np.random.RandomState(), 308 log_of_measurement_results={}, 309 ) 310 cirq.act_on(m, args) 311 assert args.log_of_measurement_results == {'out': [1, 1]} 312 313 args = cirq.ActOnStateVectorArgs( 314 target_tensor=cirq.one_hot( 315 index=(0, 1, 0, 1, 0), shape=(2, 2, 2, 2, 2), dtype=np.complex64 316 ), 317 available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), 318 qubits=cirq.LineQubit.range(5), 319 prng=np.random.RandomState(), 320 log_of_measurement_results={}, 321 ) 322 cirq.act_on(m, args) 323 assert args.log_of_measurement_results == {'out': [0, 1]} 324 325 with pytest.raises(ValueError, match="already logged to key"): 326 cirq.act_on(m, args) 327 328 329def test_act_on_clifford_tableau(): 330 a, b = [cirq.LineQubit(3), cirq.LineQubit(1)] 331 m = cirq.measure(a, b, key='out', invert_mask=(True,)) 332 # The below assertion does not fail since it ignores non-unitary operations 333 cirq.testing.assert_all_implemented_act_on_effects_match_unitary(m) 334 335 args = cirq.ActOnCliffordTableauArgs( 336 tableau=cirq.CliffordTableau(num_qubits=5, initial_state=0), 337 qubits=cirq.LineQubit.range(5), 338 prng=np.random.RandomState(), 339 log_of_measurement_results={}, 340 ) 341 cirq.act_on(m, args) 342 assert args.log_of_measurement_results == {'out': [1, 0]} 343 344 args = cirq.ActOnCliffordTableauArgs( 345 tableau=cirq.CliffordTableau(num_qubits=5, initial_state=8), 346 qubits=cirq.LineQubit.range(5), 347 prng=np.random.RandomState(), 348 log_of_measurement_results={}, 349 ) 350 351 cirq.act_on(m, args) 352 assert args.log_of_measurement_results == {'out': [1, 1]} 353 354 args = cirq.ActOnCliffordTableauArgs( 355 tableau=cirq.CliffordTableau(num_qubits=5, initial_state=10), 356 qubits=cirq.LineQubit.range(5), 357 prng=np.random.RandomState(), 358 log_of_measurement_results={}, 359 ) 360 cirq.act_on(m, args) 361 assert args.log_of_measurement_results == {'out': [0, 1]} 362 363 with pytest.raises(ValueError, match="already logged to key"): 364 cirq.act_on(m, args) 365 366 367def test_act_on_stabilizer_ch_form(): 368 a, b = [cirq.LineQubit(3), cirq.LineQubit(1)] 369 m = cirq.measure(a, b, key='out', invert_mask=(True,)) 370 # The below assertion does not fail since it ignores non-unitary operations 371 cirq.testing.assert_all_implemented_act_on_effects_match_unitary(m) 372 373 args = cirq.ActOnStabilizerCHFormArgs( 374 state=cirq.StabilizerStateChForm(num_qubits=5, initial_state=0), 375 qubits=cirq.LineQubit.range(5), 376 prng=np.random.RandomState(), 377 log_of_measurement_results={}, 378 ) 379 cirq.act_on(m, args) 380 assert args.log_of_measurement_results == {'out': [1, 0]} 381 382 args = cirq.ActOnStabilizerCHFormArgs( 383 state=cirq.StabilizerStateChForm(num_qubits=5, initial_state=8), 384 qubits=cirq.LineQubit.range(5), 385 prng=np.random.RandomState(), 386 log_of_measurement_results={}, 387 ) 388 389 cirq.act_on(m, args) 390 assert args.log_of_measurement_results == {'out': [1, 1]} 391 392 args = cirq.ActOnStabilizerCHFormArgs( 393 state=cirq.StabilizerStateChForm(num_qubits=5, initial_state=10), 394 qubits=cirq.LineQubit.range(5), 395 prng=np.random.RandomState(), 396 log_of_measurement_results={}, 397 ) 398 cirq.act_on(m, args) 399 assert args.log_of_measurement_results == {'out': [0, 1]} 400 401 with pytest.raises(ValueError, match="already logged to key"): 402 cirq.act_on(m, args) 403 404 405def test_act_on_qutrit(): 406 a, b = [cirq.LineQid(3, dimension=3), cirq.LineQid(1, dimension=3)] 407 m = cirq.measure(a, b, key='out', invert_mask=(True,)) 408 409 args = cirq.ActOnStateVectorArgs( 410 target_tensor=cirq.one_hot( 411 index=(0, 2, 0, 2, 0), shape=(3, 3, 3, 3, 3), dtype=np.complex64 412 ), 413 available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), 414 qubits=cirq.LineQid.range(5, dimension=3), 415 prng=np.random.RandomState(), 416 log_of_measurement_results={}, 417 ) 418 cirq.act_on(m, args) 419 assert args.log_of_measurement_results == {'out': [2, 2]} 420 421 args = cirq.ActOnStateVectorArgs( 422 target_tensor=cirq.one_hot( 423 index=(0, 1, 0, 2, 0), shape=(3, 3, 3, 3, 3), dtype=np.complex64 424 ), 425 available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), 426 qubits=cirq.LineQid.range(5, dimension=3), 427 prng=np.random.RandomState(), 428 log_of_measurement_results={}, 429 ) 430 cirq.act_on(m, args) 431 assert args.log_of_measurement_results == {'out': [2, 1]} 432 433 args = cirq.ActOnStateVectorArgs( 434 target_tensor=cirq.one_hot( 435 index=(0, 2, 0, 1, 0), shape=(3, 3, 3, 3, 3), dtype=np.complex64 436 ), 437 available_buffer=np.empty(shape=(3, 3, 3, 3, 3)), 438 qubits=cirq.LineQid.range(5, dimension=3), 439 prng=np.random.RandomState(), 440 log_of_measurement_results={}, 441 ) 442 cirq.act_on(m, args) 443 assert args.log_of_measurement_results == {'out': [0, 2]} 444