1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import itertools
15import random
16from typing import Type
17from unittest import mock
18import numpy as np
19import pytest
20import sympy
21
22import cirq
23
24
25def test_invalid_dtype():
26    with pytest.raises(ValueError, match='complex'):
27        cirq.Simulator(dtype=np.int32)
28
29
30@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
31@pytest.mark.parametrize('split', [True, False])
32def test_run_no_measurements(dtype: Type[np.number], split: bool):
33    q0, q1 = cirq.LineQubit.range(2)
34    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
35
36    circuit = cirq.Circuit(cirq.X(q0), cirq.X(q1))
37    with pytest.raises(ValueError, match="no measurements"):
38        simulator.run(circuit)
39
40
41@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
42@pytest.mark.parametrize('split', [True, False])
43def test_run_no_results(dtype: Type[np.number], split: bool):
44    q0, q1 = cirq.LineQubit.range(2)
45    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
46
47    circuit = cirq.Circuit(cirq.X(q0), cirq.X(q1))
48    with pytest.raises(ValueError, match="no measurements"):
49        simulator.run(circuit)
50
51
52@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
53@pytest.mark.parametrize('split', [True, False])
54def test_run_empty_circuit(dtype: Type[np.number], split: bool):
55    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
56    with pytest.raises(ValueError, match="no measurements"):
57        simulator.run(cirq.Circuit())
58
59
60@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
61@pytest.mark.parametrize('split', [True, False])
62def test_run_reset(dtype: Type[np.number], split: bool):
63    q0, q1 = cirq.LineQid.for_qid_shape((2, 3))
64    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
65    circuit = cirq.Circuit(
66        cirq.H(q0),
67        PlusGate(3, 2)(q1),
68        cirq.reset(q0),
69        cirq.measure(q0, key='m0'),
70        cirq.measure(q1, key='m1a'),
71        cirq.reset(q1),
72        cirq.measure(q1, key='m1b'),
73    )
74    meas = simulator.run(circuit, repetitions=100).measurements
75    assert np.array_equal(meas['m0'], np.zeros((100, 1)))
76    assert np.array_equal(meas['m1a'], np.full((100, 1), 2))
77    assert np.array_equal(meas['m1b'], np.zeros((100, 1)))
78
79
80@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
81@pytest.mark.parametrize('split', [True, False])
82def test_run_bit_flips(dtype: Type[np.number], split: bool):
83    q0, q1 = cirq.LineQubit.range(2)
84    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
85    for b0 in [0, 1]:
86        for b1 in [0, 1]:
87            circuit = cirq.Circuit(
88                (cirq.X ** b0)(q0), (cirq.X ** b1)(q1), cirq.measure(q0), cirq.measure(q1)
89            )
90            result = simulator.run(circuit)
91            np.testing.assert_equal(result.measurements, {'0': [[b0]], '1': [[b1]]})
92
93
94@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
95@pytest.mark.parametrize('split', [True, False])
96def test_run_measure_at_end_no_repetitions(dtype: Type[np.number], split: bool):
97    q0, q1 = cirq.LineQubit.range(2)
98    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
99    with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim:
100        for b0 in [0, 1]:
101            for b1 in [0, 1]:
102                circuit = cirq.Circuit(
103                    (cirq.X ** b0)(q0), (cirq.X ** b1)(q1), cirq.measure(q0), cirq.measure(q1)
104                )
105                result = simulator.run(circuit, repetitions=0)
106                np.testing.assert_equal(
107                    result.measurements, {'0': np.empty([0, 1]), '1': np.empty([0, 1])}
108                )
109                assert result.repetitions == 0
110        assert mock_sim.call_count == 0
111
112
113def test_run_repetitions_terminal_measurement_stochastic():
114    q = cirq.LineQubit(0)
115    c = cirq.Circuit(cirq.H(q), cirq.measure(q, key='q'))
116    results = cirq.Simulator().run(c, repetitions=10000)
117    assert 1000 <= sum(v[0] for v in results.measurements['q']) < 9000
118
119
120@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
121@pytest.mark.parametrize('split', [True, False])
122def test_run_repetitions_measure_at_end(dtype: Type[np.number], split: bool):
123    q0, q1 = cirq.LineQubit.range(2)
124    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
125    with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim:
126        for b0 in [0, 1]:
127            for b1 in [0, 1]:
128                circuit = cirq.Circuit(
129                    (cirq.X ** b0)(q0), (cirq.X ** b1)(q1), cirq.measure(q0), cirq.measure(q1)
130                )
131                result = simulator.run(circuit, repetitions=3)
132                np.testing.assert_equal(result.measurements, {'0': [[b0]] * 3, '1': [[b1]] * 3})
133                assert result.repetitions == 3
134        # We expect one call per b0,b1.
135        assert mock_sim.call_count == 8
136
137
138@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
139@pytest.mark.parametrize('split', [True, False])
140def test_run_invert_mask_measure_not_terminal(dtype: Type[np.number], split: bool):
141    q0, q1 = cirq.LineQubit.range(2)
142    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
143    with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim:
144        for b0 in [0, 1]:
145            for b1 in [0, 1]:
146                circuit = cirq.Circuit(
147                    (cirq.X ** b0)(q0),
148                    (cirq.X ** b1)(q1),
149                    cirq.measure(q0, q1, key='m', invert_mask=(True, False)),
150                    cirq.X(q0),
151                )
152                result = simulator.run(circuit, repetitions=3)
153                np.testing.assert_equal(result.measurements, {'m': [[1 - b0, b1]] * 3})
154                assert result.repetitions == 3
155        # We expect repeated calls per b0,b1 instead of one call.
156        assert mock_sim.call_count > 4
157
158
159@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
160@pytest.mark.parametrize('split', [True, False])
161def test_run_partial_invert_mask_measure_not_terminal(dtype: Type[np.number], split: bool):
162    q0, q1 = cirq.LineQubit.range(2)
163    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
164    with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim:
165        for b0 in [0, 1]:
166            for b1 in [0, 1]:
167                circuit = cirq.Circuit(
168                    (cirq.X ** b0)(q0),
169                    (cirq.X ** b1)(q1),
170                    cirq.measure(q0, q1, key='m', invert_mask=(True,)),
171                    cirq.X(q0),
172                )
173                result = simulator.run(circuit, repetitions=3)
174                np.testing.assert_equal(result.measurements, {'m': [[1 - b0, b1]] * 3})
175                assert result.repetitions == 3
176        # We expect repeated calls per b0,b1 instead of one call.
177        assert mock_sim.call_count > 4
178
179
180@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
181@pytest.mark.parametrize('split', [True, False])
182def test_run_measurement_not_terminal_no_repetitions(dtype: Type[np.number], split: bool):
183    q0, q1 = cirq.LineQubit.range(2)
184    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
185    with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim:
186        for b0 in [0, 1]:
187            for b1 in [0, 1]:
188                circuit = cirq.Circuit(
189                    (cirq.X ** b0)(q0),
190                    (cirq.X ** b1)(q1),
191                    cirq.measure(q0),
192                    cirq.measure(q1),
193                    cirq.H(q0),
194                    cirq.H(q1),
195                )
196                result = simulator.run(circuit, repetitions=0)
197                np.testing.assert_equal(
198                    result.measurements, {'0': np.empty([0, 1]), '1': np.empty([0, 1])}
199                )
200                assert result.repetitions == 0
201        assert mock_sim.call_count == 0
202
203
204@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
205@pytest.mark.parametrize('split', [True, False])
206def test_run_repetitions_measurement_not_terminal(dtype: Type[np.number], split: bool):
207    q0, q1 = cirq.LineQubit.range(2)
208    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
209    with mock.patch.object(simulator, '_core_iterator', wraps=simulator._core_iterator) as mock_sim:
210        for b0 in [0, 1]:
211            for b1 in [0, 1]:
212                circuit = cirq.Circuit(
213                    (cirq.X ** b0)(q0),
214                    (cirq.X ** b1)(q1),
215                    cirq.measure(q0),
216                    cirq.measure(q1),
217                    cirq.H(q0),
218                    cirq.H(q1),
219                )
220                result = simulator.run(circuit, repetitions=3)
221                np.testing.assert_equal(result.measurements, {'0': [[b0]] * 3, '1': [[b1]] * 3})
222                assert result.repetitions == 3
223        # We expect repeated calls per b0,b1 instead of one call.
224        assert mock_sim.call_count > 4
225
226
227@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
228@pytest.mark.parametrize('split', [True, False])
229def test_run_param_resolver(dtype: Type[np.number], split: bool):
230    q0, q1 = cirq.LineQubit.range(2)
231    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
232    for b0 in [0, 1]:
233        for b1 in [0, 1]:
234            circuit = cirq.Circuit(
235                (cirq.X ** sympy.Symbol('b0'))(q0),
236                (cirq.X ** sympy.Symbol('b1'))(q1),
237                cirq.measure(q0),
238                cirq.measure(q1),
239            )
240            param_resolver = cirq.ParamResolver({'b0': b0, 'b1': b1})
241            result = simulator.run(circuit, param_resolver=param_resolver)
242            np.testing.assert_equal(result.measurements, {'0': [[b0]], '1': [[b1]]})
243            np.testing.assert_equal(result.params, param_resolver)
244
245
246@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
247@pytest.mark.parametrize('split', [True, False])
248def test_run_mixture(dtype: Type[np.number], split: bool):
249    q0 = cirq.LineQubit(0)
250    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
251    circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0))
252    result = simulator.run(circuit, repetitions=100)
253    assert 20 < sum(result.measurements['0'])[0] < 80  # type: ignore
254
255
256@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
257@pytest.mark.parametrize('split', [True, False])
258def test_run_mixture_with_gates(dtype: Type[np.number], split: bool):
259    q0 = cirq.LineQubit(0)
260    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split, seed=23)
261    circuit = cirq.Circuit(cirq.H(q0), cirq.phase_flip(0.5)(q0), cirq.H(q0), cirq.measure(q0))
262    result = simulator.run(circuit, repetitions=100)
263    assert sum(result.measurements['0'])[0] < 80  # type: ignore
264    assert sum(result.measurements['0'])[0] > 20  # type: ignore
265
266
267@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
268@pytest.mark.parametrize('split', [True, False])
269def test_run_correlations(dtype: Type[np.number], split: bool):
270    q0, q1 = cirq.LineQubit.range(2)
271    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
272    circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.measure(q0, q1))
273    for _ in range(10):
274        result = simulator.run(circuit)
275        bits = result.measurements['0,1'][0]
276        assert bits[0] == bits[1]
277
278
279@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
280@pytest.mark.parametrize('split', [True, False])
281def test_run_measure_multiple_qubits(dtype: Type[np.number], split: bool):
282    q0, q1 = cirq.LineQubit.range(2)
283    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
284    for b0 in [0, 1]:
285        for b1 in [0, 1]:
286            circuit = cirq.Circuit((cirq.X ** b0)(q0), (cirq.X ** b1)(q1), cirq.measure(q0, q1))
287            result = simulator.run(circuit, repetitions=3)
288            np.testing.assert_equal(result.measurements, {'0,1': [[b0, b1]] * 3})
289
290
291@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
292@pytest.mark.parametrize('split', [True, False])
293def test_run_sweeps_param_resolvers(dtype: Type[np.number], split: bool):
294    q0, q1 = cirq.LineQubit.range(2)
295    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
296    for b0 in [0, 1]:
297        for b1 in [0, 1]:
298            circuit = cirq.Circuit(
299                (cirq.X ** sympy.Symbol('b0'))(q0),
300                (cirq.X ** sympy.Symbol('b1'))(q1),
301                cirq.measure(q0),
302                cirq.measure(q1),
303            )
304            params = [
305                cirq.ParamResolver({'b0': b0, 'b1': b1}),
306                cirq.ParamResolver({'b0': b1, 'b1': b0}),
307            ]
308            results = simulator.run_sweep(circuit, params=params)
309
310            assert len(results) == 2
311            np.testing.assert_equal(results[0].measurements, {'0': [[b0]], '1': [[b1]]})
312            np.testing.assert_equal(results[1].measurements, {'0': [[b1]], '1': [[b0]]})
313            assert results[0].params == params[0]
314            assert results[1].params == params[1]
315
316
317@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
318@pytest.mark.parametrize('split', [True, False])
319def test_simulate_random_unitary(dtype: Type[np.number], split: bool):
320    q0, q1 = cirq.LineQubit.range(2)
321    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
322    for _ in range(10):
323        random_circuit = cirq.testing.random_circuit(qubits=[q0, q1], n_moments=8, op_density=0.99)
324        circuit_unitary = []
325        for x in range(4):
326            result = simulator.simulate(random_circuit, qubit_order=[q0, q1], initial_state=x)
327            circuit_unitary.append(result.final_state_vector)
328        np.testing.assert_almost_equal(
329            np.transpose(circuit_unitary), random_circuit.unitary(qubit_order=[q0, q1]), decimal=6
330        )
331
332
333@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
334@pytest.mark.parametrize('split', [True, False])
335def test_simulate_no_circuit(dtype: Type[np.number], split: bool):
336    q0, q1 = cirq.LineQubit.range(2)
337    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
338    circuit = cirq.Circuit()
339    result = simulator.simulate(circuit, qubit_order=[q0, q1])
340    np.testing.assert_almost_equal(result.final_state_vector, np.array([1, 0, 0, 0]))
341    assert len(result.measurements) == 0
342
343
344@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
345@pytest.mark.parametrize('split', [True, False])
346def test_simulate(dtype: Type[np.number], split: bool):
347    q0, q1 = cirq.LineQubit.range(2)
348    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
349    circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1))
350    result = simulator.simulate(circuit, qubit_order=[q0, q1])
351    np.testing.assert_almost_equal(result.final_state_vector, np.array([0.5, 0.5, 0.5, 0.5]))
352    assert len(result.measurements) == 0
353
354
355class PlusGate(cirq.Gate):
356    """A qudit gate that increments a qudit state mod its dimension."""
357
358    def __init__(self, dimension, increment=1):
359        self.dimension = dimension
360        self.increment = increment % dimension
361
362    def _qid_shape_(self):
363        return (self.dimension,)
364
365    def _unitary_(self):
366        inc = (self.increment - 1) % self.dimension + 1
367        u = np.empty((self.dimension, self.dimension))
368        u[inc:] = np.eye(self.dimension)[:-inc]
369        u[:inc] = np.eye(self.dimension)[-inc:]
370        return u
371
372
373class _TestMixture(cirq.Gate):
374    def __init__(self, gate_options):
375        self.gate_options = gate_options
376
377    def _qid_shape_(self):
378        return cirq.qid_shape(self.gate_options[0], ())
379
380    def _mixture_(self):
381        return [(1 / len(self.gate_options), cirq.unitary(g)) for g in self.gate_options]
382
383
384@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
385@pytest.mark.parametrize('split', [True, False])
386def test_simulate_qudits(dtype: Type[np.number], split: bool):
387    q0, q1 = cirq.LineQid.for_qid_shape((3, 4))
388    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
389    circuit = cirq.Circuit(
390        PlusGate(3)(q0),
391        PlusGate(4, increment=3)(q1),
392    )
393    result = simulator.simulate(circuit, qubit_order=[q0, q1])
394    expected = np.zeros(12)
395    expected[4 * 1 + 3] = 1
396    np.testing.assert_almost_equal(result.final_state_vector, expected)
397    assert len(result.measurements) == 0
398
399
400@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
401@pytest.mark.parametrize('split', [True, False])
402def test_simulate_mixtures(dtype: Type[np.number], split: bool):
403    q0 = cirq.LineQubit(0)
404    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
405    circuit = cirq.Circuit(cirq.bit_flip(0.5)(q0), cirq.measure(q0))
406    count = 0
407    for _ in range(100):
408        result = simulator.simulate(circuit, qubit_order=[q0])
409        if result.measurements['0']:
410            np.testing.assert_almost_equal(result.final_state_vector, np.array([0, 1]))
411            count += 1
412        else:
413            np.testing.assert_almost_equal(result.final_state_vector, np.array([1, 0]))
414    assert count < 80 and count > 20
415
416
417@pytest.mark.parametrize(
418    'dtype, split', itertools.product([np.complex64, np.complex128], [True, False])
419)
420def test_simulate_qudit_mixtures(dtype: Type[np.number], split: bool):
421    q0 = cirq.LineQid(0, 3)
422    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
423    mixture = _TestMixture([PlusGate(3, 0), PlusGate(3, 1), PlusGate(3, 2)])
424    circuit = cirq.Circuit(mixture(q0), cirq.measure(q0))
425    counts = {0: 0, 1: 0, 2: 0}
426    for _ in range(300):
427        result = simulator.simulate(circuit, qubit_order=[q0])
428        meas = result.measurements['0 (d=3)'][0]
429        counts[meas] += 1
430        np.testing.assert_almost_equal(
431            result.final_state_vector, np.array([meas == 0, meas == 1, meas == 2])
432        )
433    assert counts[0] < 160 and counts[0] > 40
434    assert counts[1] < 160 and counts[1] > 40
435    assert counts[2] < 160 and counts[2] > 40
436
437
438@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
439@pytest.mark.parametrize('split', [True, False])
440def test_simulate_bit_flips(dtype: Type[np.number], split: bool):
441    q0, q1 = cirq.LineQubit.range(2)
442    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
443    for b0 in [0, 1]:
444        for b1 in [0, 1]:
445            circuit = cirq.Circuit(
446                (cirq.X ** b0)(q0), (cirq.X ** b1)(q1), cirq.measure(q0), cirq.measure(q1)
447            )
448            result = simulator.simulate(circuit)
449            np.testing.assert_equal(result.measurements, {'0': [b0], '1': [b1]})
450            expected_state = np.zeros(shape=(2, 2))
451            expected_state[b0][b1] = 1.0
452            np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4))
453
454
455@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
456@pytest.mark.parametrize('split', [True, False])
457def test_simulate_initial_state(dtype: Type[np.number], split: bool):
458    q0, q1 = cirq.LineQubit.range(2)
459    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
460    for b0 in [0, 1]:
461        for b1 in [0, 1]:
462            circuit = cirq.Circuit((cirq.X ** b0)(q0), (cirq.X ** b1)(q1))
463            result = simulator.simulate(circuit, initial_state=1)
464            expected_state = np.zeros(shape=(2, 2))
465            expected_state[b0][1 - b1] = 1.0
466            np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4))
467
468
469@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
470@pytest.mark.parametrize('split', [True, False])
471def test_simulate_act_on_args(dtype: Type[np.number], split: bool):
472    q0, q1 = cirq.LineQubit.range(2)
473    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
474    for b0 in [0, 1]:
475        for b1 in [0, 1]:
476            circuit = cirq.Circuit((cirq.X ** b0)(q0), (cirq.X ** b1)(q1))
477            args = simulator._create_act_on_args(initial_state=1, qubits=(q0, q1))
478            result = simulator.simulate(circuit, initial_state=args)
479            expected_state = np.zeros(shape=(2, 2))
480            expected_state[b0][1 - b1] = 1.0
481            np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4))
482
483
484@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
485@pytest.mark.parametrize('split', [True, False])
486def test_simulate_qubit_order(dtype: Type[np.number], split: bool):
487    q0, q1 = cirq.LineQubit.range(2)
488    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
489    for b0 in [0, 1]:
490        for b1 in [0, 1]:
491            circuit = cirq.Circuit((cirq.X ** b0)(q0), (cirq.X ** b1)(q1))
492            result = simulator.simulate(circuit, qubit_order=[q1, q0])
493            expected_state = np.zeros(shape=(2, 2))
494            expected_state[b1][b0] = 1.0
495            np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4))
496
497
498@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
499@pytest.mark.parametrize('split', [True, False])
500def test_simulate_param_resolver(dtype: Type[np.number], split: bool):
501    q0, q1 = cirq.LineQubit.range(2)
502    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
503    for b0 in [0, 1]:
504        for b1 in [0, 1]:
505            circuit = cirq.Circuit(
506                (cirq.X ** sympy.Symbol('b0'))(q0), (cirq.X ** sympy.Symbol('b1'))(q1)
507            )
508            resolver = {'b0': b0, 'b1': b1}
509            result = simulator.simulate(circuit, param_resolver=resolver)  # type: ignore
510            expected_state = np.zeros(shape=(2, 2))
511            expected_state[b0][b1] = 1.0
512            np.testing.assert_equal(result.final_state_vector, np.reshape(expected_state, 4))
513            assert result.params == cirq.ParamResolver(resolver)  # type: ignore
514            assert len(result.measurements) == 0
515
516
517@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
518@pytest.mark.parametrize('split', [True, False])
519def test_simulate_measure_multiple_qubits(dtype: Type[np.number], split: bool):
520    q0, q1 = cirq.LineQubit.range(2)
521    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
522    for b0 in [0, 1]:
523        for b1 in [0, 1]:
524            circuit = cirq.Circuit((cirq.X ** b0)(q0), (cirq.X ** b1)(q1), cirq.measure(q0, q1))
525            result = simulator.simulate(circuit)
526            np.testing.assert_equal(result.measurements, {'0,1': [b0, b1]})
527
528
529@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
530@pytest.mark.parametrize('split', [True, False])
531def test_simulate_sweeps_param_resolver(dtype: Type[np.number], split: bool):
532    q0, q1 = cirq.LineQubit.range(2)
533    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
534    for b0 in [0, 1]:
535        for b1 in [0, 1]:
536            circuit = cirq.Circuit(
537                (cirq.X ** sympy.Symbol('b0'))(q0), (cirq.X ** sympy.Symbol('b1'))(q1)
538            )
539            params = [
540                cirq.ParamResolver({'b0': b0, 'b1': b1}),
541                cirq.ParamResolver({'b0': b1, 'b1': b0}),
542            ]
543            results = simulator.simulate_sweep(circuit, params=params)
544            expected_state = np.zeros(shape=(2, 2))
545            expected_state[b0][b1] = 1.0
546            np.testing.assert_equal(results[0].final_state_vector, np.reshape(expected_state, 4))
547
548            expected_state = np.zeros(shape=(2, 2))
549            expected_state[b1][b0] = 1.0
550            np.testing.assert_equal(results[1].final_state_vector, np.reshape(expected_state, 4))
551
552            assert results[0].params == params[0]
553            assert results[1].params == params[1]
554
555
556@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
557@pytest.mark.parametrize('split', [True, False])
558def test_simulate_moment_steps(dtype: Type[np.number], split: bool):
559    q0, q1 = cirq.LineQubit.range(2)
560    circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q0), cirq.H(q1))
561    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
562    for i, step in enumerate(simulator.simulate_moment_steps(circuit)):
563        if i == 0:
564            np.testing.assert_almost_equal(step.state_vector(), np.array([0.5] * 4))
565        else:
566            np.testing.assert_almost_equal(step.state_vector(), np.array([1, 0, 0, 0]))
567
568
569@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
570@pytest.mark.parametrize('split', [True, False])
571def test_simulate_moment_steps_empty_circuit(dtype: Type[np.number], split: bool):
572    circuit = cirq.Circuit()
573    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
574    step = None
575    for step in simulator.simulate_moment_steps(circuit):
576        pass
577    assert step._simulator_state() == cirq.StateVectorSimulatorState(
578        state_vector=np.array([1]), qubit_map={}
579    )
580
581
582@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
583def test_simulate_moment_steps_set_state(dtype):
584    q0, q1 = cirq.LineQubit.range(2)
585    circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q0), cirq.H(q1))
586    simulator = cirq.Simulator(dtype=dtype)
587    for i, step in enumerate(simulator.simulate_moment_steps(circuit)):
588        np.testing.assert_almost_equal(step.state_vector(), np.array([0.5] * 4))
589        if i == 0:
590            step.set_state_vector(np.array([1, 0, 0, 0], dtype=dtype))
591
592
593@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
594@pytest.mark.parametrize('split', [True, False])
595def test_simulate_moment_steps_sample(dtype: Type[np.number], split: bool):
596    q0, q1 = cirq.LineQubit.range(2)
597    circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1))
598    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
599    for i, step in enumerate(simulator.simulate_moment_steps(circuit)):
600        if i == 0:
601            samples = step.sample([q0, q1], repetitions=10)
602            for sample in samples:
603                assert np.array_equal(sample, [True, False]) or np.array_equal(
604                    sample, [False, False]
605                )
606        else:
607            samples = step.sample([q0, q1], repetitions=10)
608            for sample in samples:
609                assert np.array_equal(sample, [True, True]) or np.array_equal(
610                    sample, [False, False]
611                )
612
613
614@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
615@pytest.mark.parametrize('split', [True, False])
616def test_simulate_moment_steps_intermediate_measurement(dtype: Type[np.number], split: bool):
617    q0 = cirq.LineQubit(0)
618    circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0), cirq.H(q0))
619    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
620    for i, step in enumerate(simulator.simulate_moment_steps(circuit)):
621        if i == 1:
622            result = int(step.measurements['0'][0])
623            expected = np.zeros(2)
624            expected[result] = 1
625            np.testing.assert_almost_equal(step.state_vector(), expected)
626        if i == 2:
627            expected = np.array([np.sqrt(0.5), np.sqrt(0.5) * (-1) ** result])
628            np.testing.assert_almost_equal(step.state_vector(), expected)
629
630
631@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
632@pytest.mark.parametrize('split', [True, False])
633def test_simulate_expectation_values(dtype: Type[np.number], split: bool):
634    # Compare with test_expectation_from_state_vector_two_qubit_states
635    # in file: cirq/ops/linear_combinations_test.py
636    q0, q1 = cirq.LineQubit.range(2)
637    psum1 = cirq.Z(q0) + 3.2 * cirq.Z(q1)
638    psum2 = -1 * cirq.X(q0) + 2 * cirq.X(q1)
639    c1 = cirq.Circuit(cirq.I(q0), cirq.X(q1))
640    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
641    result = simulator.simulate_expectation_values(c1, [psum1, psum2])
642    assert cirq.approx_eq(result[0], -2.2, atol=1e-6)
643    assert cirq.approx_eq(result[1], 0, atol=1e-6)
644
645    c2 = cirq.Circuit(cirq.H(q0), cirq.H(q1))
646    result = simulator.simulate_expectation_values(c2, [psum1, psum2])
647    assert cirq.approx_eq(result[0], 0, atol=1e-6)
648    assert cirq.approx_eq(result[1], 1, atol=1e-6)
649
650    psum3 = cirq.Z(q0) + cirq.X(q1)
651    c3 = cirq.Circuit(cirq.I(q0), cirq.H(q1))
652    result = simulator.simulate_expectation_values(c3, psum3)
653    assert cirq.approx_eq(result[0], 2, atol=1e-6)
654
655
656@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
657@pytest.mark.parametrize('split', [True, False])
658def test_simulate_expectation_values_terminal_measure(dtype: Type[np.number], split: bool):
659    q0 = cirq.LineQubit(0)
660    circuit = cirq.Circuit(cirq.H(q0), cirq.measure(q0))
661    obs = cirq.Z(q0)
662    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
663    with pytest.raises(ValueError):
664        _ = simulator.simulate_expectation_values(circuit, obs)
665
666    results = {-1: 0, 1: 0}
667    for _ in range(100):
668        result = simulator.simulate_expectation_values(
669            circuit, obs, permit_terminal_measurements=True
670        )
671        if cirq.approx_eq(result[0], -1, atol=1e-6):
672            results[-1] += 1
673        if cirq.approx_eq(result[0], 1, atol=1e-6):
674            results[1] += 1
675
676    # With a measurement after H, the Z-observable expects a specific state.
677    assert results[-1] > 0
678    assert results[1] > 0
679    assert results[-1] + results[1] == 100
680
681    circuit = cirq.Circuit(cirq.H(q0))
682    results = {0: 0}
683    for _ in range(100):
684        result = simulator.simulate_expectation_values(
685            circuit, obs, permit_terminal_measurements=True
686        )
687        if cirq.approx_eq(result[0], 0, atol=1e-6):
688            results[0] += 1
689
690    # Without measurement after H, the Z-observable is indeterminate.
691    assert results[0] == 100
692
693
694@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
695@pytest.mark.parametrize('split', [True, False])
696def test_simulate_expectation_values_qubit_order(dtype: Type[np.number], split: bool):
697    q0, q1, q2 = cirq.LineQubit.range(3)
698    circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.X(q2))
699    obs = cirq.X(q0) + cirq.X(q1) - cirq.Z(q2)
700    simulator = cirq.Simulator(dtype=dtype, split_untangled_states=split)
701
702    result = simulator.simulate_expectation_values(circuit, obs)
703    assert cirq.approx_eq(result[0], 3, atol=1e-6)
704
705    # Adjusting the qubit order has no effect on the observables.
706    result_flipped = simulator.simulate_expectation_values(circuit, obs, qubit_order=[q1, q2, q0])
707    assert cirq.approx_eq(result_flipped[0], 3, atol=1e-6)
708
709
710def test_invalid_run_no_unitary():
711    class NoUnitary(cirq.SingleQubitGate):
712        pass
713
714    q0 = cirq.LineQubit(0)
715    simulator = cirq.Simulator()
716    circuit = cirq.Circuit(NoUnitary()(q0))
717    circuit.append([cirq.measure(q0, key='meas')])
718    with pytest.raises(TypeError, match='unitary'):
719        simulator.run(circuit)
720
721
722def test_allocates_new_state():
723    class NoUnitary(cirq.SingleQubitGate):
724        def _has_unitary_(self):
725            return True
726
727        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):
728            return np.copy(args.target_tensor)
729
730    q0 = cirq.LineQubit(0)
731    simulator = cirq.Simulator()
732    circuit = cirq.Circuit(NoUnitary()(q0))
733
734    initial_state = np.array([np.sqrt(0.5), np.sqrt(0.5)], dtype=np.complex64)
735    result = simulator.simulate(circuit, initial_state=initial_state)
736    np.testing.assert_array_almost_equal(result.state_vector(), initial_state)
737    assert not initial_state is result.state_vector()
738
739
740def test_does_not_modify_initial_state():
741    q0 = cirq.LineQubit(0)
742    simulator = cirq.Simulator()
743
744    class InPlaceUnitary(cirq.SingleQubitGate):
745        def _has_unitary_(self):
746            return True
747
748        def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs):
749            args.target_tensor[0], args.target_tensor[1] = (
750                args.target_tensor[1],
751                args.target_tensor[0],
752            )
753            return args.target_tensor
754
755    circuit = cirq.Circuit(InPlaceUnitary()(q0))
756
757    initial_state = np.array([1, 0], dtype=np.complex64)
758    result = simulator.simulate(circuit, initial_state=initial_state)
759    np.testing.assert_array_almost_equal(np.array([1, 0], dtype=np.complex64), initial_state)
760    np.testing.assert_array_almost_equal(
761        result.state_vector(), np.array([0, 1], dtype=np.complex64)
762    )
763
764
765def test_simulator_step_state_mixin():
766    qubits = cirq.LineQubit.range(2)
767    args = cirq.ActOnStateVectorArgs(
768        log_of_measurement_results={'m': np.array([1, 2])},
769        target_tensor=np.array([0, 1, 0, 0]).reshape((2, 2)),
770        available_buffer=np.array([0, 1, 0, 0]).reshape((2, 2)),
771        prng=cirq.value.parse_random_state(0),
772        qubits=qubits,
773    )
774    result = cirq.SparseSimulatorStep(
775        sim_state=args,
776        dtype=np.complex64,
777        simulator=None,  # type: ignore
778    )
779    rho = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
780    np.testing.assert_array_almost_equal(rho, result.density_matrix_of(qubits))
781    bloch = np.array([0, 0, -1])
782    np.testing.assert_array_almost_equal(bloch, result.bloch_vector_of(qubits[1]))
783
784    assert result.dirac_notation() == '|01⟩'
785
786
787class MultiHTestGate(cirq.testing.TwoQubitGate):
788    def _decompose_(self, qubits):
789        return cirq.H.on_each(*qubits)
790
791
792def test_simulates_composite():
793    c = cirq.Circuit(MultiHTestGate().on(*cirq.LineQubit.range(2)))
794    expected = np.array([0.5] * 4)
795    np.testing.assert_allclose(c.final_state_vector(), expected)
796    np.testing.assert_allclose(cirq.Simulator().simulate(c).state_vector(), expected)
797
798
799def test_simulate_measurement_inversions():
800    q = cirq.NamedQubit('q')
801
802    c = cirq.Circuit(cirq.measure(q, key='q', invert_mask=(True,)))
803    assert cirq.Simulator().simulate(c).measurements == {'q': np.array([True])}
804
805    c = cirq.Circuit(cirq.measure(q, key='q', invert_mask=(False,)))
806    assert cirq.Simulator().simulate(c).measurements == {'q': np.array([False])}
807
808
809def test_works_on_pauli_string_phasor():
810    a, b = cirq.LineQubit.range(2)
811    c = cirq.Circuit(np.exp(0.5j * np.pi * cirq.X(a) * cirq.X(b)))
812    sim = cirq.Simulator()
813    result = sim.simulate(c).state_vector()
814    np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1j]), atol=1e-8)
815
816
817def test_works_on_pauli_string():
818    a, b = cirq.LineQubit.range(2)
819    c = cirq.Circuit(cirq.X(a) * cirq.X(b))
820    sim = cirq.Simulator()
821    result = sim.simulate(c).state_vector()
822    np.testing.assert_allclose(result.reshape(4), np.array([0, 0, 0, 1]), atol=1e-8)
823
824
825def test_measure_at_end_invert_mask():
826    simulator = cirq.Simulator()
827    a = cirq.NamedQubit('a')
828    circuit = cirq.Circuit(cirq.measure(a, key='a', invert_mask=(True,)))
829    result = simulator.run(circuit, repetitions=4)
830    np.testing.assert_equal(result.measurements['a'], np.array([[1]] * 4))
831
832
833def test_measure_at_end_invert_mask_multiple_qubits():
834    simulator = cirq.Simulator()
835    a, b, c = cirq.LineQubit.range(3)
836    circuit = cirq.Circuit(
837        cirq.measure(a, key='a', invert_mask=(True,)),
838        cirq.measure(b, c, key='bc', invert_mask=(False, True)),
839    )
840    result = simulator.run(circuit, repetitions=4)
841    np.testing.assert_equal(result.measurements['a'], np.array([[True]] * 4))
842    np.testing.assert_equal(result.measurements['bc'], np.array([[0, 1]] * 4))
843
844
845def test_measure_at_end_invert_mask_partial():
846    simulator = cirq.Simulator()
847    a, _, c = cirq.LineQubit.range(3)
848    circuit = cirq.Circuit(cirq.measure(a, c, key='ac', invert_mask=(True,)))
849    result = simulator.run(circuit, repetitions=4)
850    np.testing.assert_equal(result.measurements['ac'], np.array([[1, 0]] * 4))
851
852
853def test_qudit_invert_mask():
854    q0, q1, q2, q3, q4 = cirq.LineQid.for_qid_shape((2, 3, 3, 3, 4))
855    c = cirq.Circuit(
856        PlusGate(2, 1)(q0),
857        PlusGate(3, 1)(q2),
858        PlusGate(3, 2)(q3),
859        PlusGate(4, 3)(q4),
860        cirq.measure(q0, q1, q2, q3, q4, key='a', invert_mask=(True,) * 4),
861    )
862    assert np.all(cirq.Simulator().run(c).measurements['a'] == [[0, 1, 0, 2, 3]])
863
864
865def test_compute_amplitudes():
866    a, b = cirq.LineQubit.range(2)
867    c = cirq.Circuit(cirq.X(a), cirq.H(a), cirq.H(b))
868    sim = cirq.Simulator()
869
870    result = sim.compute_amplitudes(c, [0])
871    np.testing.assert_allclose(np.array(result), np.array([0.5]))
872
873    result = sim.compute_amplitudes(c, [1, 2, 3])
874    np.testing.assert_allclose(np.array(result), np.array([0.5, -0.5, -0.5]))
875
876    result = sim.compute_amplitudes(c, (1, 2, 3), qubit_order=(b, a))
877    np.testing.assert_allclose(np.array(result), np.array([-0.5, 0.5, -0.5]))
878
879
880def test_compute_amplitudes_bad_input():
881    a, b = cirq.LineQubit.range(2)
882    c = cirq.Circuit(cirq.X(a), cirq.H(a), cirq.H(b))
883    sim = cirq.Simulator()
884
885    with pytest.raises(ValueError, match='1-dimensional'):
886        _ = sim.compute_amplitudes(c, np.array([[0, 0]]))
887
888
889def test_run_sweep_parameters_not_resolved():
890    a = cirq.LineQubit(0)
891    simulator = cirq.Simulator()
892    circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a))
893    with pytest.raises(ValueError, match='symbols were not specified'):
894        _ = simulator.run_sweep(circuit, cirq.ParamResolver({}))
895
896
897def test_simulate_sweep_parameters_not_resolved():
898    a = cirq.LineQubit(0)
899    simulator = cirq.Simulator()
900    circuit = cirq.Circuit(cirq.XPowGate(exponent=sympy.Symbol('a'))(a), cirq.measure(a))
901    with pytest.raises(ValueError, match='symbols were not specified'):
902        _ = simulator.simulate_sweep(circuit, cirq.ParamResolver({}))
903
904
905def test_random_seed():
906    a = cirq.NamedQubit('a')
907    circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a))
908
909    sim = cirq.Simulator(seed=1234)
910    result = sim.run(circuit, repetitions=10)
911    assert np.all(
912        result.measurements['a']
913        == [[False], [True], [False], [True], [True], [False], [False], [True], [True], [True]]
914    )
915
916    sim = cirq.Simulator(seed=np.random.RandomState(1234))
917    result = sim.run(circuit, repetitions=10)
918    assert np.all(
919        result.measurements['a']
920        == [[False], [True], [False], [True], [True], [False], [False], [True], [True], [True]]
921    )
922
923
924def test_random_seed_does_not_modify_global_state_terminal_measurements():
925    a = cirq.NamedQubit('a')
926    circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a))
927
928    sim = cirq.Simulator(seed=1234)
929    result1 = sim.run(circuit, repetitions=50)
930
931    sim = cirq.Simulator(seed=1234)
932    _ = np.random.random()
933    _ = random.random()
934    result2 = sim.run(circuit, repetitions=50)
935
936    assert result1 == result2
937
938
939def test_random_seed_does_not_modify_global_state_non_terminal_measurements():
940    a = cirq.NamedQubit('a')
941    circuit = cirq.Circuit(
942        cirq.X(a) ** 0.5, cirq.measure(a, key='a0'), cirq.X(a) ** 0.5, cirq.measure(a, key='a1')
943    )
944
945    sim = cirq.Simulator(seed=1234)
946    result1 = sim.run(circuit, repetitions=50)
947
948    sim = cirq.Simulator(seed=1234)
949    _ = np.random.random()
950    _ = random.random()
951    result2 = sim.run(circuit, repetitions=50)
952
953    assert result1 == result2
954
955
956def test_random_seed_does_not_modify_global_state_mixture():
957    a = cirq.NamedQubit('a')
958    circuit = cirq.Circuit(cirq.depolarize(0.5).on(a), cirq.measure(a))
959
960    sim = cirq.Simulator(seed=1234)
961    result1 = sim.run(circuit, repetitions=50)
962
963    sim = cirq.Simulator(seed=1234)
964    _ = np.random.random()
965    _ = random.random()
966    result2 = sim.run(circuit, repetitions=50)
967
968    assert result1 == result2
969
970
971def test_random_seed_terminal_measurements_deterministic():
972    a = cirq.NamedQubit('a')
973    circuit = cirq.Circuit(cirq.X(a) ** 0.5, cirq.measure(a, key='a'))
974    sim = cirq.Simulator(seed=1234)
975    result1 = sim.run(circuit, repetitions=30)
976    result2 = sim.run(circuit, repetitions=30)
977    assert np.all(
978        result1.measurements['a']
979        == [
980            [0],
981            [1],
982            [0],
983            [1],
984            [1],
985            [0],
986            [0],
987            [1],
988            [1],
989            [1],
990            [0],
991            [1],
992            [1],
993            [1],
994            [0],
995            [1],
996            [1],
997            [0],
998            [1],
999            [1],
1000            [0],
1001            [1],
1002            [0],
1003            [0],
1004            [1],
1005            [1],
1006            [0],
1007            [1],
1008            [0],
1009            [1],
1010        ]
1011    )
1012    assert np.all(
1013        result2.measurements['a']
1014        == [
1015            [1],
1016            [0],
1017            [1],
1018            [0],
1019            [1],
1020            [1],
1021            [0],
1022            [1],
1023            [0],
1024            [1],
1025            [0],
1026            [0],
1027            [0],
1028            [1],
1029            [1],
1030            [1],
1031            [0],
1032            [1],
1033            [0],
1034            [1],
1035            [0],
1036            [1],
1037            [1],
1038            [0],
1039            [1],
1040            [1],
1041            [1],
1042            [1],
1043            [1],
1044            [1],
1045        ]
1046    )
1047
1048
1049def test_random_seed_non_terminal_measurements_deterministic():
1050    a = cirq.NamedQubit('a')
1051    circuit = cirq.Circuit(
1052        cirq.X(a) ** 0.5, cirq.measure(a, key='a'), cirq.X(a) ** 0.5, cirq.measure(a, key='b')
1053    )
1054    sim = cirq.Simulator(seed=1234)
1055    result = sim.run(circuit, repetitions=30)
1056    assert np.all(
1057        result.measurements['a']
1058        == [
1059            [0],
1060            [0],
1061            [1],
1062            [0],
1063            [1],
1064            [0],
1065            [1],
1066            [0],
1067            [1],
1068            [1],
1069            [0],
1070            [0],
1071            [1],
1072            [0],
1073            [0],
1074            [1],
1075            [1],
1076            [1],
1077            [0],
1078            [0],
1079            [0],
1080            [0],
1081            [1],
1082            [0],
1083            [0],
1084            [0],
1085            [1],
1086            [1],
1087            [1],
1088            [1],
1089        ]
1090    )
1091    assert np.all(
1092        result.measurements['b']
1093        == [
1094            [1],
1095            [1],
1096            [0],
1097            [1],
1098            [1],
1099            [1],
1100            [1],
1101            [1],
1102            [0],
1103            [1],
1104            [1],
1105            [0],
1106            [1],
1107            [1],
1108            [1],
1109            [0],
1110            [0],
1111            [1],
1112            [1],
1113            [1],
1114            [0],
1115            [1],
1116            [1],
1117            [1],
1118            [1],
1119            [1],
1120            [0],
1121            [1],
1122            [1],
1123            [1],
1124        ]
1125    )
1126
1127
1128def test_random_seed_mixture_deterministic():
1129    a = cirq.NamedQubit('a')
1130    circuit = cirq.Circuit(
1131        cirq.depolarize(0.9).on(a),
1132        cirq.depolarize(0.9).on(a),
1133        cirq.depolarize(0.9).on(a),
1134        cirq.depolarize(0.9).on(a),
1135        cirq.depolarize(0.9).on(a),
1136        cirq.measure(a, key='a'),
1137    )
1138    sim = cirq.Simulator(seed=1234)
1139    result = sim.run(circuit, repetitions=30)
1140    assert np.all(
1141        result.measurements['a']
1142        == [
1143            [1],
1144            [0],
1145            [0],
1146            [0],
1147            [1],
1148            [0],
1149            [0],
1150            [1],
1151            [1],
1152            [1],
1153            [1],
1154            [1],
1155            [0],
1156            [1],
1157            [0],
1158            [0],
1159            [0],
1160            [0],
1161            [0],
1162            [1],
1163            [0],
1164            [1],
1165            [1],
1166            [0],
1167            [1],
1168            [1],
1169            [1],
1170            [1],
1171            [1],
1172            [0],
1173        ]
1174    )
1175
1176
1177# TODO(#3388) Add summary line to docstring.
1178# pylint: disable=docstring-first-line-empty
1179def test_entangled_reset_does_not_break_randomness():
1180    """
1181    A previous version of cirq made the mistake of assuming that it was okay to
1182    cache the wavefunction produced by general channels on unrelated qubits
1183    before repeatedly sampling measurements. This test checks for that mistake.
1184    """
1185
1186    a, b = cirq.LineQubit.range(2)
1187    circuit = cirq.Circuit(
1188        cirq.H(a), cirq.CNOT(a, b), cirq.ResetChannel().on(a), cirq.measure(b, key='out')
1189    )
1190    samples = cirq.Simulator().sample(circuit, repetitions=100)['out']
1191    counts = samples.value_counts()
1192    assert len(counts) == 2
1193    assert 10 <= counts[0] <= 90
1194    assert 10 <= counts[1] <= 90
1195
1196
1197# pylint: enable=docstring-first-line-empty
1198def test_overlapping_measurements_at_end():
1199    a, b = cirq.LineQubit.range(2)
1200    circuit = cirq.Circuit(
1201        cirq.H(a),
1202        cirq.CNOT(a, b),
1203        # These measurements are not on independent qubits but they commute.
1204        cirq.measure(a, key='a'),
1205        cirq.measure(a, key='not a', invert_mask=(True,)),
1206        cirq.measure(b, key='b'),
1207        cirq.measure(a, b, key='ab'),
1208    )
1209
1210    samples = cirq.Simulator().sample(circuit, repetitions=100)
1211    np.testing.assert_array_equal(samples['a'].values, samples['not a'].values ^ 1)
1212    np.testing.assert_array_equal(
1213        samples['a'].values * 2 + samples['b'].values, samples['ab'].values
1214    )
1215
1216    counts = samples['b'].value_counts()
1217    assert len(counts) == 2
1218    assert 10 <= counts[0] <= 90
1219    assert 10 <= counts[1] <= 90
1220
1221
1222def test_separated_measurements():
1223    a, b = cirq.LineQubit.range(2)
1224    c = cirq.Circuit(
1225        [
1226            cirq.H(a),
1227            cirq.H(b),
1228            cirq.CZ(a, b),
1229            cirq.measure(a, key='a'),
1230            cirq.CZ(a, b),
1231            cirq.H(b),
1232            cirq.measure(b, key='zero'),
1233        ]
1234    )
1235    sample = cirq.Simulator().sample(c, repetitions=10)
1236    np.testing.assert_array_equal(sample['zero'].values, [0] * 10)
1237
1238
1239def test_state_vector_copy():
1240    sim = cirq.Simulator(split_untangled_states=False)
1241
1242    class InplaceGate(cirq.SingleQubitGate):
1243        """A gate that modifies the target tensor in place, multiply by -1."""
1244
1245        def _apply_unitary_(self, args):
1246            args.target_tensor *= -1.0
1247            return args.target_tensor
1248
1249    q = cirq.LineQubit(0)
1250    circuit = cirq.Circuit(InplaceGate()(q), InplaceGate()(q))
1251
1252    vectors = []
1253    for step in sim.simulate_moment_steps(circuit):
1254        vectors.append(step.state_vector(copy=True))
1255    for x, y in itertools.combinations(vectors, 2):
1256        assert not np.shares_memory(x, y)
1257
1258    # If the state vector is not copied, then applying second InplaceGate
1259    # causes old state to be modified.
1260    vectors = []
1261    copy_of_vectors = []
1262    for step in sim.simulate_moment_steps(circuit):
1263        state_vector = step.state_vector(copy=False)
1264        vectors.append(state_vector)
1265        copy_of_vectors.append(state_vector.copy())
1266    assert any(not np.array_equal(x, y) for x, y in zip(vectors, copy_of_vectors))
1267
1268
1269def test_final_state_vector_is_not_last_object():
1270    sim = cirq.Simulator()
1271
1272    q = cirq.LineQubit(0)
1273    initial_state = np.array([1, 0], dtype=np.complex64)
1274    circuit = cirq.Circuit(cirq.wait(q))
1275    result = sim.simulate(circuit, initial_state=initial_state)
1276    assert result.state_vector() is not initial_state
1277    assert not np.shares_memory(result.state_vector(), initial_state)
1278    np.testing.assert_equal(result.state_vector(), initial_state)
1279
1280
1281def test_deterministic_gate_noise():
1282    q = cirq.LineQubit(0)
1283    circuit = cirq.Circuit(cirq.I(q), cirq.measure(q))
1284
1285    simulator1 = cirq.Simulator(noise=cirq.X)
1286    result1 = simulator1.run(circuit, repetitions=10)
1287
1288    simulator2 = cirq.Simulator(noise=cirq.X)
1289    result2 = simulator2.run(circuit, repetitions=10)
1290
1291    assert result1 == result2
1292
1293    simulator3 = cirq.Simulator(noise=cirq.Z)
1294    result3 = simulator3.run(circuit, repetitions=10)
1295
1296    assert result1 != result3
1297
1298
1299def test_nondeterministic_mixture_noise():
1300    q = cirq.LineQubit(0)
1301    circuit = cirq.Circuit(cirq.I(q), cirq.measure(q))
1302
1303    simulator = cirq.Simulator(noise=cirq.ConstantQubitNoiseModel(cirq.depolarize(0.5)))
1304    result1 = simulator.run(circuit, repetitions=50)
1305    result2 = simulator.run(circuit, repetitions=50)
1306
1307    assert result1 != result2
1308
1309
1310def test_act_on_args_pure_state_creation():
1311    sim = cirq.Simulator()
1312    qids = cirq.LineQubit.range(3)
1313    shape = cirq.qid_shape(qids)
1314    args = sim._create_act_on_args(1, qids)
1315    values = list(args.values())
1316    arg = (
1317        values[0]
1318        .kronecker_product(values[1])
1319        .kronecker_product(values[2])
1320        .transpose_to_qubit_order(qids)
1321    )
1322    expected = cirq.to_valid_state_vector(1, len(qids), qid_shape=shape)
1323    np.testing.assert_allclose(arg.target_tensor, expected.reshape(shape))
1324
1325
1326def test_noise_model():
1327    q = cirq.LineQubit(0)
1328    circuit = cirq.Circuit(cirq.H(q), cirq.measure(q))
1329
1330    noise_model = cirq.NoiseModel.from_noise_model_like(cirq.depolarize(p=0.01))
1331    simulator = cirq.Simulator(noise=noise_model)
1332    result = simulator.run(circuit, repetitions=100)
1333
1334    assert 40 <= sum(result.measurements['0'])[0] < 60
1335