1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Callable, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union
16
17import numpy as np
18
19from cirq import protocols, value
20from cirq.ops import raw_types, pauli_string
21from cirq.ops.measurement_gate import MeasurementGate
22from cirq.ops.pauli_measurement_gate import PauliMeasurementGate
23
24if TYPE_CHECKING:
25    import cirq
26
27
28def _default_measurement_key(qubits: Iterable[raw_types.Qid]) -> str:
29    return ','.join(str(q) for q in qubits)
30
31
32def measure_single_paulistring(
33    pauli_observable: pauli_string.PauliString,
34    key: Optional[Union[str, value.MeasurementKey]] = None,
35) -> raw_types.Operation:
36    """Returns a single PauliMeasurementGate which measures the pauli observable
37
38    Args:
39        pauli_observable: The `cirq.PauliString` observable to measure.
40        key: Optional `str` or `cirq.MeasurementKey` that gate should use.
41            If none provided, it defaults to a comma-separated list of the
42            target qubits' str values.
43
44    Returns:
45        An operation measuring the pauli observable.
46
47    Raises:
48        ValueError: if the observable is not an instance of PauliString.
49    """
50    if not isinstance(pauli_observable, pauli_string.PauliString):
51        raise ValueError(
52            f'Pauli observable {pauli_observable} should be an instance of cirq.PauliString.'
53        )
54    if key is None:
55        key = _default_measurement_key(pauli_observable)
56    return PauliMeasurementGate(pauli_observable.values(), key).on(*pauli_observable.keys())
57
58
59def measure_paulistring_terms(
60    pauli_basis: pauli_string.PauliString, key_func: Callable[[raw_types.Qid], str] = str
61) -> List[raw_types.Operation]:
62    """Returns a list of operations individually measuring qubits in the pauli basis.
63
64    Args:
65        pauli_basis: The `cirq.PauliString` basis in which each qubit should
66            be measured.
67        key_func: Determines the key of the measurements of each qubit. Takes
68            the qubit and returns the key for that qubit. Defaults to str.
69
70    Returns:
71        A list of operations individually measuring the given qubits in the
72        specified pauli basis.
73
74    Raises:
75        ValueError: if `pauli_basis` is not an instance of `cirq.PauliString`.
76    """
77    if not isinstance(pauli_basis, pauli_string.PauliString):
78        raise ValueError(
79            f'Pauli observable {pauli_basis} should be an instance of cirq.PauliString.'
80        )
81    return [PauliMeasurementGate([pauli_basis[q]], key=key_func(q)).on(q) for q in pauli_basis]
82
83
84def measure(
85    *target: 'cirq.Qid',
86    key: Optional[Union[str, value.MeasurementKey]] = None,
87    invert_mask: Tuple[bool, ...] = (),
88) -> raw_types.Operation:
89    """Returns a single MeasurementGate applied to all the given qubits.
90
91    The qubits are measured in the computational basis.
92
93    Args:
94        *target: The qubits that the measurement gate should measure.
95        key: The string key of the measurement. If this is None, it defaults
96            to a comma-separated list of the target qubits' str values.
97        invert_mask: A list of Truthy or Falsey values indicating whether
98            the corresponding qubits should be flipped. None indicates no
99            inverting should be done.
100
101    Returns:
102        An operation targeting the given qubits with a measurement.
103
104    Raises:
105        ValueError: If the qubits are not instances of Qid.
106    """
107    for qubit in target:
108        if isinstance(qubit, np.ndarray):
109            raise ValueError(
110                'measure() was called a numpy ndarray. Perhaps you meant '
111                'to call measure_state_vector on numpy array?'
112            )
113        elif not isinstance(qubit, raw_types.Qid):
114            raise ValueError('measure() was called with type different than Qid.')
115
116    if key is None:
117        key = _default_measurement_key(target)
118    qid_shape = protocols.qid_shape(target)
119    return MeasurementGate(len(target), key, invert_mask, qid_shape).on(*target)
120
121
122def measure_each(
123    *qubits: 'cirq.Qid', key_func: Callable[[raw_types.Qid], str] = str
124) -> List[raw_types.Operation]:
125    """Returns a list of operations individually measuring the given qubits.
126
127    The qubits are measured in the computational basis.
128
129    Args:
130        *qubits: The qubits to measure.
131        key_func: Determines the key of the measurements of each qubit. Takes
132            the qubit and returns the key for that qubit. Defaults to str.
133
134    Returns:
135        A list of operations individually measuring the given qubits.
136    """
137    return [MeasurementGate(1, key_func(q), qid_shape=(q.dimension,)).on(q) for q in qubits]
138