1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from unittest import mock
16
17import numpy as np
18import pytest
19
20import cirq
21
22
23def test_decomposed_fallback():
24    class Composite(cirq.Gate):
25        def num_qubits(self) -> int:
26            return 1
27
28        def _decompose_(self, qubits):
29            yield cirq.X(*qubits)
30
31    args = cirq.ActOnStateVectorArgs(
32        target_tensor=cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64),
33        available_buffer=np.empty((2, 2, 2), dtype=np.complex64),
34        qubits=cirq.LineQubit.range(3),
35        prng=np.random.RandomState(),
36        log_of_measurement_results={},
37    )
38
39    cirq.act_on(Composite(), args, [cirq.LineQubit(1)])
40    np.testing.assert_allclose(
41        args.target_tensor, cirq.one_hot(index=(0, 1, 0), shape=(2, 2, 2), dtype=np.complex64)
42    )
43
44
45def test_cannot_act():
46    class NoDetails:
47        pass
48
49    args = cirq.ActOnStateVectorArgs(
50        target_tensor=cirq.one_hot(shape=(2, 2, 2), dtype=np.complex64),
51        available_buffer=np.empty((2, 2, 2), dtype=np.complex64),
52        qubits=cirq.LineQubit.range(3),
53        prng=np.random.RandomState(),
54        log_of_measurement_results={},
55    )
56
57    with pytest.raises(TypeError, match="Can't simulate operations"):
58        cirq.act_on(NoDetails(), args, qubits=())
59
60
61def test_act_using_probabilistic_single_qubit_channel():
62    class ProbabilisticSorX(cirq.Gate):
63        def num_qubits(self) -> int:
64            return 1
65
66        def _kraus_(self):
67            return [
68                cirq.unitary(cirq.S) * np.sqrt(1 / 3),
69                cirq.unitary(cirq.X) * np.sqrt(2 / 3),
70            ]
71
72    initial_state = cirq.testing.random_superposition(dim=16).reshape((2,) * 4)
73    mock_prng = mock.Mock()
74
75    mock_prng.random.return_value = 1 / 3 + 1e-6
76    args = cirq.ActOnStateVectorArgs(
77        target_tensor=np.copy(initial_state),
78        available_buffer=np.empty_like(initial_state),
79        qubits=cirq.LineQubit.range(4),
80        prng=mock_prng,
81        log_of_measurement_results={},
82    )
83    cirq.act_on(ProbabilisticSorX(), args, [cirq.LineQubit(2)])
84    np.testing.assert_allclose(
85        args.target_tensor.reshape(16),
86        cirq.final_state_vector(
87            cirq.X(cirq.LineQubit(2)) ** -1,
88            initial_state=initial_state,
89            qubit_order=cirq.LineQubit.range(4),
90        ),
91        atol=1e-8,
92    )
93
94    mock_prng.random.return_value = 1 / 3 - 1e-6
95    args = cirq.ActOnStateVectorArgs(
96        target_tensor=np.copy(initial_state),
97        available_buffer=np.empty_like(initial_state),
98        qubits=cirq.LineQubit.range(4),
99        prng=mock_prng,
100        log_of_measurement_results={},
101    )
102    cirq.act_on(ProbabilisticSorX(), args, [cirq.LineQubit(2)])
103    np.testing.assert_allclose(
104        args.target_tensor.reshape(16),
105        cirq.final_state_vector(
106            cirq.S(cirq.LineQubit(2)),
107            initial_state=initial_state,
108            qubit_order=cirq.LineQubit.range(4),
109        ),
110        atol=1e-8,
111    )
112
113
114def test_act_using_adaptive_two_qubit_channel():
115    class Decay11(cirq.Gate):
116        def num_qubits(self) -> int:
117            return 2
118
119        def _kraus_(self):
120            bottom_right = cirq.one_hot(index=(3, 3), shape=(4, 4), dtype=np.complex64)
121            top_right = cirq.one_hot(index=(0, 3), shape=(4, 4), dtype=np.complex64)
122            return [
123                np.eye(4) * np.sqrt(3 / 4),
124                (np.eye(4) - bottom_right) * np.sqrt(1 / 4),
125                top_right * np.sqrt(1 / 4),
126            ]
127
128    mock_prng = mock.Mock()
129
130    def get_result(state: np.ndarray, sample: float):
131        mock_prng.random.return_value = sample
132        args = cirq.ActOnStateVectorArgs(
133            target_tensor=np.copy(state),
134            available_buffer=np.empty_like(state),
135            qubits=cirq.LineQubit.range(4),
136            prng=mock_prng,
137            log_of_measurement_results={},
138        )
139        cirq.act_on(Decay11(), args, [cirq.LineQubit(1), cirq.LineQubit(3)])
140        return args.target_tensor
141
142    def assert_not_affected(state: np.ndarray, sample: float):
143        np.testing.assert_allclose(get_result(state, sample), state, atol=1e-8)
144
145    all_zeroes = cirq.one_hot(index=(0, 0, 0, 0), shape=(2,) * 4, dtype=np.complex128)
146    all_ones = cirq.one_hot(index=(1, 1, 1, 1), shape=(2,) * 4, dtype=np.complex128)
147    decayed_all_ones = cirq.one_hot(index=(1, 0, 1, 0), shape=(2,) * 4, dtype=np.complex128)
148
149    # Decays the 11 state to 00.
150    np.testing.assert_allclose(get_result(all_ones, 3 / 4 - 1e-8), all_ones)
151    np.testing.assert_allclose(get_result(all_ones, 3 / 4 + 1e-8), decayed_all_ones)
152
153    # Decoheres the 11 subspace from other subspaces as sample rises.
154    superpose = all_ones * np.sqrt(1 / 2) + all_zeroes * np.sqrt(1 / 2)
155    np.testing.assert_allclose(get_result(superpose, 3 / 4 - 1e-8), superpose)
156    np.testing.assert_allclose(get_result(superpose, 3 / 4 + 1e-8), all_zeroes)
157    np.testing.assert_allclose(get_result(superpose, 7 / 8 - 1e-8), all_zeroes)
158    np.testing.assert_allclose(get_result(superpose, 7 / 8 + 1e-8), decayed_all_ones)
159
160    # Always acts like identity when sample < p=3/4.
161    for _ in range(10):
162        assert_not_affected(
163            cirq.testing.random_superposition(dim=16).reshape((2,) * 4),
164            sample=3 / 4 - 1e-8,
165        )
166
167    # Acts like identity on superpositions of first three states.
168    for _ in range(10):
169        mock_prng.random.return_value = 3 / 4 + 1e-6
170        projected_state = cirq.testing.random_superposition(dim=16).reshape((2,) * 4)
171        projected_state[cirq.slice_for_qubits_equal_to([1, 3], 3)] = 0
172        projected_state /= np.linalg.norm(projected_state)
173        assert abs(np.linalg.norm(projected_state) - 1) < 1e-8
174        assert_not_affected(
175            projected_state,
176            sample=3 / 4 + 1e-8,
177        )
178
179
180def test_probability_comes_up_short_results_in_fallback():
181    class Short(cirq.Gate):
182        def num_qubits(self) -> int:
183            return 1
184
185        def _kraus_(self):
186            return [
187                cirq.unitary(cirq.X) * np.sqrt(0.999),
188                np.eye(2) * 0,
189            ]
190
191    mock_prng = mock.Mock()
192    mock_prng.random.return_value = 0.9999
193
194    args = cirq.ActOnStateVectorArgs(
195        target_tensor=np.array([1, 0], dtype=np.complex64),
196        available_buffer=np.empty(2, dtype=np.complex64),
197        qubits=cirq.LineQubit.range(1),
198        prng=mock_prng,
199        log_of_measurement_results={},
200    )
201
202    cirq.act_on(Short(), args, cirq.LineQubit.range(1))
203
204    np.testing.assert_allclose(
205        args.target_tensor,
206        np.array([0, 1]),
207    )
208
209
210def test_random_channel_has_random_behavior():
211    q = cirq.LineQubit(0)
212    s = cirq.Simulator().sample(
213        cirq.Circuit(
214            cirq.X(q),
215            cirq.amplitude_damp(0.4).on(q),
216            cirq.measure(q, key='out'),
217        ),
218        repetitions=100,
219    )
220    v = s['out'].value_counts()
221    assert v[0] > 1
222    assert v[1] > 1
223
224
225def test_measured_channel():
226    # This behaves like an X-basis measurement.
227    kc = cirq.KrausChannel(
228        kraus_ops=(
229            np.array([[1, 1], [1, 1]]) * 0.5,
230            np.array([[1, -1], [-1, 1]]) * 0.5,
231        ),
232        key='m',
233    )
234    q0 = cirq.LineQubit(0)
235    circuit = cirq.Circuit(cirq.H(q0), kc.on(q0))
236    sim = cirq.Simulator(seed=0)
237    results = sim.run(circuit, repetitions=100)
238    assert results.histogram(key='m') == {0: 100}
239
240
241def test_measured_mixture():
242    # This behaves like an X-basis measurement.
243    mm = cirq.MixedUnitaryChannel(
244        mixture=(
245            (0.5, np.array([[1, 0], [0, 1]])),
246            (0.5, np.array([[0, 1], [1, 0]])),
247        ),
248        key='flip',
249    )
250    q0 = cirq.LineQubit(0)
251    circuit = cirq.Circuit(mm.on(q0), cirq.measure(q0, key='m'))
252    sim = cirq.Simulator(seed=0)
253    results = sim.run(circuit, repetitions=100)
254    assert results.histogram(key='flip') == results.histogram(key='m')
255
256
257def test_axes_deprecation():
258    rng = np.random.RandomState()
259    state = np.array([1, 0], dtype=np.complex64)
260    buf = np.array([1, 0], dtype=np.complex64)
261    qids = tuple(cirq.LineQubit.range(1))
262    log = {}
263
264    # No kwargs
265    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
266        args = cirq.ActOnStateVectorArgs(state, buf, (1,), rng, log, qids)  # type: ignore
267    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
268        assert args.axes == (1,)
269    assert args.prng is rng
270    assert args.target_tensor is state
271    assert args.available_buffer is buf
272    assert args.qubits is qids
273    assert args.log_of_measurement_results is log
274
275    # kwargs no axes
276    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
277        args = cirq.ActOnStateVectorArgs(
278            state,
279            buf,
280            (1,),  # type: ignore
281            qubits=qids,
282            prng=rng,
283            log_of_measurement_results=log,
284        )
285    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
286        assert args.axes == (1,)
287    assert args.prng is rng
288    assert args.target_tensor is state
289    assert args.available_buffer is buf
290    assert args.qubits is qids
291    assert args.log_of_measurement_results is log
292
293    # kwargs incl axes
294    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
295        args = cirq.ActOnStateVectorArgs(
296            state,
297            buf,
298            axes=(1,),
299            qubits=qids,
300            prng=rng,
301            log_of_measurement_results=log,
302        )
303    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
304        assert args.axes == (1,)
305    assert args.prng is rng
306    assert args.target_tensor is state
307    assert args.available_buffer is buf
308    assert args.qubits is qids
309    assert args.log_of_measurement_results is log
310
311    # all kwargs
312    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
313        args = cirq.ActOnStateVectorArgs(
314            target_tensor=state,
315            available_buffer=buf,
316            axes=(1,),
317            qubits=qids,
318            prng=rng,
319            log_of_measurement_results=log,
320        )
321    with cirq.testing.assert_deprecated("axes", deadline="v0.13"):
322        assert args.axes == (1,)
323    assert args.prng is rng
324    assert args.target_tensor is state
325    assert args.available_buffer is buf
326    assert args.qubits is qids
327    assert args.log_of_measurement_results is log
328