1"""Tests for the Quantum Volume utilities."""
2
3from unittest.mock import Mock, MagicMock
4import io
5import numpy as np
6import pytest
7import cirq
8import cirq.contrib.routing as ccr
9from cirq.contrib.quantum_volume import CompilationResult
10
11
12class TestDevice(cirq.Device):
13    qubits = cirq.GridQubit.rect(5, 5)
14
15
16def test_generate_model_circuit():
17    """Test that a model circuit is randomly generated."""
18    model_circuit = cirq.contrib.quantum_volume.generate_model_circuit(
19        3, 3, random_state=np.random.RandomState(1)
20    )
21
22    assert len(model_circuit) == 3
23    # Ensure there are no measurement gates.
24    assert list(model_circuit.findall_operations_with_gate_type(cirq.MeasurementGate)) == []
25
26
27def test_generate_model_circuit_without_seed():
28    """Test that a model circuit is randomly generated without a seed."""
29    model_circuit = cirq.contrib.quantum_volume.generate_model_circuit(3, 3)
30
31    assert len(model_circuit) == 3
32    # Ensure there are no measurement gates.
33    assert list(model_circuit.findall_operations_with_gate_type(cirq.MeasurementGate)) == []
34
35
36def test_generate_model_circuit_seed():
37    """Test that a model circuit is determined by its seed ."""
38    model_circuit_1 = cirq.contrib.quantum_volume.generate_model_circuit(
39        3, 3, random_state=np.random.RandomState(1)
40    )
41    model_circuit_2 = cirq.contrib.quantum_volume.generate_model_circuit(
42        3, 3, random_state=np.random.RandomState(1)
43    )
44    model_circuit_3 = cirq.contrib.quantum_volume.generate_model_circuit(
45        3, 3, random_state=np.random.RandomState(2)
46    )
47
48    assert model_circuit_1 == model_circuit_2
49    assert model_circuit_2 != model_circuit_3
50
51
52def test_compute_heavy_set():
53    """Test that the heavy set can be computed from a given circuit."""
54    a, b, c = cirq.LineQubit.range(3)
55    model_circuit = cirq.Circuit(
56        [
57            cirq.Moment([]),
58            cirq.Moment([cirq.X(a), cirq.Y(b)]),
59            cirq.Moment([]),
60            cirq.Moment([cirq.CNOT(a, c)]),
61            cirq.Moment([cirq.Z(a), cirq.H(b)]),
62        ]
63    )
64    assert cirq.contrib.quantum_volume.compute_heavy_set(model_circuit) == [5, 7]
65
66
67def test_sample_heavy_set():
68    """Test that we correctly sample a circuit's heavy set"""
69
70    sampler = Mock(spec=cirq.Simulator)
71    # Construct a result that returns "1", "2", "3", "0"
72    result = cirq.Result.from_single_parameter_set(
73        params=cirq.ParamResolver({}),
74        measurements={'mock': np.array([[0, 1], [1, 0], [1, 1], [0, 0]])},
75    )
76    sampler.run = MagicMock(return_value=result)
77    circuit = cirq.Circuit(cirq.measure(*cirq.LineQubit.range(2)))
78    compilation_result = CompilationResult(circuit=circuit, mapping={}, parity_map={})
79    probability = cirq.contrib.quantum_volume.sample_heavy_set(
80        compilation_result, [1, 2, 3], sampler=sampler, repetitions=10
81    )
82    # The first 3 of our outputs are in the heavy set, and then the rest are
83    # not.
84    assert probability == 0.75
85
86
87def test_sample_heavy_set_with_parity():
88    """Test that we correctly sample a circuit's heavy set with a parity map"""
89
90    sampler = Mock(spec=cirq.Simulator)
91    # Construct a result that returns [1, 0, 1, 0] for the physical qubit
92    # measurement, and [0, 1, 1, 0] for the ancilla qubit measurement. The first
93    # bitstring "10" is valid and heavy. The second "01" is valid and not
94    # heavy. The third and fourth bitstrings "11" and "00" are not valid and
95    # dropped.
96    result = cirq.Result.from_single_parameter_set(
97        params=cirq.ParamResolver({}),
98        measurements={
99            '0': np.array([[1], [0]]),
100            '1': np.array([[0], [1]]),
101            '2': np.array([[1], [1]]),
102            '3': np.array([[0], [0]]),
103        },
104    )
105    sampler.run = MagicMock(return_value=result)
106    circuit = cirq.Circuit(cirq.measure(*cirq.LineQubit.range(4)))
107    compilation_result = CompilationResult(
108        circuit=circuit,
109        mapping={q: q for q in cirq.LineQubit.range(4)},
110        parity_map={cirq.LineQubit(0): cirq.LineQubit(1), cirq.LineQubit(2): cirq.LineQubit(3)},
111    )
112    probability = cirq.contrib.quantum_volume.sample_heavy_set(
113        compilation_result, [1], sampler=sampler, repetitions=1
114    )
115    # The first output is in the heavy set. The second one isn't, but it is
116    # dropped.
117    assert probability == 0.5
118
119
120def test_compile_circuit_router():
121    """Tests that the given router is used."""
122    router_mock = MagicMock()
123    cirq.contrib.quantum_volume.compile_circuit(
124        cirq.Circuit(),
125        device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits),
126        router=router_mock,
127        routing_attempts=1,
128    )
129    router_mock.assert_called()
130
131
132def test_compile_circuit():
133    """Tests that we are able to compile a model circuit."""
134    compiler_mock = MagicMock(side_effect=lambda circuit: circuit)
135    a, b, c = cirq.LineQubit.range(3)
136    model_circuit = cirq.Circuit(
137        [
138            cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]),
139        ]
140    )
141    compilation_result = cirq.contrib.quantum_volume.compile_circuit(
142        model_circuit,
143        device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits),
144        compiler=compiler_mock,
145        routing_attempts=1,
146    )
147
148    assert len(compilation_result.mapping) == 3
149    assert cirq.contrib.routing.ops_are_consistent_with_device_graph(
150        compilation_result.circuit.all_operations(),
151        cirq.contrib.routing.gridqubits_to_graph_device(TestDevice().qubits),
152    )
153    compiler_mock.assert_called_with(compilation_result.circuit)
154
155
156def test_compile_circuit_replaces_swaps():
157    """Tests that the compiler never sees the SwapPermutationGates from the
158    router."""
159    compiler_mock = MagicMock(side_effect=lambda circuit: circuit)
160    a, b, c = cirq.LineQubit.range(3)
161    # Create a circuit that will require some swaps.
162    model_circuit = cirq.Circuit(
163        [
164            cirq.Moment([cirq.CNOT(a, b)]),
165            cirq.Moment([cirq.CNOT(a, c)]),
166            cirq.Moment([cirq.CNOT(b, c)]),
167        ]
168    )
169    compilation_result = cirq.contrib.quantum_volume.compile_circuit(
170        model_circuit,
171        device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits),
172        compiler=compiler_mock,
173        routing_attempts=1,
174    )
175
176    # Assert that there were some swaps in the result
177    compiler_mock.assert_called_with(compilation_result.circuit)
178    assert (
179        len(
180            list(compilation_result.circuit.findall_operations_with_gate_type(cirq.ops.SwapPowGate))
181        )
182        > 0
183    )
184    # Assert that there were not SwapPermutations in the result.
185    assert (
186        len(
187            list(
188                compilation_result.circuit.findall_operations_with_gate_type(
189                    cirq.contrib.acquaintance.SwapPermutationGate
190                )
191            )
192        )
193        == 0
194    )
195
196
197def test_compile_circuit_with_readout_correction():
198    """Tests that we are able to compile a model circuit with readout error
199    correction."""
200    compiler_mock = MagicMock(side_effect=lambda circuit: circuit)
201    router_mock = MagicMock(side_effect=lambda circuit, network: ccr.SwapNetwork(circuit, {}))
202    a, b, c = cirq.LineQubit.range(3)
203    ap, bp, cp = cirq.LineQubit.range(3, 6)
204    model_circuit = cirq.Circuit(
205        [
206            cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]),
207        ]
208    )
209    compilation_result = cirq.contrib.quantum_volume.compile_circuit(
210        model_circuit,
211        device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits),
212        compiler=compiler_mock,
213        router=router_mock,
214        routing_attempts=1,
215        add_readout_error_correction=True,
216    )
217
218    assert compilation_result.circuit == cirq.Circuit(
219        [
220            cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]),
221            cirq.Moment([cirq.X(a), cirq.X(b), cirq.X(c)]),
222            cirq.Moment([cirq.CNOT(a, ap), cirq.CNOT(b, bp), cirq.CNOT(c, cp)]),
223            cirq.Moment([cirq.X(a), cirq.X(b), cirq.X(c)]),
224        ]
225    )
226
227
228def test_compile_circuit_multiple_routing_attempts():
229    """Tests that we make multiple attempts at routing and keep the best one."""
230    qubits = cirq.LineQubit.range(3)
231    initial_mapping = dict(zip(qubits, qubits))
232    more_operations = cirq.Circuit(
233        [
234            cirq.X.on_each(qubits),
235            cirq.Y.on_each(qubits),
236        ]
237    )
238    more_qubits = cirq.Circuit(
239        [
240            cirq.X.on_each(cirq.LineQubit.range(4)),
241        ]
242    )
243    well_routed = cirq.Circuit(
244        [
245            cirq.X.on_each(qubits),
246        ]
247    )
248    router_mock = MagicMock(
249        side_effect=[
250            ccr.SwapNetwork(more_operations, initial_mapping),
251            ccr.SwapNetwork(well_routed, initial_mapping),
252            ccr.SwapNetwork(more_qubits, initial_mapping),
253        ]
254    )
255    compiler_mock = MagicMock(side_effect=lambda circuit: circuit)
256    model_circuit = cirq.Circuit([cirq.X.on_each(qubits)])
257
258    compilation_result = cirq.contrib.quantum_volume.compile_circuit(
259        model_circuit,
260        device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits),
261        compiler=compiler_mock,
262        router=router_mock,
263        routing_attempts=3,
264    )
265
266    assert compilation_result.mapping == initial_mapping
267    assert router_mock.call_count == 3
268    compiler_mock.assert_called_with(well_routed)
269
270
271def test_compile_circuit_no_routing_attempts():
272    """Tests that setting no routing attempts throws an error."""
273    a, b, c = cirq.LineQubit.range(3)
274    model_circuit = cirq.Circuit(
275        [
276            cirq.Moment([cirq.X(a), cirq.Y(b), cirq.Z(c)]),
277        ]
278    )
279
280    with pytest.raises(AssertionError) as e:
281        cirq.contrib.quantum_volume.compile_circuit(
282            model_circuit,
283            device_graph=ccr.gridqubits_to_graph_device(TestDevice().qubits),
284            routing_attempts=0,
285        )
286    assert e.match('Unable to get routing for circuit')
287
288
289def test_calculate_quantum_volume_result():
290    """Test that running the main loop returns the desired result"""
291    results = cirq.contrib.quantum_volume.calculate_quantum_volume(
292        num_qubits=3,
293        depth=3,
294        num_circuits=1,
295        device_graph=ccr.gridqubits_to_graph_device(cirq.GridQubit.rect(3, 3)),
296        samplers=[cirq.Simulator()],
297        routing_attempts=2,
298        random_state=1,
299    )
300
301    model_circuit = cirq.contrib.quantum_volume.generate_model_circuit(3, 3, random_state=1)
302    assert len(results) == 1
303    assert results[0].model_circuit == model_circuit
304    assert results[0].heavy_set == cirq.contrib.quantum_volume.compute_heavy_set(model_circuit)
305    # Ensure that calling to_json on the results does not err.
306    buffer = io.StringIO()
307    cirq.to_json(results, buffer)
308
309
310def test_calculate_quantum_volume_result_with_device_graph():
311    """Test that running the main loop routes the circuit onto the given device
312    graph"""
313    device_qubits = [cirq.GridQubit(i, j) for i in range(2) for j in range(3)]
314
315    results = cirq.contrib.quantum_volume.calculate_quantum_volume(
316        num_qubits=3,
317        depth=3,
318        num_circuits=1,
319        device_graph=ccr.gridqubits_to_graph_device(device_qubits),
320        samplers=[cirq.Simulator()],
321        routing_attempts=2,
322        random_state=1,
323    )
324
325    assert len(results) == 1
326    assert ccr.ops_are_consistent_with_device_graph(
327        results[0].compiled_circuit.all_operations(), ccr.get_grid_device_graph(2, 3)
328    )
329
330
331def test_calculate_quantum_volume_loop():
332    """Test that calculate_quantum_volume is able to run without erring."""
333    # Keep test from taking a long time by lowering circuits and routing
334    # attempts.
335    cirq.contrib.quantum_volume.calculate_quantum_volume(
336        num_qubits=5,
337        depth=5,
338        num_circuits=1,
339        routing_attempts=2,
340        random_state=1,
341        device_graph=ccr.gridqubits_to_graph_device(cirq.GridQubit.rect(3, 3)),
342        samplers=[cirq.Simulator()],
343    )
344
345
346def test_calculate_quantum_volume_loop_with_readout_correction():
347    """Test that calculate_quantum_volume is able to run without erring with
348    readout error correction."""
349    # Keep test from taking a long time by lowering circuits and routing
350    # attempts.
351    cirq.contrib.quantum_volume.calculate_quantum_volume(
352        num_qubits=4,
353        depth=4,
354        num_circuits=1,
355        routing_attempts=2,
356        random_state=1,
357        device_graph=ccr.gridqubits_to_graph_device(cirq.GridQubit.rect(3, 3)),
358        samplers=[cirq.Simulator()],
359        add_readout_error_correction=True,
360    )
361