1"""Tests for the Quantum Volume utilities.""" 2 3from unittest.mock import Mock, MagicMock 4import io 5import numpy as np 6import pytest 7import cirq 8import cirq.contrib.routing as ccr 9from cirq.contrib.quantum_volume import CompilationResult 10 11 12class TestDevice(cirq.Device): 13 qubits = cirq.GridQubit.rect(5, 5) 14 15 16def test_generate_model_circuit(): 17 """Test that a model circuit is randomly generated.""" 18 model_circuit = cirq.contrib.quantum_volume.generate_model_circuit( 19 3, 3, random_state=np.random.RandomState(1) 20 ) 21 22 assert len(model_circuit) == 3 23 # Ensure there are no measurement gates. 24 assert list(model_circuit.findall_operations_with_gate_type(cirq.MeasurementGate)) == [] 25 26 27def test_generate_model_circuit_without_seed(): 28 """Test that a model circuit is randomly generated without a seed.""" 29 model_circuit = cirq.contrib.quantum_volume.generate_model_circuit(3, 3) 30 31 assert len(model_circuit) == 3 32 # Ensure there are no measurement gates. 33 assert list(model_circuit.findall_operations_with_gate_type(cirq.MeasurementGate)) == [] 34 35 36def test_generate_model_circuit_seed(): 37 """Test that a model circuit is determined by its seed .""" 38 model_circuit_1 = cirq.contrib.quantum_volume.generate_model_circuit( 39 3, 3, random_state=np.random.RandomState(1) 40 ) 41 model_circuit_2 = cirq.contrib.quantum_volume.generate_model_circuit( 42 3, 3, random_state=np.random.RandomState(1) 43 ) 44 model_circuit_3 = cirq.contrib.quantum_volume.generate_model_circuit( 45 3, 3, random_state=np.random.RandomState(2) 46 ) 47 48 assert model_circuit_1 == model_circuit_2 49 assert model_circuit_2 != model_circuit_3 50 51 52def test_compute_heavy_set(): 53 """Test that the heavy set can be computed from a given circuit.""" 54 a, b, c = cirq.LineQubit.range(3) 55 model_circuit = cirq.Circuit( 56 [ 57 cirq.Moment([]), 58 cirq.Moment([cirq.X(a), cirq.Y(b)]), 59 cirq.Moment([]), 60 cirq.Moment([cirq.CNOT(a, c)]), 61 cirq.Moment([cirq.Z(a), cirq.H(b)]), 62 ] 63 ) 64 assert cirq.contrib.quantum_volume.compute_heavy_set(model_circuit) == [5, 7] 65 66 67def test_sample_heavy_set(): 68 """Test that we correctly sample a circuit's heavy set""" 69 70 sampler = Mock(spec=cirq.Simulator) 71 # Construct a result that returns "1", "2", "3", "0" 72 result = cirq.Result.from_single_parameter_set( 73 params=cirq.ParamResolver({}), 74 measurements={'mock': np.array([[0, 1], [1, 0], [1, 1], [0, 0]])}, 75 ) 76 sampler.run = MagicMock(return_value=result) 77 circuit = cirq.Circuit(cirq.measure(*cirq.LineQubit.range(2))) 78 compilation_result = CompilationResult(circuit=circuit, mapping={}, parity_map={}) 79 probability = cirq.contrib.quantum_volume.sample_heavy_set( 80 compilation_result, [1, 2, 3], sampler=sampler, repetitions=10 81 ) 82 # The first 3 of our outputs are in the heavy set, and then the rest are 83 # not. 84 assert probability == 0.75 85 86 87def test_sample_heavy_set_with_parity(): 88 """Test that we correctly sample a circuit's heavy set with a parity map""" 89 90 sampler = Mock(spec=cirq.Simulator) 91 # Construct a result that returns [1, 0, 1, 0] for the physical qubit 92 # measurement, and [0, 1, 1, 0] for the ancilla qubit measurement. The first 93 # bitstring "10" is valid and heavy. The second "01" is valid and not 94 # heavy. The third and fourth bitstrings "11" and "00" are not valid and 95 # dropped. 96 result = cirq.Result.from_single_parameter_set( 97 params=cirq.ParamResolver({}), 98 measurements={ 99 '0': np.array([[1], [0]]), 100 '1': np.array([[0], [1]]), 101 '2': np.array([[1], [1]]), 102 '3': np.array([[0], [0]]), 103 }, 104 ) 105 sampler.run = MagicMock(return_value=result) 106 circuit = cirq.Circuit(cirq.measure(*cirq.LineQubit.range(4))) 107 compilation_result = CompilationResult( 108 circuit=circuit, 109 mapping={q: q for q in cirq.LineQubit.range(4)}, 110 parity_map={cirq.LineQubit(0): cirq.LineQubit(1), cirq.LineQubit(2): cirq.LineQubit(3)}, 111 ) 112 probability = cirq.contrib.quantum_volume.sample_heavy_set( 113 compilation_result, [1], sampler=sampler, repetitions=1 114 ) 115 # The first output is in the heavy set. The second one isn't, but it is 116 # dropped. 117 assert probability == 0.5 118 119 120def test_compile_circuit_router(): 121 """Tests that the given router is used.""" 122 router_mock = MagicMock() 123 cirq.contrib.quantum_volume.compile_circuit( 124 cirq.Circuit(), 125 device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits), 126 router=router_mock, 127 routing_attempts=1, 128 ) 129 router_mock.assert_called() 130 131 132def test_compile_circuit(): 133 """Tests that we are able to compile a model circuit.""" 134 compiler_mock = MagicMock(side_effect=lambda circuit: circuit) 135 a, b, c = cirq.LineQubit.range(3) 136 model_circuit = cirq.Circuit( 137 [ 138 cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]), 139 ] 140 ) 141 compilation_result = cirq.contrib.quantum_volume.compile_circuit( 142 model_circuit, 143 device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits), 144 compiler=compiler_mock, 145 routing_attempts=1, 146 ) 147 148 assert len(compilation_result.mapping) == 3 149 assert cirq.contrib.routing.ops_are_consistent_with_device_graph( 150 compilation_result.circuit.all_operations(), 151 cirq.contrib.routing.gridqubits_to_graph_device(TestDevice().qubits), 152 ) 153 compiler_mock.assert_called_with(compilation_result.circuit) 154 155 156def test_compile_circuit_replaces_swaps(): 157 """Tests that the compiler never sees the SwapPermutationGates from the 158 router.""" 159 compiler_mock = MagicMock(side_effect=lambda circuit: circuit) 160 a, b, c = cirq.LineQubit.range(3) 161 # Create a circuit that will require some swaps. 162 model_circuit = cirq.Circuit( 163 [ 164 cirq.Moment([cirq.CNOT(a, b)]), 165 cirq.Moment([cirq.CNOT(a, c)]), 166 cirq.Moment([cirq.CNOT(b, c)]), 167 ] 168 ) 169 compilation_result = cirq.contrib.quantum_volume.compile_circuit( 170 model_circuit, 171 device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits), 172 compiler=compiler_mock, 173 routing_attempts=1, 174 ) 175 176 # Assert that there were some swaps in the result 177 compiler_mock.assert_called_with(compilation_result.circuit) 178 assert ( 179 len( 180 list(compilation_result.circuit.findall_operations_with_gate_type(cirq.ops.SwapPowGate)) 181 ) 182 > 0 183 ) 184 # Assert that there were not SwapPermutations in the result. 185 assert ( 186 len( 187 list( 188 compilation_result.circuit.findall_operations_with_gate_type( 189 cirq.contrib.acquaintance.SwapPermutationGate 190 ) 191 ) 192 ) 193 == 0 194 ) 195 196 197def test_compile_circuit_with_readout_correction(): 198 """Tests that we are able to compile a model circuit with readout error 199 correction.""" 200 compiler_mock = MagicMock(side_effect=lambda circuit: circuit) 201 router_mock = MagicMock(side_effect=lambda circuit, network: ccr.SwapNetwork(circuit, {})) 202 a, b, c = cirq.LineQubit.range(3) 203 ap, bp, cp = cirq.LineQubit.range(3, 6) 204 model_circuit = cirq.Circuit( 205 [ 206 cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]), 207 ] 208 ) 209 compilation_result = cirq.contrib.quantum_volume.compile_circuit( 210 model_circuit, 211 device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits), 212 compiler=compiler_mock, 213 router=router_mock, 214 routing_attempts=1, 215 add_readout_error_correction=True, 216 ) 217 218 assert compilation_result.circuit == cirq.Circuit( 219 [ 220 cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]), 221 cirq.Moment([cirq.X(a), cirq.X(b), cirq.X(c)]), 222 cirq.Moment([cirq.CNOT(a, ap), cirq.CNOT(b, bp), cirq.CNOT(c, cp)]), 223 cirq.Moment([cirq.X(a), cirq.X(b), cirq.X(c)]), 224 ] 225 ) 226 227 228def test_compile_circuit_multiple_routing_attempts(): 229 """Tests that we make multiple attempts at routing and keep the best one.""" 230 qubits = cirq.LineQubit.range(3) 231 initial_mapping = dict(zip(qubits, qubits)) 232 more_operations = cirq.Circuit( 233 [ 234 cirq.X.on_each(qubits), 235 cirq.Y.on_each(qubits), 236 ] 237 ) 238 more_qubits = cirq.Circuit( 239 [ 240 cirq.X.on_each(cirq.LineQubit.range(4)), 241 ] 242 ) 243 well_routed = cirq.Circuit( 244 [ 245 cirq.X.on_each(qubits), 246 ] 247 ) 248 router_mock = MagicMock( 249 side_effect=[ 250 ccr.SwapNetwork(more_operations, initial_mapping), 251 ccr.SwapNetwork(well_routed, initial_mapping), 252 ccr.SwapNetwork(more_qubits, initial_mapping), 253 ] 254 ) 255 compiler_mock = MagicMock(side_effect=lambda circuit: circuit) 256 model_circuit = cirq.Circuit([cirq.X.on_each(qubits)]) 257 258 compilation_result = cirq.contrib.quantum_volume.compile_circuit( 259 model_circuit, 260 device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits), 261 compiler=compiler_mock, 262 router=router_mock, 263 routing_attempts=3, 264 ) 265 266 assert compilation_result.mapping == initial_mapping 267 assert router_mock.call_count == 3 268 compiler_mock.assert_called_with(well_routed) 269 270 271def test_compile_circuit_no_routing_attempts(): 272 """Tests that setting no routing attempts throws an error.""" 273 a, b, c = cirq.LineQubit.range(3) 274 model_circuit = cirq.Circuit( 275 [ 276 cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]), 277 ] 278 ) 279 280 with pytest.raises(AssertionError) as e: 281 cirq.contrib.quantum_volume.compile_circuit( 282 model_circuit, 283 device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits), 284 routing_attempts=0, 285 ) 286 assert e.match('Unable to get routing for circuit') 287 288 289def test_calculate_quantum_volume_result(): 290 """Test that running the main loop returns the desired result""" 291 results = cirq.contrib.quantum_volume.calculate_quantum_volume( 292 num_qubits=3, 293 depth=3, 294 num_circuits=1, 295 device_graph=ccr.gridqubits_to_graph_device(cirq.GridQubit.rect(3, 3)), 296 samplers=[cirq.Simulator()], 297 routing_attempts=2, 298 random_state=1, 299 ) 300 301 model_circuit = cirq.contrib.quantum_volume.generate_model_circuit(3, 3, random_state=1) 302 assert len(results) == 1 303 assert results[0].model_circuit == model_circuit 304 assert results[0].heavy_set == cirq.contrib.quantum_volume.compute_heavy_set(model_circuit) 305 # Ensure that calling to_json on the results does not err. 306 buffer = io.StringIO() 307 cirq.to_json(results, buffer) 308 309 310def test_calculate_quantum_volume_result_with_device_graph(): 311 """Test that running the main loop routes the circuit onto the given device 312 graph""" 313 device_qubits = [cirq.GridQubit(i, j) for i in range(2) for j in range(3)] 314 315 results = cirq.contrib.quantum_volume.calculate_quantum_volume( 316 num_qubits=3, 317 depth=3, 318 num_circuits=1, 319 device_graph=ccr.gridqubits_to_graph_device(device_qubits), 320 samplers=[cirq.Simulator()], 321 routing_attempts=2, 322 random_state=1, 323 ) 324 325 assert len(results) == 1 326 assert ccr.ops_are_consistent_with_device_graph( 327 results[0].compiled_circuit.all_operations(), ccr.get_grid_device_graph(2, 3) 328 ) 329 330 331def test_calculate_quantum_volume_loop(): 332 """Test that calculate_quantum_volume is able to run without erring.""" 333 # Keep test from taking a long time by lowering circuits and routing 334 # attempts. 335 cirq.contrib.quantum_volume.calculate_quantum_volume( 336 num_qubits=5, 337 depth=5, 338 num_circuits=1, 339 routing_attempts=2, 340 random_state=1, 341 device_graph=ccr.gridqubits_to_graph_device(cirq.GridQubit.rect(3, 3)), 342 samplers=[cirq.Simulator()], 343 ) 344 345 346def test_calculate_quantum_volume_loop_with_readout_correction(): 347 """Test that calculate_quantum_volume is able to run without erring with 348 readout error correction.""" 349 # Keep test from taking a long time by lowering circuits and routing 350 # attempts. 351 cirq.contrib.quantum_volume.calculate_quantum_volume( 352 num_qubits=4, 353 depth=4, 354 num_circuits=1, 355 routing_attempts=2, 356 random_state=1, 357 device_graph=ccr.gridqubits_to_graph_device(cirq.GridQubit.rect(3, 3)), 358 samplers=[cirq.Simulator()], 359 add_readout_error_correction=True, 360 ) 361