1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for state_histogram."""
16
17import numpy as np
18from matplotlib import pyplot as plt
19import matplotlib as mpl
20
21import cirq
22from cirq.devices import GridQubit
23from cirq.vis import state_histogram
24
25
26def test_get_state_histogram():
27    simulator = cirq.Simulator()
28
29    q0 = GridQubit(0, 0)
30    q1 = GridQubit(1, 0)
31    circuit = cirq.Circuit()
32    circuit.append([cirq.X(q0), cirq.X(q1)])
33    circuit.append([cirq.measure(q0, key='q0'), cirq.measure(q1, key='q1')])
34    result = simulator.run(program=circuit, repetitions=5)
35
36    values_to_plot = state_histogram.get_state_histogram(result)
37    expected_values = [0.0, 0.0, 0.0, 5.0]
38
39    np.testing.assert_equal(values_to_plot, expected_values)
40
41
42def test_get_state_histogram_multi_1():
43    qubits = cirq.LineQubit.range(4)
44    c = cirq.Circuit(
45        cirq.X.on_each(*qubits[1:]),
46        cirq.measure(*qubits),  # One multi-qubit measurement
47    )
48    r = cirq.sample(c, repetitions=5)
49    values_to_plot = state_histogram.get_state_histogram(r)
50    expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
51    np.testing.assert_equal(values_to_plot, expected_values)
52
53
54def test_get_state_histogram_multi_2():
55    qubits = cirq.LineQubit.range(4)
56    c = cirq.Circuit(
57        cirq.X.on_each(*qubits[1:]),
58        cirq.measure(*qubits[:2]),  # One multi-qubit measurement
59        cirq.measure_each(*qubits[2:]),  # Multiple single-qubit measurement
60    )
61    r = cirq.sample(c, repetitions=5)
62    values_to_plot = state_histogram.get_state_histogram(r)
63    expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
64    np.testing.assert_equal(values_to_plot, expected_values)
65
66
67def test_plot_state_histogram_result():
68    qubits = cirq.LineQubit.range(4)
69    c = cirq.Circuit(
70        cirq.X.on_each(*qubits[1:]),
71        cirq.measure(*qubits),  # One multi-qubit measurement
72    )
73    r = cirq.sample(c, repetitions=5)
74    expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0]
75    _, (ax1, ax2) = plt.subplots(1, 2)
76    state_histogram.plot_state_histogram(r, ax1)
77    state_histogram.plot_state_histogram(expected_values, ax2)
78    for r1, r2 in zip(ax1.get_children(), ax2.get_children()):
79        if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle):
80            assert str(r1) == str(r2)
81
82
83def test_plot_state_histogram_collection():
84    qubits = cirq.LineQubit.range(4)
85    c = cirq.Circuit(
86        cirq.X.on_each(*qubits[1:]),
87        cirq.measure(*qubits),  # One multi-qubit measurement
88    )
89    r = cirq.sample(c, repetitions=5)
90    _, (ax1, ax2) = plt.subplots(1, 2)
91    state_histogram.plot_state_histogram(r.histogram(key='0,1,2,3'), ax1)
92    expected_values = [5]
93    tick_label = ['7']
94    state_histogram.plot_state_histogram(expected_values, ax2, tick_label=tick_label, xlabel=None)
95    for r1, r2 in zip(ax1.get_children(), ax2.get_children()):
96        if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle):
97            assert str(r1) == str(r2)
98