1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17
18import cirq
19from cirq.protocols.apply_unitary_protocol import (
20    _incorporate_result_into_target,
21)
22
23
24def test_apply_unitary_presence_absence():
25    m = np.diag([1, -1])
26
27    class NoUnitaryEffect:
28        pass
29
30    class HasUnitary:
31        def _unitary_(self) -> np.ndarray:
32            return m
33
34    class HasApplyReturnsNotImplemented:
35        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):
36            return NotImplemented
37
38    class HasApplyReturnsNotImplementedButHasUnitary:
39        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):
40            return NotImplemented
41
42        def _unitary_(self) -> np.ndarray:
43            return m
44
45    class HasApplyOutputInBuffer:
46        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
47            zero = args.subspace_index(0)
48            one = args.subspace_index(1)
49            args.available_buffer[zero] = args.target_tensor[zero]
50            args.available_buffer[one] = -args.target_tensor[one]
51            return args.available_buffer
52
53    class HasApplyMutateInline:
54        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs) -> np.ndarray:
55            one = args.subspace_index(1)
56            args.target_tensor[one] *= -1
57            return args.target_tensor
58
59    fails = [
60        NoUnitaryEffect(),
61        HasApplyReturnsNotImplemented(),
62    ]
63    passes = [
64        HasUnitary(),
65        HasApplyReturnsNotImplementedButHasUnitary(),
66        HasApplyOutputInBuffer(),
67        HasApplyMutateInline(),
68    ]
69
70    def make_input():
71        return np.ones((2, 2))
72
73    def assert_works(val):
74        expected_outputs = [
75            np.array([1, 1, -1, -1]).reshape((2, 2)),
76            np.array([1, -1, 1, -1]).reshape((2, 2)),
77        ]
78        for axis in range(2):
79            result = cirq.apply_unitary(val, cirq.ApplyUnitaryArgs(make_input(), buf, [axis]))
80            np.testing.assert_allclose(result, expected_outputs[axis])
81
82    buf = np.empty(shape=(2, 2), dtype=np.complex128)
83
84    for f in fails:
85        with pytest.raises(TypeError, match='failed to satisfy'):
86            _ = cirq.apply_unitary(f, cirq.ApplyUnitaryArgs(make_input(), buf, [0]))
87        assert (
88            cirq.apply_unitary(f, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=None)
89            is None
90        )
91        assert (
92            cirq.apply_unitary(
93                f, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=NotImplemented
94            )
95            is NotImplemented
96        )
97        assert cirq.apply_unitary(f, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=1) == 1
98
99    for s in passes:
100        assert_works(s)
101        assert (
102            cirq.apply_unitary(s, cirq.ApplyUnitaryArgs(make_input(), buf, [0]), default=None)
103            is not None
104        )
105
106
107def test_apply_unitary_args_tensor_manipulation():
108    # All below are qubit swap operations with 1j global phase
109
110    class ModifyTargetTensor:
111        def _apply_unitary_(self, args):
112            zo = args.subspace_index(0b01)
113            oz = args.subspace_index(0b10)
114            args.available_buffer[zo] = args.target_tensor[zo]
115            args.target_tensor[zo] = args.target_tensor[oz]
116            args.target_tensor[oz] = args.available_buffer[zo]
117            args.target_tensor[...] *= 1j
118            args.available_buffer[...] = 99  # Destroy buffer data just in case
119            return args.target_tensor
120
121    class TransposeTargetTensor:
122        def _apply_unitary_(self, args):
123            indices = list(range(len(args.target_tensor.shape)))
124            indices[args.axes[0]], indices[args.axes[1]] = (
125                indices[args.axes[1]],
126                indices[args.axes[0]],
127            )
128            target = args.target_tensor.transpose(*indices)
129            target[...] *= 1j
130            args.available_buffer[...] = 99  # Destroy buffer data just in case
131            return target
132
133    class ReshapeTargetTensor:
134        def _apply_unitary_(self, args):
135            zz = args.subspace_index(0b00)
136            zo = args.subspace_index(0b01)
137            oz = args.subspace_index(0b10)
138            oo = args.subspace_index(0b11)
139            args.available_buffer[zz] = args.target_tensor[zz]
140            args.available_buffer[zo] = args.target_tensor[zo]
141            args.available_buffer[oz] = args.target_tensor[oz]
142            args.available_buffer[oo] = args.target_tensor[oo]
143            # Do a pointless reshape and transpose
144            target = args.target_tensor.transpose(
145                *range(1, len(args.target_tensor.shape)), 0
146            ).reshape(args.target_tensor.shape)
147            target[zz] = args.available_buffer[zz]
148            target[zo] = args.available_buffer[oz]
149            target[oz] = args.available_buffer[zo]
150            target[oo] = args.available_buffer[oo]
151            target[...] *= 1j
152            args.available_buffer[...] = 99  # Destroy buffer data just in case
153            return target
154
155    class ModifyAvailableBuffer:
156        def _apply_unitary_(self, args):
157            zz = args.subspace_index(0b00)
158            zo = args.subspace_index(0b01)
159            oz = args.subspace_index(0b10)
160            oo = args.subspace_index(0b11)
161            args.available_buffer[zz] = args.target_tensor[zz]
162            args.available_buffer[zo] = args.target_tensor[oz]
163            args.available_buffer[oz] = args.target_tensor[zo]
164            args.available_buffer[oo] = args.target_tensor[oo]
165            args.available_buffer[...] *= 1j
166            args.target_tensor[...] = 99  # Destroy buffer data just in case
167            return args.available_buffer
168
169    class TransposeAvailableBuffer:
170        def _apply_unitary_(self, args):
171            indices = list(range(len(args.target_tensor.shape)))
172            indices[args.axes[0]], indices[args.axes[1]] = (
173                indices[args.axes[1]],
174                indices[args.axes[0]],
175            )
176            output = args.available_buffer.transpose(*indices)
177            args.available_buffer[...] = args.target_tensor
178            output *= 1j
179            args.target_tensor[...] = 99  # Destroy buffer data just in case
180            return output
181
182    class ReshapeAvailableBuffer:
183        def _apply_unitary_(self, args):
184            zz = args.subspace_index(0b00)
185            zo = args.subspace_index(0b01)
186            oz = args.subspace_index(0b10)
187            oo = args.subspace_index(0b11)
188            # Do a pointless reshape and transpose
189            output = args.available_buffer.transpose(
190                *range(1, len(args.available_buffer.shape)), 0
191            ).reshape(args.available_buffer.shape)
192            output[zz] = args.target_tensor[zz]
193            output[zo] = args.target_tensor[oz]
194            output[oz] = args.target_tensor[zo]
195            output[oo] = args.target_tensor[oo]
196            output[...] *= 1j
197            args.target_tensor[...] = 99  # Destroy buffer data just in case
198            return output
199
200    class CreateNewBuffer:
201        def _apply_unitary_(self, args):
202            u = (
203                np.array(
204                    [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]],
205                    dtype=args.target_tensor.dtype,
206                )
207                * 1j
208            )  # yapf: disable
209            # Flatten last two axes and add a dummy index to the end of
210            # target_tensor so np.matmul treats it like an array of two-qubit
211            # column vectors.
212            new_shape = args.target_tensor.shape[:-2] + (4, 1)
213            ret = np.matmul(u, args.target_tensor.reshape(new_shape)).reshape(
214                args.target_tensor.shape
215            )
216            args.target_tensor[...] = 99  # Destroy buffer data just in case
217            args.available_buffer[...] = 98
218            return ret
219
220    operations = [
221        ModifyTargetTensor(),
222        TransposeTargetTensor(),
223        ReshapeTargetTensor(),
224        ModifyAvailableBuffer(),
225        TransposeAvailableBuffer(),
226        ReshapeAvailableBuffer(),
227        CreateNewBuffer(),
228    ]
229
230    def assert_is_swap_simple(val: cirq.SupportsConsistentApplyUnitary) -> None:
231        qid_shape = (2, 2)
232        op_indices = [0, 1]
233        state = np.arange(3 * 3, dtype=np.complex64).reshape((1, 3, 3))
234        expected = state.copy()
235        buf = expected[..., 0, 1].copy()
236        expected[..., 0, 1] = expected[..., 1, 0]
237        expected[..., 1, 0] = buf
238        expected[..., :2, :2] *= 1j
239
240        args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [1, 2])
241        sub_args = args._for_operation_with_qid_shape(
242            op_indices, tuple(qid_shape[i] for i in op_indices)
243        )
244        sub_result = val._apply_unitary_(sub_args)
245        result = _incorporate_result_into_target(args, sub_args, sub_result)
246        np.testing.assert_allclose(result, expected, atol=1e-8)
247
248    def assert_is_swap(val: cirq.SupportsConsistentApplyUnitary) -> None:
249        qid_shape = (1, 2, 4, 2)
250        op_indices = [1, 3]
251        state = np.arange(2 * (1 * 3 * 4 * 5), dtype=np.complex64).reshape((1, 2, 1, 5, 3, 1, 4))
252        expected = state.copy()
253        buf = expected[..., 0, 1, :, :].copy()
254        expected[..., 0, 1, :, :] = expected[..., 1, 0, :, :]
255        expected[..., 1, 0, :, :] = buf
256        expected[..., :2, :2, :, :] *= 1j
257
258        args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [5, 4, 6, 3])
259        sub_args = args._for_operation_with_qid_shape(
260            op_indices, tuple(qid_shape[i] for i in op_indices)
261        )
262        sub_result = val._apply_unitary_(sub_args)
263        result = _incorporate_result_into_target(args, sub_args, sub_result)
264        np.testing.assert_allclose(result, expected, atol=1e-8, verbose=True)
265
266    for op in operations:
267        assert_is_swap_simple(op)
268        assert_is_swap(op)
269
270
271def test_big_endian_subspace_index():
272    state = np.zeros(shape=(2, 3, 4, 5, 1, 6, 1, 1))
273    args = cirq.ApplyUnitaryArgs(state, np.empty_like(state), [1, 3])
274    s = slice(None)
275    assert args.subspace_index(little_endian_bits_int=1) == (s, 1, s, 0, s, s, s, s)
276    assert args.subspace_index(big_endian_bits_int=1) == (s, 0, s, 1, s, s, s, s)
277
278
279def test_apply_unitaries():
280    a, b, c = cirq.LineQubit.range(3)
281
282    result = cirq.apply_unitaries(
283        unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], qubits=[a, b, c]
284    )
285    np.testing.assert_allclose(
286        result.reshape(8),
287        [
288            np.sqrt(0.5),
289            0,
290            0,
291            0,
292            0,
293            0,
294            0.5,
295            0.5,
296        ],
297        atol=1e-8,
298    )
299
300    # Different order.
301    result = cirq.apply_unitaries(
302        unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)], qubits=[a, c, b]
303    )
304    np.testing.assert_allclose(
305        result.reshape(8),
306        [
307            np.sqrt(0.5),
308            0,
309            0,
310            0,
311            0,
312            0.5,
313            0,
314            0.5,
315        ],
316        atol=1e-8,
317    )
318
319    # Explicit arguments.
320    result = cirq.apply_unitaries(
321        unitary_values=[cirq.H(a), cirq.CNOT(a, b), cirq.H(c).controlled_by(b)],
322        qubits=[a, b, c],
323        args=cirq.ApplyUnitaryArgs.default(num_qubits=3),
324    )
325    np.testing.assert_allclose(
326        result.reshape(8),
327        [
328            np.sqrt(0.5),
329            0,
330            0,
331            0,
332            0,
333            0,
334            0.5,
335            0.5,
336        ],
337        atol=1e-8,
338    )
339
340    # Empty.
341    result = cirq.apply_unitaries(unitary_values=[], qubits=[])
342    np.testing.assert_allclose(result, [1])
343    result = cirq.apply_unitaries(unitary_values=[], qubits=[], default=None)
344    np.testing.assert_allclose(result, [1])
345
346    # Non-unitary operation.
347    with pytest.raises(TypeError, match='non-unitary'):
348        _ = cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a])
349    assert (
350        cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a], default=None)
351        is None
352    )
353    assert (
354        cirq.apply_unitaries(unitary_values=[cirq.depolarize(0.5).on(a)], qubits=[a], default=1)
355        == 1
356    )
357
358    # Inconsistent arguments.
359    with pytest.raises(ValueError, match='len'):
360        _ = cirq.apply_unitaries(
361            unitary_values=[], qubits=[], args=cirq.ApplyUnitaryArgs.default(1)
362        )
363
364
365def test_apply_unitaries_mixed_qid_shapes():
366    class PlusOneMod3Gate(cirq.SingleQubitGate):
367        def _qid_shape_(self):
368            return (3,)
369
370        def _unitary_(self):
371            return np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]])  # yapf: disable
372
373    class PlusOneMod4Gate(cirq.SingleQubitGate):
374        def _qid_shape_(self):
375            return (4,)
376
377        def _unitary_(self):
378            return np.array(
379                [[0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]
380            )  # yapf: disable
381
382    a, b = cirq.LineQid.for_qid_shape((3, 4))
383
384    result = cirq.apply_unitaries(
385        unitary_values=[
386            PlusOneMod3Gate().on(a.with_dimension(3)),
387            cirq.X(a.with_dimension(2)),
388            cirq.CNOT(a.with_dimension(2), b.with_dimension(2)),
389            cirq.CNOT(a.with_dimension(2), b.with_dimension(2)),
390            cirq.X(a.with_dimension(2)),
391            PlusOneMod3Gate().on(a.with_dimension(3)),
392            PlusOneMod3Gate().on(a.with_dimension(3)),
393        ],
394        qubits=[a, b],
395    )
396    np.testing.assert_allclose(result.reshape(12), [1] + [0] * 11, atol=1e-8)
397
398    result = cirq.apply_unitaries(
399        unitary_values=[
400            PlusOneMod3Gate().on(a.with_dimension(3)),
401            cirq.X(a.with_dimension(2)),
402            cirq.CNOT(a.with_dimension(2), b.with_dimension(2)),
403            cirq.CNOT(a.with_dimension(2), b.with_dimension(2)),
404            cirq.X(a.with_dimension(2)),
405            PlusOneMod3Gate().on(a.with_dimension(3)),
406            PlusOneMod3Gate().on(a.with_dimension(3)),
407        ],
408        qubits=[a, b],
409        args=cirq.ApplyUnitaryArgs(
410            target_tensor=cirq.eye_tensor((3, 4), dtype=np.complex64),
411            available_buffer=cirq.eye_tensor((3, 4), dtype=np.complex64),
412            axes=(0, 1),
413        ),
414    )
415    np.testing.assert_allclose(result.reshape(12, 12), np.eye(12), atol=1e-8)
416
417    result = cirq.apply_unitaries(
418        unitary_values=[
419            PlusOneMod3Gate().on(a.with_dimension(3)),
420            cirq.X(a.with_dimension(2)),
421            PlusOneMod4Gate().on(b.with_dimension(4)),
422            PlusOneMod4Gate().on(b.with_dimension(4)),
423            cirq.X(b.with_dimension(2)),
424            PlusOneMod4Gate().on(b.with_dimension(4)),
425            PlusOneMod4Gate().on(b.with_dimension(4)),
426            cirq.CNOT(a.with_dimension(2), b.with_dimension(2)),
427            PlusOneMod4Gate().on(b.with_dimension(4)),
428            cirq.X(b.with_dimension(2)),
429            cirq.CNOT(a.with_dimension(2), b.with_dimension(2)),
430            cirq.X(a.with_dimension(2)),
431            PlusOneMod3Gate().on(a.with_dimension(3)),
432            PlusOneMod3Gate().on(a.with_dimension(3)),
433        ],
434        qubits=[a, b],
435        args=cirq.ApplyUnitaryArgs(
436            target_tensor=cirq.eye_tensor((3, 4), dtype=np.complex64),
437            available_buffer=cirq.eye_tensor((3, 4), dtype=np.complex64),
438            axes=(0, 1),
439        ),
440    )
441    np.testing.assert_allclose(
442        result.reshape(12, 12),
443        np.array(
444            [
445                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
446                [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
447                [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
448                [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
449                [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
450                [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
451                [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
452                [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
453                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
454                [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
455                [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
456                [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
457            ]
458        ),
459        atol=1e-8,
460    )
461
462
463def test_incorporate_result_not_view():
464    tensor = np.zeros((2, 2))
465    tensor2 = np.zeros((2, 2))
466    buffer = np.empty_like(tensor)
467    args = cirq.ApplyUnitaryArgs(tensor, buffer, [0])
468    not_sub_args = cirq.ApplyUnitaryArgs(tensor2, buffer, [0])
469    with pytest.raises(ValueError, match='view'):
470        _incorporate_result_into_target(args, not_sub_args, tensor2)
471
472
473def test_default_method_arguments():
474    with pytest.raises(TypeError, match='exactly one of'):
475        cirq.ApplyUnitaryArgs.default(1, qid_shape=(2,))
476
477
478def test_apply_unitary_args_with_axes_transposed_to_start():
479    target = np.zeros((2, 3, 4, 5))
480    buffer = np.zeros((2, 3, 4, 5))
481    args = cirq.ApplyUnitaryArgs(target, buffer, [1, 3])
482
483    new_args = args.with_axes_transposed_to_start()
484    assert new_args.target_tensor.shape == (3, 5, 2, 4)
485    assert new_args.available_buffer.shape == (3, 5, 2, 4)
486    assert new_args.axes == (0, 1)
487
488    # Confirm aliasing.
489    new_args.target_tensor[2, 4, 1, 3] = 1
490    assert args.target_tensor[1, 2, 3, 4] == 1
491    new_args.available_buffer[2, 4, 1, 3] = 2
492    assert args.available_buffer[1, 2, 3, 4] == 2
493