1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for state_vector.py"""
15
16import itertools
17import pytest
18
19import numpy as np
20
21import cirq
22import cirq.testing
23
24
25def test_state_mixin():
26    class TestClass(cirq.StateVectorMixin):
27        def state_vector(self) -> np.ndarray:
28            return np.array([0, 0, 1, 0])
29
30    qubits = cirq.LineQubit.range(2)
31    test = TestClass(qubit_map={qubits[i]: i for i in range(2)})
32    assert test.dirac_notation() == '|10⟩'
33    np.testing.assert_almost_equal(test.bloch_vector_of(qubits[0]), np.array([0, 0, -1]))
34    np.testing.assert_almost_equal(test.density_matrix_of(qubits[0:1]), np.array([[0, 0], [0, 1]]))
35
36    assert cirq.qid_shape(TestClass({qubits[i]: 1 - i for i in range(2)})) == (2, 2)
37    assert cirq.qid_shape(TestClass({cirq.LineQid(i, i + 1): 2 - i for i in range(3)})) == (3, 2, 1)
38    assert cirq.qid_shape(TestClass(), 'no shape') == 'no shape'
39
40    with pytest.raises(ValueError, match='Qubit index out of bounds'):
41        _ = TestClass({qubits[0]: 1})
42    with pytest.raises(ValueError, match='Duplicate qubit index'):
43        _ = TestClass({qubits[0]: 0, qubits[1]: 0})
44    with pytest.raises(ValueError, match='Duplicate qubit index'):
45        _ = TestClass({qubits[0]: 1, qubits[1]: 1})
46    with pytest.raises(ValueError, match='Duplicate qubit index'):
47        _ = TestClass({qubits[0]: -1, qubits[1]: 1})
48
49
50def test_sample_state_big_endian():
51    results = []
52    for x in range(8):
53        state = cirq.to_valid_state_vector(x, 3)
54        sample = cirq.sample_state_vector(state, [2, 1, 0])
55        results.append(sample)
56    expecteds = [[list(reversed(x))] for x in list(itertools.product([False, True], repeat=3))]
57    for result, expected in zip(results, expecteds):
58        np.testing.assert_equal(result, expected)
59
60
61def test_sample_state_partial_indices():
62    for index in range(3):
63        for x in range(8):
64            state = cirq.to_valid_state_vector(x, 3)
65            np.testing.assert_equal(
66                cirq.sample_state_vector(state, [index]), [[bool(1 & (x >> (2 - index)))]]
67            )
68
69
70def test_sample_state_partial_indices_oder():
71    for x in range(8):
72        state = cirq.to_valid_state_vector(x, 3)
73        expected = [[bool(1 & (x >> 0)), bool(1 & (x >> 1))]]
74        np.testing.assert_equal(cirq.sample_state_vector(state, [2, 1]), expected)
75
76
77def test_sample_state_partial_indices_all_orders():
78    for perm in itertools.permutations([0, 1, 2]):
79        for x in range(8):
80            state = cirq.to_valid_state_vector(x, 3)
81            expected = [[bool(1 & (x >> (2 - p))) for p in perm]]
82            np.testing.assert_equal(cirq.sample_state_vector(state, perm), expected)
83
84
85def test_sample_state():
86    state = np.zeros(8, dtype=np.complex64)
87    state[0] = 1 / np.sqrt(2)
88    state[2] = 1 / np.sqrt(2)
89    for _ in range(10):
90        sample = cirq.sample_state_vector(state, [2, 1, 0])
91        assert np.array_equal(sample, [[False, False, False]]) or np.array_equal(
92            sample, [[False, True, False]]
93        )
94    # Partial sample is correct.
95    for _ in range(10):
96        np.testing.assert_equal(cirq.sample_state_vector(state, [2]), [[False]])
97        np.testing.assert_equal(cirq.sample_state_vector(state, [0]), [[False]])
98
99
100def test_sample_empty_state():
101    state = np.array([1.0])
102    np.testing.assert_almost_equal(cirq.sample_state_vector(state, []), np.zeros(shape=(1, 0)))
103
104
105def test_sample_no_repetitions():
106    state = cirq.to_valid_state_vector(0, 3)
107    np.testing.assert_almost_equal(
108        cirq.sample_state_vector(state, [1], repetitions=0), np.zeros(shape=(0, 1))
109    )
110    np.testing.assert_almost_equal(
111        cirq.sample_state_vector(state, [1, 2], repetitions=0), np.zeros(shape=(0, 2))
112    )
113
114
115def test_sample_state_repetitions():
116    for perm in itertools.permutations([0, 1, 2]):
117        for x in range(8):
118            state = cirq.to_valid_state_vector(x, 3)
119            expected = [[bool(1 & (x >> (2 - p))) for p in perm]] * 3
120
121            result = cirq.sample_state_vector(state, perm, repetitions=3)
122            np.testing.assert_equal(result, expected)
123
124
125def test_sample_state_seed():
126    state = np.ones(2) / np.sqrt(2)
127
128    samples = cirq.sample_state_vector(state, [0], repetitions=10, seed=1234)
129    assert np.array_equal(
130        samples,
131        [[False], [True], [False], [True], [True], [False], [False], [True], [True], [True]],
132    )
133
134    samples = cirq.sample_state_vector(state, [0], repetitions=10, seed=np.random.RandomState(1234))
135    assert np.array_equal(
136        samples,
137        [[False], [True], [False], [True], [True], [False], [False], [True], [True], [True]],
138    )
139
140
141def test_sample_state_negative_repetitions():
142    state = cirq.to_valid_state_vector(0, 3)
143    with pytest.raises(ValueError, match='-1'):
144        cirq.sample_state_vector(state, [1], repetitions=-1)
145
146
147def test_sample_state_not_power_of_two():
148    with pytest.raises(ValueError, match='3'):
149        cirq.sample_state_vector(np.array([1, 0, 0]), [1])
150    with pytest.raises(ValueError, match='5'):
151        cirq.sample_state_vector(np.array([0, 1, 0, 0, 0]), [1])
152
153
154def test_sample_state_index_out_of_range():
155    state = cirq.to_valid_state_vector(0, 3)
156    with pytest.raises(IndexError, match='-2'):
157        cirq.sample_state_vector(state, [-2])
158    with pytest.raises(IndexError, match='3'):
159        cirq.sample_state_vector(state, [3])
160
161
162def test_sample_no_indices():
163    state = cirq.to_valid_state_vector(0, 3)
164    np.testing.assert_almost_equal(cirq.sample_state_vector(state, []), np.zeros(shape=(1, 0)))
165
166
167def test_sample_no_indices_repetitions():
168    state = cirq.to_valid_state_vector(0, 3)
169    np.testing.assert_almost_equal(
170        cirq.sample_state_vector(state, [], repetitions=2), np.zeros(shape=(2, 0))
171    )
172
173
174def test_measure_state_computational_basis():
175    results = []
176    for x in range(8):
177        initial_state = cirq.to_valid_state_vector(x, 3)
178        bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0])
179        results.append(bits)
180        np.testing.assert_almost_equal(state, initial_state)
181    expected = [list(reversed(x)) for x in list(itertools.product([False, True], repeat=3))]
182    assert results == expected
183
184
185def test_measure_state_reshape():
186    results = []
187    for x in range(8):
188        initial_state = np.reshape(cirq.to_valid_state_vector(x, 3), [2] * 3)
189        bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0])
190        results.append(bits)
191        np.testing.assert_almost_equal(state, initial_state)
192    expected = [list(reversed(x)) for x in list(itertools.product([False, True], repeat=3))]
193    assert results == expected
194
195
196def test_measure_state_partial_indices():
197    for index in range(3):
198        for x in range(8):
199            initial_state = cirq.to_valid_state_vector(x, 3)
200            bits, state = cirq.measure_state_vector(initial_state, [index])
201            np.testing.assert_almost_equal(state, initial_state)
202            assert bits == [bool(1 & (x >> (2 - index)))]
203
204
205def test_measure_state_partial_indices_order():
206    for x in range(8):
207        initial_state = cirq.to_valid_state_vector(x, 3)
208        bits, state = cirq.measure_state_vector(initial_state, [2, 1])
209        np.testing.assert_almost_equal(state, initial_state)
210        assert bits == [bool(1 & (x >> 0)), bool(1 & (x >> 1))]
211
212
213def test_measure_state_partial_indices_all_orders():
214    for perm in itertools.permutations([0, 1, 2]):
215        for x in range(8):
216            initial_state = cirq.to_valid_state_vector(x, 3)
217            bits, state = cirq.measure_state_vector(initial_state, perm)
218            np.testing.assert_almost_equal(state, initial_state)
219            assert bits == [bool(1 & (x >> (2 - p))) for p in perm]
220
221
222def test_measure_state_collapse():
223    initial_state = np.zeros(8, dtype=np.complex64)
224    initial_state[0] = 1 / np.sqrt(2)
225    initial_state[2] = 1 / np.sqrt(2)
226    for _ in range(10):
227        bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0])
228        assert bits in [[False, False, False], [False, True, False]]
229        expected = np.zeros(8, dtype=np.complex64)
230        expected[2 if bits[1] else 0] = 1.0
231        np.testing.assert_almost_equal(state, expected)
232        assert state is not initial_state
233
234    # Partial sample is correct.
235    for _ in range(10):
236        bits, state = cirq.measure_state_vector(initial_state, [2])
237        np.testing.assert_almost_equal(state, initial_state)
238        assert bits == [False]
239
240        bits, state = cirq.measure_state_vector(initial_state, [0])
241        np.testing.assert_almost_equal(state, initial_state)
242        assert bits == [False]
243
244
245def test_measure_state_seed():
246    n = 10
247    initial_state = np.ones(2 ** n) / 2 ** (n / 2)
248
249    bits, state1 = cirq.measure_state_vector(initial_state, range(n), seed=1234)
250    np.testing.assert_equal(
251        bits, [False, False, True, True, False, False, False, True, False, False]
252    )
253
254    bits, state2 = cirq.measure_state_vector(
255        initial_state, range(n), seed=np.random.RandomState(1234)
256    )
257    np.testing.assert_equal(
258        bits, [False, False, True, True, False, False, False, True, False, False]
259    )
260
261    np.testing.assert_allclose(state1, state2)
262
263
264def test_measure_state_out_is_state():
265    initial_state = np.zeros(8, dtype=np.complex64)
266    initial_state[0] = 1 / np.sqrt(2)
267    initial_state[2] = 1 / np.sqrt(2)
268    bits, state = cirq.measure_state_vector(initial_state, [2, 1, 0], out=initial_state)
269    expected = np.zeros(8, dtype=np.complex64)
270    expected[2 if bits[1] else 0] = 1.0
271    np.testing.assert_array_almost_equal(initial_state, expected)
272    assert state is initial_state
273
274
275def test_measure_state_out_is_not_state():
276    initial_state = np.zeros(8, dtype=np.complex64)
277    initial_state[0] = 1 / np.sqrt(2)
278    initial_state[2] = 1 / np.sqrt(2)
279    out = np.zeros_like(initial_state)
280    _, state = cirq.measure_state_vector(initial_state, [2, 1, 0], out=out)
281    assert out is not initial_state
282    assert out is state
283
284
285def test_measure_state_not_power_of_two():
286    with pytest.raises(ValueError, match='3'):
287        _, _ = cirq.measure_state_vector(np.array([1, 0, 0]), [1])
288    with pytest.raises(ValueError, match='5'):
289        cirq.measure_state_vector(np.array([0, 1, 0, 0, 0]), [1])
290
291
292def test_measure_state_index_out_of_range():
293    state = cirq.to_valid_state_vector(0, 3)
294    with pytest.raises(IndexError, match='-2'):
295        cirq.measure_state_vector(state, [-2])
296    with pytest.raises(IndexError, match='3'):
297        cirq.measure_state_vector(state, [3])
298
299
300def test_measure_state_no_indices():
301    initial_state = cirq.to_valid_state_vector(0, 3)
302    bits, state = cirq.measure_state_vector(initial_state, [])
303    assert [] == bits
304    np.testing.assert_almost_equal(state, initial_state)
305
306
307def test_measure_state_no_indices_out_is_state():
308    initial_state = cirq.to_valid_state_vector(0, 3)
309    bits, state = cirq.measure_state_vector(initial_state, [], out=initial_state)
310    assert [] == bits
311    np.testing.assert_almost_equal(state, initial_state)
312    assert state is initial_state
313
314
315def test_measure_state_no_indices_out_is_not_state():
316    initial_state = cirq.to_valid_state_vector(0, 3)
317    out = np.zeros_like(initial_state)
318    bits, state = cirq.measure_state_vector(initial_state, [], out=out)
319    assert [] == bits
320    np.testing.assert_almost_equal(state, initial_state)
321    assert state is out
322    assert out is not initial_state
323
324
325def test_measure_state_empty_state():
326    initial_state = np.array([1.0])
327    bits, state = cirq.measure_state_vector(initial_state, [])
328    assert [] == bits
329    np.testing.assert_almost_equal(state, initial_state)
330
331
332class BasicStateVector(cirq.StateVectorMixin):
333    def state_vector(self) -> np.ndarray:
334        return np.array([0, 1, 0, 0])
335
336
337def test_step_result_pretty_state():
338    step_result = BasicStateVector()
339    assert step_result.dirac_notation() == '|01⟩'
340
341
342def test_step_result_density_matrix():
343    q0, q1 = cirq.LineQubit.range(2)
344
345    step_result = BasicStateVector({q0: 0, q1: 1})
346    rho = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
347    np.testing.assert_array_almost_equal(rho, step_result.density_matrix_of([q0, q1]))
348
349    np.testing.assert_array_almost_equal(rho, step_result.density_matrix_of())
350
351    rho_ind_rev = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]])
352    np.testing.assert_array_almost_equal(rho_ind_rev, step_result.density_matrix_of([q1, q0]))
353
354    single_rho = np.array([[0, 0], [0, 1]])
355    np.testing.assert_array_almost_equal(single_rho, step_result.density_matrix_of([q1]))
356
357
358def test_step_result_density_matrix_invalid():
359    q0, q1 = cirq.LineQubit.range(2)
360
361    step_result = BasicStateVector({q0: 0})
362
363    with pytest.raises(KeyError):
364        step_result.density_matrix_of([q1])
365    with pytest.raises(KeyError):
366        step_result.density_matrix_of('junk')
367    with pytest.raises(TypeError):
368        step_result.density_matrix_of(0)
369
370
371def test_step_result_bloch_vector():
372    q0, q1 = cirq.LineQubit.range(2)
373    step_result = BasicStateVector({q0: 0, q1: 1})
374    bloch1 = np.array([0, 0, -1])
375    bloch0 = np.array([0, 0, 1])
376    np.testing.assert_array_almost_equal(bloch1, step_result.bloch_vector_of(q1))
377    np.testing.assert_array_almost_equal(bloch0, step_result.bloch_vector_of(q0))
378
379
380def test_factor_validation():
381    args = cirq.Simulator()._create_act_on_args(0, qubits=cirq.LineQubit.range(2))
382    args.apply_operation(cirq.H(cirq.LineQubit(0)))
383    t = args.create_merged_state().target_tensor
384    cirq.linalg.transformations.factor_state_vector(t, [0])
385    cirq.linalg.transformations.factor_state_vector(t, [1], atol=1e-2)
386    args.apply_operation(cirq.CNOT(cirq.LineQubit(0), cirq.LineQubit(1)))
387    t = args.create_merged_state().target_tensor
388    with pytest.raises(ValueError, match='factor'):
389        cirq.linalg.transformations.factor_state_vector(t, [0])
390    with pytest.raises(ValueError, match='factor'):
391        cirq.linalg.transformations.factor_state_vector(t, [1])
392