1# Copyright 2020 The Cirq developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import dataclasses
16from typing import Union, Iterable, Dict, TYPE_CHECKING, ItemsView, Tuple, FrozenSet
17
18from cirq import ops, value, protocols
19
20if TYPE_CHECKING:
21    import cirq
22    from cirq.value.product_state import _NamedOneQubitState
23
24
25@dataclasses.dataclass(frozen=True)
26class InitObsSetting:
27    """A pair of initial state and observable.
28
29    Usually, given a circuit you want to iterate through many
30    InitObsSettings to vary the initial state preparation and output
31    observable.
32    """
33
34    init_state: value.ProductState
35    observable: ops.PauliString
36
37    def __post_init__(self):
38        # Special validation for this dataclass.
39        init_qs = self.init_state.qubits
40        obs_qs = self.observable.qubits
41        if set(obs_qs) > set(init_qs):
42            raise ValueError(
43                "`observable`'s qubits should be a subset of those "
44                "found in `init_state`. "
45                "observable qubits: {}. init_state qubits: {}".format(obs_qs, init_qs)
46            )
47
48    def __str__(self):
49        return f'{self.init_state} → {self.observable}'
50
51    def __repr__(self):
52        return (
53            f'cirq.work.InitObsSetting('
54            f'init_state={self.init_state!r}, '
55            f'observable={self.observable!r})'
56        )
57
58    def _json_dict_(self):
59        return protocols.dataclass_json_dict(self)
60
61
62def _max_weight_observable(observables: Iterable[ops.PauliString]) -> Union[None, ops.PauliString]:
63    """Create a new observable that is compatible with all input observables
64    and has the maximum non-identity elements.
65
66    The returned PauliString is constructed by taking the non-identity
67    single-qubit Pauli at each qubit position.
68
69    This function will return `None` if the input observables do not share a
70    tensor product basis.
71
72    For example, the _max_weight_observable of ["XI", "IZ"] is "XZ". Asking for
73    the max weight observable of something like ["XI", "ZI"] will return None.
74
75    The returned value need not actually be present in the input observables.
76    Coefficients from input observables will be dropped.
77    """
78    qubit_pauli_map = dict()  # type: Dict[ops.Qid, ops.Pauli]
79    for observable in observables:
80        for qubit, pauli in observable.items():
81            if qubit in qubit_pauli_map:
82                if qubit_pauli_map[qubit] != pauli:
83                    return None
84            else:
85                qubit_pauli_map[qubit] = pauli
86    return ops.PauliString(qubit_pauli_map)
87
88
89def _max_weight_state(states: Iterable[value.ProductState]) -> Union[None, value.ProductState]:
90    """Create a new state that is compatible with all input states
91    and has the maximum weight.
92
93    The returned TensorProductState is constructed by taking the
94    single-qubit state at each qubit position.
95
96    This function will return `None` if the input states are not compatible
97
98    For example, the max_weight_state of [+X(0), -Z(1)] is
99    "+X(0) * -Z(1)". Asking for the max weight state of something like
100    [+X(0), +Z(0)] will return None.
101    """
102    qubit_state_map = dict()  # type: Dict[ops.Qid, _NamedOneQubitState]
103    for state in states:
104        for qubit, named_state in state:
105            if qubit in qubit_state_map:
106                if qubit_state_map[qubit] != named_state:
107                    return None
108            else:
109                qubit_state_map[qubit] = named_state
110    return value.ProductState(qubit_state_map)
111
112
113def zeros_state(qubits: Iterable['cirq.Qid']):
114    """Return the ProductState that is |00..00> on all qubits."""
115    return value.ProductState({q: value.KET_ZERO for q in qubits})
116
117
118def observables_to_settings(
119    observables: Iterable['cirq.PauliString'], qubits: Iterable['cirq.Qid']
120) -> Iterable[InitObsSetting]:
121    """Transform an observable to an InitObsSetting initialized in the
122    all-zeros state.
123    """
124    for observable in observables:
125        yield InitObsSetting(init_state=zeros_state(qubits), observable=observable)
126
127
128def _fix_precision(val: float, precision) -> int:
129    """Convert floating point numbers to (implicitly) fixed point integers.
130
131    Circuit parameters can be floats but we also need to use them as
132    dictionary keys. We secretly use these fixed-precision integers.
133    """
134    return int(val * precision)
135
136
137def _hashable_param(
138    param_tuples: ItemsView[str, float], precision=1e7
139) -> FrozenSet[Tuple[str, float]]:
140    """Hash circuit parameters using fixed precision.
141
142    Circuit parameters can be floats but we also need to use them as
143    dictionary keys. We secretly use these fixed-precision integers.
144    """
145    return frozenset((k, _fix_precision(v, precision)) for k, v in param_tuples)
146
147
148@dataclasses.dataclass(frozen=True)
149class _MeasurementSpec:
150    """An encapsulation of all the specifications for one run of a
151    quantum processor.
152
153    This includes the maximal input-output setting (which may result in many
154    observables being measured if they are consistent with `max_setting`) and
155    a set of circuit parameters if the circuit is parameterized.
156    """
157
158    max_setting: InitObsSetting
159    circuit_params: Dict[str, float]
160
161    def __hash__(self):
162        return hash((self.max_setting, _hashable_param(self.circuit_params.items())))
163
164    def __repr__(self):
165        return (
166            f'cirq.work._MeasurementSpec(max_setting={self.max_setting!r}, '
167            f'circuit_params={self.circuit_params!r})'
168        )
169
170    def _json_dict_(self):
171        return protocols.dataclass_json_dict(self)
172