1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tool to visualize the results of a study."""
16
17from typing import Union, Optional, Sequence, SupportsFloat
18import collections
19import numpy as np
20import matplotlib.pyplot as plt
21import cirq.study.result as result
22
23
24def get_state_histogram(result: 'result.Result') -> np.ndarray:
25    """Computes a state histogram from a single result with repetitions.
26
27    Args:
28        result: The trial result containing measurement results from which the
29                state histogram should be computed.
30
31    Returns:
32        The state histogram (a numpy array) corresponding to the trial result.
33    """
34    num_qubits = sum([value.shape[1] for value in result.measurements.values()])
35    states = 2 ** num_qubits
36    values = np.zeros(states)
37    # measurements is a dict of {measurement gate key:
38    #                            array(repetitions, boolean result)}
39    # Convert this to an array of repetitions, each with an array of booleans.
40    # e.g. {q1: array([[True, True]]), q2: array([[False, False]])}
41    #      --> array([[True, False], [True, False]])
42    measurement_by_result = np.hstack(list(result.measurements.values()))
43
44    for meas in measurement_by_result:
45        # Convert each array of booleans to a string representation.
46        # e.g. [True, False] -> [1, 0] -> '10' -> 2
47        state_ind = int(''.join([str(x) for x in [int(x) for x in meas]]), 2)
48        values[state_ind] += 1
49    return values
50
51
52def plot_state_histogram(
53    data: Union['result.Result', collections.Counter, Sequence[SupportsFloat]],
54    ax: Optional['plt.Axis'] = None,
55    *,
56    tick_label: Optional[Sequence[str]] = None,
57    xlabel: Optional[str] = 'qubit state',
58    ylabel: Optional[str] = 'result count',
59    title: Optional[str] = 'Result State Histogram',
60) -> 'plt.Axis':
61    """Plot the state histogram from either a single result with repetitions or
62       a histogram computed using `result.histogram()` or a flattened histogram
63       of measurement results computed using `get_state_histogram`.
64
65    Args:
66        data:   The histogram values to plot. Possible options are:
67                `result.Result`: Histogram is computed using
68                    `get_state_histogram` and all 2 ** num_qubits values are
69                    plotted, including 0s.
70                `collections.Counter`: Only (key, value) pairs present in
71                    collection are plotted.
72                `Sequence[SupportsFloat]`: Values in the input sequence are
73                    plotted. i'th entry corresponds to height of the i'th
74                    bar in histogram.
75        ax:      The Axes to plot on. If not given, a new figure is created,
76                 plotted on, and shown.
77        tick_label: Tick labels for the histogram plot in case input is not
78                    `collections.Counter`. By default, label for i'th entry
79                     is |i>.
80        xlabel:  Label for the x-axis.
81        ylabel:  Label for the y-axis.
82        title:   Title of the plot.
83
84    Returns:
85        The axis that was plotted on.
86    """
87    show_fig = not ax
88    if not ax:
89        fig, ax = plt.subplots(1, 1)
90    if isinstance(data, result.Result):
91        values = get_state_histogram(data)
92    elif isinstance(data, collections.Counter):
93        tick_label, values = zip(*sorted(data.items()))
94    else:
95        values = data
96    if not tick_label:
97        tick_label = np.arange(len(values))
98    ax.bar(np.arange(len(values)), values, tick_label=tick_label)
99    ax.set_xlabel(xlabel)
100    ax.set_ylabel(ylabel)
101    ax.set_title(title)
102    if show_fig:
103        fig.show()
104    return ax
105