1# Copyright 2021 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for state_histogram.""" 16 17import numpy as np 18from matplotlib import pyplot as plt 19import matplotlib as mpl 20 21import cirq 22from cirq.devices import GridQubit 23from cirq.vis import state_histogram 24 25 26def test_get_state_histogram(): 27 simulator = cirq.Simulator() 28 29 q0 = GridQubit(0, 0) 30 q1 = GridQubit(1, 0) 31 circuit = cirq.Circuit() 32 circuit.append([cirq.X(q0), cirq.X(q1)]) 33 circuit.append([cirq.measure(q0, key='q0'), cirq.measure(q1, key='q1')]) 34 result = simulator.run(program=circuit, repetitions=5) 35 36 values_to_plot = state_histogram.get_state_histogram(result) 37 expected_values = [0.0, 0.0, 0.0, 5.0] 38 39 np.testing.assert_equal(values_to_plot, expected_values) 40 41 42def test_get_state_histogram_multi_1(): 43 qubits = cirq.LineQubit.range(4) 44 c = cirq.Circuit( 45 cirq.X.on_each(*qubits[1:]), 46 cirq.measure(*qubits), # One multi-qubit measurement 47 ) 48 r = cirq.sample(c, repetitions=5) 49 values_to_plot = state_histogram.get_state_histogram(r) 50 expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] 51 np.testing.assert_equal(values_to_plot, expected_values) 52 53 54def test_get_state_histogram_multi_2(): 55 qubits = cirq.LineQubit.range(4) 56 c = cirq.Circuit( 57 cirq.X.on_each(*qubits[1:]), 58 cirq.measure(*qubits[:2]), # One multi-qubit measurement 59 cirq.measure_each(*qubits[2:]), # Multiple single-qubit measurement 60 ) 61 r = cirq.sample(c, repetitions=5) 62 values_to_plot = state_histogram.get_state_histogram(r) 63 expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] 64 np.testing.assert_equal(values_to_plot, expected_values) 65 66 67def test_plot_state_histogram_result(): 68 qubits = cirq.LineQubit.range(4) 69 c = cirq.Circuit( 70 cirq.X.on_each(*qubits[1:]), 71 cirq.measure(*qubits), # One multi-qubit measurement 72 ) 73 r = cirq.sample(c, repetitions=5) 74 expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] 75 _, (ax1, ax2) = plt.subplots(1, 2) 76 state_histogram.plot_state_histogram(r, ax1) 77 state_histogram.plot_state_histogram(expected_values, ax2) 78 for r1, r2 in zip(ax1.get_children(), ax2.get_children()): 79 if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle): 80 assert str(r1) == str(r2) 81 82 83def test_plot_state_histogram_collection(): 84 qubits = cirq.LineQubit.range(4) 85 c = cirq.Circuit( 86 cirq.X.on_each(*qubits[1:]), 87 cirq.measure(*qubits), # One multi-qubit measurement 88 ) 89 r = cirq.sample(c, repetitions=5) 90 _, (ax1, ax2) = plt.subplots(1, 2) 91 state_histogram.plot_state_histogram(r.histogram(key='0,1,2,3'), ax1) 92 expected_values = [5] 93 tick_label = ['7'] 94 state_histogram.plot_state_histogram(expected_values, ax2, tick_label=tick_label, xlabel=None) 95 for r1, r2 in zip(ax1.get_children(), ax2.get_children()): 96 if isinstance(r1, mpl.patches.Rectangle) and isinstance(r2, mpl.patches.Rectangle): 97 assert str(r1) == str(r2) 98