1import itertools
2from typing import (
3    Any,
4    Dict,
5    Iterable,
6    List,
7    Mapping,
8    Optional,
9    Union,
10)
11
12import numpy as np
13from scipy.sparse import csr_matrix
14
15from cirq import value
16from cirq.ops import raw_types
17
18
19def _check_qids_dimension(qids):
20    """A utility to check that we only have Qubits."""
21    for qid in qids:
22        if qid.dimension != 2:
23            raise ValueError(f"Only qubits are supported, but {qid} has dimension {qid.dimension}")
24
25
26@value.value_equality(approximate=True)
27class ProjectorString:
28    def __init__(
29        self,
30        projector_dict: Dict[raw_types.Qid, int],
31        coefficient: Union[int, float, complex] = 1,
32    ):
33        """Contructor for ProjectorString
34
35        Args:
36            projector_dict: A python dictionary mapping from cirq.Qid to integers. A key value pair
37                represents the desired computational basis state for that qubit.
38            coefficient: Initial scalar coefficient. Defaults to 1.
39        """
40        _check_qids_dimension(projector_dict.keys())
41        self._projector_dict = projector_dict
42        self._coefficient = complex(coefficient)
43
44    @property
45    def projector_dict(self) -> Dict[raw_types.Qid, int]:
46        return self._projector_dict
47
48    @property
49    def coefficient(self) -> complex:
50        return self._coefficient
51
52    def matrix(self, projector_qids: Optional[Iterable[raw_types.Qid]] = None) -> csr_matrix:
53        """Returns the matrix of self in computational basis of qubits.
54
55        Args:
56            projector_qids: Ordered collection of qubits that determine the subspace
57                in which the matrix representation of the ProjectorString is to
58                be computed. Qbits absent from self.qubits are acted on by
59                the identity. Defaults to the qubits of the projector_dict.
60
61        Returns:
62            A sparse matrix that is the projection in the specified basis.
63        """
64        projector_qids = self._projector_dict.keys() if projector_qids is None else projector_qids
65        _check_qids_dimension(projector_qids)
66        idx_to_keep = [
67            [self._projector_dict[qid]] if qid in self._projector_dict else [0, 1]
68            for qid in projector_qids
69        ]
70
71        total_d = np.prod([qid.dimension for qid in projector_qids], dtype=np.int64)
72
73        ones_idx = []
74        for idx in itertools.product(*idx_to_keep):
75            d = total_d
76            kron_idx = 0
77            for i, qid in zip(idx, projector_qids):
78                d //= qid.dimension
79                kron_idx += i * d
80            ones_idx.append(kron_idx)
81
82        return csr_matrix(
83            ([self._coefficient] * len(ones_idx), (ones_idx, ones_idx)), shape=(total_d, total_d)
84        )
85
86    def _get_idx_to_keep(self, qid_map: Mapping[raw_types.Qid, int]):
87        num_qubits = len(qid_map)
88        idx_to_keep: List[Any] = [slice(0, 2)] * num_qubits
89        for q in self.projector_dict.keys():
90            idx_to_keep[qid_map[q]] = self.projector_dict[q]
91        return tuple(idx_to_keep)
92
93    def expectation_from_state_vector(
94        self,
95        state_vector: np.ndarray,
96        qid_map: Mapping[raw_types.Qid, int],
97    ) -> complex:
98        """Expectation of the projection from a state vector.
99
100        Computes the expectation value of this ProjectorString on the provided state vector.
101
102        Args:
103            state_vector: An array representing a valid state vector.
104            qid_map: A map from all qubits used in this ProjectorString to the
105                indices of the qubits that `state_vector` is defined over.
106
107        Returns:
108            The expectation value of the input state.
109        """
110        _check_qids_dimension(qid_map.keys())
111        num_qubits = len(qid_map)
112        index = self._get_idx_to_keep(qid_map)
113        return self._coefficient * np.sum(
114            np.abs(np.reshape(state_vector, (2,) * num_qubits)[index]) ** 2
115        )
116
117    def expectation_from_density_matrix(
118        self,
119        state: np.ndarray,
120        qid_map: Mapping[raw_types.Qid, int],
121    ) -> complex:
122        """Expectation of the projection from a density matrix.
123
124        Computes the expectation value of this ProjectorString on the provided state.
125
126        Args:
127            state: An array representing a valid  density matrix.
128            qid_map: A map from all qubits used in this ProjectorString to the
129                indices of the qubits that `state_vector` is defined over.
130
131        Returns:
132            The expectation value of the input state.
133        """
134        _check_qids_dimension(qid_map.keys())
135        num_qubits = len(qid_map)
136        index = self._get_idx_to_keep(qid_map) * 2
137        result = np.reshape(state, (2,) * (2 * num_qubits))[index]
138        while any(result.shape):
139            result = np.trace(result, axis1=0, axis2=len(result.shape) // 2)
140        return self._coefficient * result
141
142    def __repr__(self) -> str:
143        return (
144            f"cirq.ProjectorString(projector_dict={self._projector_dict},"
145            + f"coefficient={self._coefficient})"
146        )
147
148    def _json_dict_(self) -> Dict[str, Any]:
149        return {
150            'cirq_type': self.__class__.__name__,
151            'projector_dict': list(self._projector_dict.items()),
152            'coefficient': self._coefficient,
153        }
154
155    @classmethod
156    def _from_json_dict_(cls, projector_dict, coefficient, **kwargs):
157        return cls(projector_dict=dict(projector_dict), coefficient=coefficient)
158
159    def _value_equality_values_(self) -> Any:
160        projector_dict = sorted(self._projector_dict.items())
161        return (tuple(projector_dict), self._coefficient)
162