1# Copyright 2021 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tool to visualize the results of a study.""" 16 17from typing import Union, Optional, Sequence, SupportsFloat 18import collections 19import numpy as np 20import matplotlib.pyplot as plt 21import cirq.study.result as result 22 23 24def get_state_histogram(result: 'result.Result') -> np.ndarray: 25 """Computes a state histogram from a single result with repetitions. 26 27 Args: 28 result: The trial result containing measurement results from which the 29 state histogram should be computed. 30 31 Returns: 32 The state histogram (a numpy array) corresponding to the trial result. 33 """ 34 num_qubits = sum([value.shape[1] for value in result.measurements.values()]) 35 states = 2 ** num_qubits 36 values = np.zeros(states) 37 # measurements is a dict of {measurement gate key: 38 # array(repetitions, boolean result)} 39 # Convert this to an array of repetitions, each with an array of booleans. 40 # e.g. {q1: array([[True, True]]), q2: array([[False, False]])} 41 # --> array([[True, False], [True, False]]) 42 measurement_by_result = np.hstack(list(result.measurements.values())) 43 44 for meas in measurement_by_result: 45 # Convert each array of booleans to a string representation. 46 # e.g. [True, False] -> [1, 0] -> '10' -> 2 47 state_ind = int(''.join([str(x) for x in [int(x) for x in meas]]), 2) 48 values[state_ind] += 1 49 return values 50 51 52def plot_state_histogram( 53 data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]], 54 ax: Optional['plt.Axis'] = None, 55 *, 56 tick_label: Optional[Sequence[str]] = None, 57 xlabel: Optional[str] = 'qubit state', 58 ylabel: Optional[str] = 'result count', 59 title: Optional[str] = 'Result State Histogram', 60) -> 'plt.Axis': 61 """Plot the state histogram from either a single result with repetitions or 62 a histogram computed using `result.histogram()` or a flattened histogram 63 of measurement results computed using `get_state_histogram`. 64 65 Args: 66 data: The histogram values to plot. Possible options are: 67 `result.Result`: Histogram is computed using 68 `get_state_histogram` and all 2 ** num_qubits values are 69 plotted, including 0s. 70 `collections.Counter`: Only (key, value) pairs present in 71 collection are plotted. 72 `Sequence[SupportsFloat]`: Values in the input sequence are 73 plotted. i'th entry corresponds to height of the i'th 74 bar in histogram. 75 ax: The Axes to plot on. If not given, a new figure is created, 76 plotted on, and shown. 77 tick_label: Tick labels for the histogram plot in case input is not 78 `collections.Counter`. By default, label for i'th entry 79 is |i>. 80 xlabel: Label for the x-axis. 81 ylabel: Label for the y-axis. 82 title: Title of the plot. 83 84 Returns: 85 The axis that was plotted on. 86 """ 87 show_fig = not ax 88 if not ax: 89 fig, ax = plt.subplots(1, 1) 90 if isinstance(data, result.Result): 91 values = get_state_histogram(data) 92 elif isinstance(data, collections.Counter): 93 tick_label, values = zip(*sorted(data.items())) 94 else: 95 values = data 96 if not tick_label: 97 tick_label = np.arange(len(values)) 98 ax.bar(np.arange(len(values)), values, tick_label=tick_label) 99 ax.set_xlabel(xlabel) 100 ax.set_ylabel(ylabel) 101 ax.set_title(title) 102 if show_fig: 103 fig.show() 104 return ax 105