1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines trial results."""
15
16import collections
17import io
18from typing import (
19    Any,
20    Callable,
21    Dict,
22    Iterable,
23    Optional,
24    Sequence,
25    TYPE_CHECKING,
26    Tuple,
27    TypeVar,
28    Union,
29    cast,
30)
31
32import numpy as np
33import pandas as pd
34
35from cirq import value, ops
36from cirq._compat import proper_repr
37from cirq.study import resolver
38
39if TYPE_CHECKING:
40    import cirq
41
42T = TypeVar('T')
43TMeasurementKey = Union[str, 'cirq.Qid', Iterable['cirq.Qid']]
44
45
46def _tuple_of_big_endian_int(bit_groups: Iterable[Any]) -> Tuple[int, ...]:
47    """Returns the big-endian integers specified by groups of bits.
48
49    Args:
50        bit_groups: Groups of descending bits, each specifying a big endian
51            integer with the 1s bit at the end.
52
53    Returns:
54        A tuple containing the integer for each group.
55    """
56    return tuple(value.big_endian_bits_to_int(bits) for bits in bit_groups)
57
58
59def _bitstring(vals: Iterable[Any]) -> str:
60    str_list = [str(int(v)) for v in vals]
61    separator = '' if all(len(s) == 1 for s in str_list) else ' '
62    return separator.join(str_list)
63
64
65def _keyed_repeated_bitstrings(vals: Dict[str, np.ndarray]) -> str:
66    keyed_bitstrings = []
67    for key in sorted(vals.keys()):
68        reps = vals[key]
69        n = 0 if len(reps) == 0 else len(reps[0])
70        all_bits = ', '.join(_bitstring(reps[:, i]) for i in range(n))
71        keyed_bitstrings.append(f'{key}={all_bits}')
72    return '\n'.join(keyed_bitstrings)
73
74
75def _key_to_str(key: TMeasurementKey) -> str:
76    if isinstance(key, str):
77        return key
78    if isinstance(key, ops.Qid):
79        return str(key)
80    return ','.join(str(q) for q in key)
81
82
83class Result:
84    """The results of multiple executions of a circuit with fixed parameters.
85    Stored as a Pandas DataFrame that can be accessed through the "data"
86    attribute. The repetition number is the row index and measurement keys
87    are the columns of the DataFrame. Each element is a big endian integer
88    representation of measurement outcomes for the measurement key in that
89    repetition.  See `cirq.big_endian_int_to_bits` and similar functions
90    for how to convert this integer into bits.
91
92    Attributes:
93        params: A ParamResolver of settings used when sampling result.
94    """
95
96    def __init__(
97        self,
98        *,  # Forces keyword args.
99        params: resolver.ParamResolver,
100        measurements: Dict[str, np.ndarray],
101    ) -> None:
102        """Inits Result.
103
104        Args:
105            params: A ParamResolver of settings used for this result.
106            measurements: A dictionary from measurement gate key to measurement
107                results. The value for each key is a 2-D array of booleans,
108                with the first index running over the repetitions, and the
109                second index running over the qubits for the corresponding
110                measurements.
111        """
112        self.params = params
113        self._data: Optional[pd.DataFrame] = None
114        self._measurements = measurements
115
116    @property
117    def data(self) -> pd.DataFrame:
118        if self._data is None:
119            # Convert to a DataFrame with columns as measurement keys, rows as
120            # repetitions and a big endian integer for individual measurements.
121            converted_dict = {}
122            for key, val in self._measurements.items():
123                converted_dict[key] = [value.big_endian_bits_to_int(m_vals) for m_vals in val]
124            # Note that when a numpy array is produced from this data frame,
125            # Pandas will try to use np.int64 as dtype, but will upgrade to
126            # object if any value is too large to fit.
127            self._data = pd.DataFrame(converted_dict, dtype=np.int64)
128        return self._data
129
130    @staticmethod
131    def from_single_parameter_set(
132        *,  # Forces keyword args.
133        params: resolver.ParamResolver,
134        measurements: Dict[str, np.ndarray],
135    ) -> 'Result':
136        """Packages runs of a single parameterized circuit into a Result.
137
138        Args:
139            params: A ParamResolver of settings used for this result.
140            measurements: A dictionary from measurement gate key to measurement
141                results. The value for each key is a 2-D array of booleans,
142                with the first index running over the repetitions, and the
143                second index running over the qubits for the corresponding
144                measurements.
145        """
146        return Result(params=params, measurements=measurements)
147
148    @property
149    def measurements(self) -> Dict[str, np.ndarray]:
150        return self._measurements
151
152    @property
153    def repetitions(self) -> int:
154        if not self.measurements:
155            return 0
156        # Get the length quickly from one of the keyed results.
157        return len(next(iter(self.measurements.values())))
158
159    # Reason for 'type: ignore': https://github.com/python/mypy/issues/5273
160    def multi_measurement_histogram(  # type: ignore
161        self,
162        *,  # Forces keyword args.
163        keys: Iterable[TMeasurementKey],
164        fold_func: Callable[[Tuple], T] = cast(Callable[[Tuple], T], _tuple_of_big_endian_int),
165    ) -> collections.Counter:
166        """Counts the number of times combined measurement results occurred.
167
168        This is a more general version of the 'histogram' method. Instead of
169        only counting how often results occurred for one specific measurement,
170        this method tensors multiple measurement results together and counts
171        how often the combined results occurred.
172
173        For example, suppose that:
174
175            - fold_func is not specified
176            - keys=['abc', 'd']
177            - the measurement with key 'abc' measures qubits a, b, and c.
178            - the measurement with key 'd' measures qubit d.
179            - the circuit was sampled 3 times.
180            - the sampled measurement values were:
181                1. a=1 b=0 c=0 d=0
182                2. a=0 b=1 c=0 d=1
183                3. a=1 b=0 c=0 d=0
184
185        Then the counter returned by this method will be:
186
187            collections.Counter({
188                (0b100, 0): 2,
189                (0b010, 1): 1
190            })
191
192
193        Where '0b100' is binary for '4' and '0b010' is binary for '2'. Notice
194        that the bits are combined in a big-endian way by default, with the
195        first measured qubit determining the highest-value bit.
196
197        Args:
198            fold_func: A function used to convert sampled measurement results
199                into countable values. The input is a tuple containing the
200                list of bits measured by each measurement specified by the
201                keys argument. If this argument is not specified, it defaults
202                to returning tuples of integers, where each integer is the big
203                endian interpretation of the bits a measurement sampled.
204            keys: Keys of measurements to include in the histogram.
205
206        Returns:
207            A counter indicating how often measurements sampled various
208            results.
209        """
210        fixed_keys = tuple(_key_to_str(key) for key in keys)
211        samples = zip(
212            *(self.measurements[sub_key] for sub_key in fixed_keys)
213        )  # type: Iterable[Any]
214        if len(fixed_keys) == 0:
215            samples = [()] * self.repetitions
216        c = collections.Counter()  # type: collections.Counter
217        for sample in samples:
218            c[fold_func(sample)] += 1
219        return c
220
221    # Reason for 'type: ignore': https://github.com/python/mypy/issues/5273
222    def histogram(  # type: ignore
223        self,
224        *,  # Forces keyword args.
225        key: TMeasurementKey,
226        fold_func: Callable[[Tuple], T] = cast(Callable[[Tuple], T], value.big_endian_bits_to_int),
227    ) -> collections.Counter:
228        """Counts the number of times a measurement result occurred.
229
230        For example, suppose that:
231
232            - fold_func is not specified
233            - key='abc'
234            - the measurement with key 'abc' measures qubits a, b, and c.
235            - the circuit was sampled 3 times.
236            - the sampled measurement values were:
237                1. a=1 b=0 c=0
238                2. a=0 b=1 c=0
239                3. a=1 b=0 c=0
240
241        Then the counter returned by this method will be:
242
243            collections.Counter({
244                0b100: 2,
245                0b010: 1
246            })
247
248        Where '0b100' is binary for '4' and '0b010' is binary for '2'. Notice
249        that the bits are combined in a big-endian way by default, with the
250        first measured qubit determining the highest-value bit.
251
252        Args:
253            key: Keys of measurements to include in the histogram.
254            fold_func: A function used to convert a sampled measurement result
255                into a countable value. The input is a list of bits sampled
256                together by a measurement. If this argument is not specified,
257                it defaults to interpreting the bits as a big endian
258                integer.
259
260        Returns:
261            A counter indicating how often a measurement sampled various
262            results.
263        """
264        return self.multi_measurement_histogram(keys=[key], fold_func=lambda e: fold_func(e[0]))
265
266    def __repr__(self) -> str:
267        def item_repr(entry):
268            key, val = entry
269            return f'{key!r}: {proper_repr(val)}'
270
271        measurement_dict_repr = (
272            '{' + ', '.join([item_repr(e) for e in self.measurements.items()]) + '}'
273        )
274
275        return f'cirq.Result(params={self.params!r}, measurements={measurement_dict_repr})'
276
277    def _repr_pretty_(self, p: Any, cycle: bool) -> None:
278        """Output to show in ipython and Jupyter notebooks."""
279        if cycle:
280            # There should never be a cycle.  This is just in case.
281            p.text('Result(...)')
282        else:
283            p.text(str(self))
284
285    def __str__(self) -> str:
286        return _keyed_repeated_bitstrings(self.measurements)
287
288    def __eq__(self, other):
289        if not isinstance(other, type(self)):
290            return NotImplemented
291        return self.data.equals(other.data) and self.params == other.params
292
293    def _measurement_shape(self):
294        return self.params, {k: v.shape[1] for k, v in self.measurements.items()}
295
296    def __add__(self, other: 'cirq.Result') -> 'cirq.Result':
297        if not isinstance(other, type(self)):
298            return NotImplemented
299        if self._measurement_shape() != other._measurement_shape():
300            raise ValueError(
301                'TrialResults do not have the same parameters or do '
302                'not have the same measurement keys.'
303            )
304        all_measurements: Dict[str, np.ndarray] = {}
305        for key in other.measurements:
306            all_measurements[key] = np.append(
307                self.measurements[key], other.measurements[key], axis=0
308            )
309        return Result(params=self.params, measurements=all_measurements)
310
311    def _json_dict_(self):
312        packed_measurements = {}
313        for key, digits in self.measurements.items():
314            packed_digits, binary = _pack_digits(digits)
315            packed_measurements[key] = {
316                'packed_digits': packed_digits,
317                'binary': binary,
318                'dtype': digits.dtype.name,
319                'shape': digits.shape,
320            }
321        return {
322            'cirq_type': self.__class__.__name__,
323            'params': self.params,
324            'measurements': packed_measurements,
325        }
326
327    @classmethod
328    def _from_json_dict_(cls, params, measurements, **kwargs):
329        return cls(
330            params=params,
331            measurements={key: _unpack_digits(**val) for key, val in measurements.items()},
332        )
333
334
335# TODO(#3388) Add documentation for Raises.
336# pylint: disable=missing-raises-doc
337def _pack_digits(digits: np.ndarray, pack_bits: str = 'auto') -> Tuple[str, bool]:
338    """Returns a string of packed digits and a boolean indicating whether the
339    digits were packed as binary values.
340
341    Args:
342        digits: A numpy array.
343        pack_bits: If 'auto' (the default), automatically pack binary digits
344            using `np.packbits` to save space. If 'never', do not pack binary
345            digits. If 'force', use `np.packbits` without checking for
346            compatibility.
347    """
348    # If digits are binary, pack them better to save space
349
350    if pack_bits == 'force':
351        return _pack_bits(digits), True
352    if pack_bits not in ['auto', 'never']:
353        raise ValueError("Please set `pack_bits` to 'auto', " "'force', or 'never'.")
354        # Do error checking here, otherwise the following logic will work
355        # for both "auto" and "never".
356
357    if pack_bits == 'auto' and np.array_equal(digits, digits.astype(np.bool_)):
358        return _pack_bits(digits.astype(np.bool_)), True
359
360    buffer = io.BytesIO()
361    np.save(buffer, digits, allow_pickle=False)
362    buffer.seek(0)
363    packed_digits = buffer.read().hex()
364    buffer.close()
365    return packed_digits, False
366
367
368# pylint: enable=missing-raises-doc
369def _pack_bits(bits: np.ndarray) -> str:
370    return np.packbits(bits).tobytes().hex()
371
372
373def _unpack_digits(
374    packed_digits: str, binary: bool, dtype: Union[None, str], shape: Union[None, Sequence[int]]
375) -> np.ndarray:
376    """The opposite of `_pack_digits`.
377
378    Args:
379        packed_digits: The hex-encoded string representing a numpy array of
380            digits. This is the first return value of `_pack_digits`.
381        binary: Whether the digits have been packed as binary. This is the
382            second return value of `_pack_digits`.
383        dtype: If `binary` is True, you must also provide the datatype of the
384            array. Otherwise, dtype information is contained within the hex
385            string.
386        shape: If `binary` is True, you must also provide the shape of the
387            array. Otherwise, shape information is contained within the hex
388            string.
389    """
390    if binary:
391        dtype = cast(str, dtype)
392        shape = cast(Sequence[int], shape)
393        return _unpack_bits(packed_digits, dtype, shape)
394
395    buffer = io.BytesIO()
396    buffer.write(bytes.fromhex(packed_digits))
397    buffer.seek(0)
398    digits = np.load(buffer, allow_pickle=False)
399    buffer.close()
400    return digits
401
402
403def _unpack_bits(packed_bits: str, dtype: str, shape: Sequence[int]) -> np.ndarray:
404    bits_bytes = bytes.fromhex(packed_bits)
405    bits = np.unpackbits(np.frombuffer(bits_bytes, dtype=np.uint8))
406    return bits[: np.prod(shape).item()].reshape(shape).astype(dtype)
407