1# Copyright 2020 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Union, Tuple, Sequence, List, Optional
16
17import numpy as np
18
19import cirq
20from cirq import ops
21from cirq import optimizers as opt
22
23
24# TODO(#3388) Add documentation for Raises.
25# pylint: disable=missing-raises-doc
26def three_qubit_matrix_to_operations(
27    q0: ops.Qid, q1: ops.Qid, q2: ops.Qid, u: np.ndarray, atol: float = 1e-8
28) -> Sequence[ops.Operation]:
29    """Returns operations for a 3 qubit unitary.
30
31    The algorithm is described in Shende et al.:
32    Synthesis of Quantum Logic Circuits. Tech. rep. 2006,
33    https://arxiv.org/abs/quant-ph/0406176
34
35    Args:
36        q0: first qubit
37        q1: second qubit
38        q2: third qubit
39        u: unitary matrix
40        atol: A limit on the amount of absolute error introduced by the
41            construction.
42
43    Returns:
44        The resulting operations will have only known two-qubit and one-qubit
45        gates based operations, namely CZ, CNOT and rx, ry, PhasedXPow gates.
46
47    Raises:
48        ValueError: If the u matrix is non-unitary or not of shape (8,8).
49    """
50    if np.shape(u) != (8, 8):
51        raise ValueError(f"Expected unitary matrix with shape (8,8) got {np.shape(u)}")
52    if not cirq.is_unitary(u, atol=atol):
53        raise ValueError(f"Matrix is not unitary: {u}")
54
55    try:
56        from scipy.linalg import cossin
57    except ImportError:  # coverage: ignore
58        # coverage: ignore
59        raise ImportError(
60            "cirq.three_qubit_unitary_to_operations requires "
61            "SciPy 1.5.0+, as it uses the cossin function. Please"
62            " upgrade scipy in your environment to use this "
63            "function!"
64        )
65    (u1, u2), theta, (v1h, v2h) = cossin(u, 4, 4, separate=True)
66
67    cs_ops = _cs_to_ops(q0, q1, q2, theta)
68    if len(cs_ops) > 0 and cs_ops[-1] == cirq.CZ(q2, q0):
69        # optimization A.1 - merging the last CZ from the end of CS into UD
70        # cz = cirq.Circuit([cs_ops[-1]]).unitary()
71        # CZ(c,a) = CZ(a,c) as CZ is symmetric
72        # for the u1⊕u2 multiplexor operator:
73        # as u1(b,c) is the operator in case a = \0>,
74        # and u2(b,c) is the operator for (b,c) in case a = |1>
75        # we can represent the merge by phasing u2 with I ⊗ Z
76        u2 = u2 @ np.diag([1, -1, 1, -1])
77        cs_ops = cs_ops[:-1]
78
79    d_ud, ud_ops = _two_qubit_multiplexor_to_ops(q0, q1, q2, u1, u2, shift_left=True, atol=atol)
80
81    _, vdh_ops = _two_qubit_multiplexor_to_ops(
82        q0, q1, q2, v1h, v2h, shift_left=False, diagonal=d_ud, atol=atol
83    )
84
85    return list(cirq.Circuit(vdh_ops + cs_ops + ud_ops).all_operations())
86
87
88# pylint: enable=missing-raises-doc
89def _cs_to_ops(q0: ops.Qid, q1: ops.Qid, q2: ops.Qid, theta: np.ndarray) -> List[ops.Operation]:
90    """Converts theta angles based Cosine Sine matrix to operations.
91
92    Using the optimization as per Appendix A.1, it uses CZ gates instead of
93    CNOT gates and returns a circuit that skips the terminal CZ gate.
94
95    Args:
96        q0: first qubit
97        q1: second qubit
98        q2: third qubit
99        theta: theta returned from the Cosine Sine decomposition
100
101    Returns:
102         the operations
103    """
104    # Note: we are using *2 as the thetas are already half angles from the
105    # CSD decomposition, but cirq.ry takes full angles.
106    angles = _multiplexed_angles(theta * 2)
107    rys = [cirq.ry(angle).on(q0) for angle in angles]
108    ops = [
109        rys[0],
110        cirq.CZ(q1, q0),
111        rys[1],
112        cirq.CZ(q2, q0),
113        rys[2],
114        cirq.CZ(q1, q0),
115        rys[3],
116        cirq.CZ(q2, q0),
117    ]
118    return _optimize_multiplexed_angles_circuit(ops)
119
120
121# TODO(#3388) Add documentation for Args.
122# pylint: disable=missing-param-doc
123def _two_qubit_multiplexor_to_ops(
124    q0: ops.Qid,
125    q1: ops.Qid,
126    q2: ops.Qid,
127    u1: np.ndarray,
128    u2: np.ndarray,
129    shift_left: bool = True,
130    diagonal: Optional[np.ndarray] = None,
131    atol: float = 1e-8,
132) -> Tuple[Optional[np.ndarray], List[ops.Operation]]:
133    r"""Converts a two qubit double multiplexor to circuit.
134    Input: U_1 ⊕ U_2, with select qubit a (i.e. a = |0> => U_1(b,c),
135    a = |1> => U_2(b,c).
136
137    We want this:
138        $$
139        U_1 ⊕ U_2 = (V ⊕ V) @ (D ⊕ D^{\dagger}) @ (W ⊕ W)
140        $$
141    We can get it via:
142        $$
143        U_1 = V @ D @ W       (1)
144        U_2 = V @ D^{\dagger} @ W (2)
145        $$
146
147    We can derive
148        $$
149        U_1 U_2^{\dagger}= V @ D^2 @ V^{\dagger}, (3)
150        $$
151
152    i.e the eigendecomposition of $U_1 U_2^{\dagger}$ will give us D and V.
153    W is easy to derive from (2).
154
155    This function, after calculating V, D and W, also returns the circuit that
156    implements these unitaries: V, W on qubits b, c and the middle diagonal
157    multiplexer on a,b,c qubits.
158
159    The resulting circuit will have only known two-qubit and one-qubit gates,
160    namely CZ, CNOT and rx, ry, PhasedXPow gates.
161
162    Args:
163        q0: first qubit
164        q1: second qubit
165        q2: third qubit
166        u1: two-qubit operation on b,c for a = |0>
167        u2: two-qubit operation on b,c for a = |1>
168        shift_left: return the extracted diagonal or not
169        diagonal: an incoming diagonal to be merged with
170    Returns:
171        The circuit implementing the two qubit multiplexor consisting only of
172        known two-qubit and single qubit gates
173    """
174    u1u2 = u1 @ u2.conj().T
175    eigvals, v = cirq.unitary_eig(u1u2)
176    d = np.diag(np.sqrt(eigvals))
177
178    w = d @ v.conj().T @ u2
179
180    circuit_u1u2_mid = _middle_multiplexor_to_ops(q0, q1, q2, eigvals)
181
182    if diagonal is not None:
183        v = diagonal @ v
184
185    d_v, circuit_u1u2_r = opt.two_qubit_matrix_to_diagonal_and_operations(q1, q2, v, atol=atol)
186
187    w = d_v @ w
188
189    # if it's interesting to extract the diagonal then let's do it
190    if shift_left:
191        d_w, circuit_u1u2_l = opt.two_qubit_matrix_to_diagonal_and_operations(q1, q2, w, atol=atol)
192    # if we are at the end of the circuit, then just fall back to KAK
193    else:
194        d_w = None
195        circuit_u1u2_l = opt.two_qubit_matrix_to_operations(
196            q1, q2, w, allow_partial_czs=False, atol=atol
197        )
198
199    return d_w, circuit_u1u2_l + circuit_u1u2_mid + circuit_u1u2_r
200
201
202# pylint: enable=missing-param-doc
203def _optimize_multiplexed_angles_circuit(operations: Sequence[ops.Operation]):
204    """Removes two qubit gates that amount to identity.
205    Exploiting the specific multiplexed structure, this methods looks ahead
206    to find stripes of 3 or 4 consecutive CZ or CNOT gates and removes them.
207
208    Args:
209        operations: operations to be optimized
210    Returns:
211        the optimized operations
212    """
213    circuit = cirq.Circuit(operations)
214    cirq.optimizers.DropNegligible().optimize_circuit(circuit)
215    if np.allclose(circuit.unitary(), np.eye(8), atol=1e-14):
216        return cirq.Circuit([])
217
218    # the only way we can get identity here is if all four CZs are
219    # next to each other
220    def num_conseq_2qbit_gates(i):
221        j = i
222        while j < len(operations) and operations[j].gate.num_qubits() == 2:
223            j += 1
224        return j - i
225
226    operations = list(circuit.all_operations())
227
228    i = 0
229    while i < len(operations):
230        num_czs = num_conseq_2qbit_gates(i)
231        if num_czs == 4:
232            operations = operations[:1]
233            break
234        elif num_czs == 3:
235            operations = operations[:i] + [operations[i + 1]] + operations[i + 3 :]
236            break
237        else:
238            i += 1
239    return operations
240
241
242def _middle_multiplexor_to_ops(q0: ops.Qid, q1: ops.Qid, q2: ops.Qid, eigvals: np.ndarray):
243    theta = np.real(np.log(np.sqrt(eigvals)) * 1j * 2)
244    angles = _multiplexed_angles(theta)
245    rzs = [cirq.rz(angle).on(q0) for angle in angles]
246    ops = [
247        rzs[0],
248        cirq.CNOT(q1, q0),
249        rzs[1],
250        cirq.CNOT(q2, q0),
251        rzs[2],
252        cirq.CNOT(q1, q0),
253        rzs[3],
254        cirq.CNOT(q2, q0),
255    ]
256    return _optimize_multiplexed_angles_circuit(ops)
257
258
259def _multiplexed_angles(theta: Union[Sequence[float], np.ndarray]) -> np.ndarray:
260    """Calculates the angles for a 4-way multiplexed rotation.
261
262    For example, if we want rz(theta[i]) if the select qubits are in state
263    |i>, then, multiplexed_angles returns a[i] that can be used in a circuit
264    similar to this:
265
266    ---rz(a[0])-X---rz(a[1])--X--rz(a[2])-X--rz(a[3])--X
267                |             |           |            |
268    ------------@-------------|-----------@------------|
269                              |                        |
270    --------------------------@------------------------@
271
272    Args:
273        theta: the desired angles for each basis state of the select qubits
274    Returns:
275        the angles to be used in actual rotations in the circuit implementation
276    """
277    return (
278        np.array(
279            [
280                (theta[0] + theta[1] + theta[2] + theta[3]),
281                (theta[0] + theta[1] - theta[2] - theta[3]),
282                (theta[0] - theta[1] - theta[2] + theta[3]),
283                (theta[0] - theta[1] + theta[2] - theta[3]),
284            ]
285        )
286        / 4
287    )
288