1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16from typing import (
17    Any,
18    Callable,
19    Iterable,
20    Optional,
21    Tuple,
22    TypeVar,
23    TYPE_CHECKING,
24)
25
26from cirq.ops import raw_types
27
28if TYPE_CHECKING:
29    from cirq.ops import qubit_order_or_list
30
31
32TInternalQubit = TypeVar('TInternalQubit')
33TExternalQubit = TypeVar('TExternalQubit')
34
35
36class QubitOrder:
37    """Defines the kronecker product order of qubits."""
38
39    def __init__(
40        self, explicit_func: Callable[[Iterable[raw_types.Qid]], Tuple[raw_types.Qid, ...]]
41    ) -> None:
42        self._explicit_func = explicit_func
43
44    DEFAULT = None  # type: QubitOrder
45    """A basis that orders qubits in the same way that calling `sorted` does.
46
47    Specifically, qubits are ordered first by their type name and then by
48    whatever comparison value qubits of a given type provide (e.g. for LineQubit
49    it is the x coordinate of the qubit).
50    """
51
52    # TODO(#3388) Add documentation for Raises.
53    # pylint: disable=missing-raises-doc
54    @staticmethod
55    def explicit(
56        fixed_qubits: Iterable[raw_types.Qid], fallback: Optional['QubitOrder'] = None
57    ) -> 'QubitOrder':
58        """A basis that contains exactly the given qubits in the given order.
59
60        Args:
61            fixed_qubits: The qubits in basis order.
62            fallback: A fallback order to use for extra qubits not in the
63                fixed_qubits list. Extra qubits will always come after the
64                fixed_qubits, but will be ordered based on the fallback. If no
65                fallback is specified, a ValueError is raised when extra qubits
66                are specified.
67
68        Returns:
69            A Basis instance that forces the given qubits in the given order.
70        """
71        result = tuple(fixed_qubits)
72        if len(set(result)) < len(result):
73            raise ValueError(f'Qubits appear in fixed_order twice: {result}.')
74
75        def func(qubits):
76            remaining = set(qubits) - set(result)
77            if not remaining:
78                return result
79            if not fallback:
80                raise ValueError(f'Unexpected extra qubits: {remaining}.')
81            return result + fallback.order_for(remaining)
82
83        return QubitOrder(func)
84
85    # pylint: enable=missing-raises-doc
86    @staticmethod
87    def sorted_by(key: Callable[[raw_types.Qid], Any]) -> 'QubitOrder':
88        """A basis that orders qubits ascending based on a key function.
89
90        Args:
91            key: A function that takes a qubit and returns a key value. The
92                basis will be ordered ascending according to these key values.
93
94
95        Returns:
96            A basis that orders qubits ascending based on a key function.
97        """
98        return QubitOrder(lambda qubits: tuple(sorted(qubits, key=key)))
99
100    def order_for(self, qubits: Iterable[raw_types.Qid]) -> Tuple[raw_types.Qid, ...]:
101        """Returns a qubit tuple ordered corresponding to the basis.
102
103        Args:
104            qubits: Qubits that should be included in the basis. (Additional
105                qubits may be added into the output by the basis.)
106
107        Returns:
108            A tuple of qubits in the same order that their single-qubit
109            matrices would be passed into `np.kron` when producing a matrix for
110            the entire system.
111        """
112        return self._explicit_func(qubits)
113
114    # TODO(#3388) Add documentation for Raises.
115    # pylint: disable=missing-raises-doc
116    @staticmethod
117    def as_qubit_order(val: 'qubit_order_or_list.QubitOrderOrList') -> 'QubitOrder':
118        """Converts a value into a basis.
119
120        Args:
121            val: An iterable or a basis.
122
123        Returns:
124            The basis implied by the value.
125        """
126        if isinstance(val, Iterable):
127            return QubitOrder.explicit(val)
128        if isinstance(val, QubitOrder):
129            return val
130        raise ValueError(f"Don't know how to interpret <{val}> as a Basis.")
131
132    # pylint: enable=missing-raises-doc
133    def map(
134        self,
135        internalize: Callable[[TExternalQubit], TInternalQubit],
136        externalize: Callable[[TInternalQubit], TExternalQubit],
137    ) -> 'QubitOrder':
138        """Transforms the Basis so that it applies to wrapped qubits.
139
140        Args:
141            externalize: Converts an internal qubit understood by the underlying
142                basis into an external qubit understood by the caller.
143            internalize: Converts an external qubit understood by the caller
144                into an internal qubit understood by the underlying basis.
145
146        Returns:
147            A basis that transforms qubits understood by the caller into qubits
148            understood by an underlying basis, uses that to order the qubits,
149            then wraps the ordered qubits back up for the caller.
150        """
151
152        def func(qubits):
153            unwrapped_qubits = [internalize(q) for q in qubits]
154            unwrapped_result = self.order_for(unwrapped_qubits)
155            return tuple(externalize(q) for q in unwrapped_result)
156
157        return QubitOrder(func)
158
159
160QubitOrder.DEFAULT = QubitOrder.sorted_by(lambda v: v)
161