1# Copyright 2020 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17import sympy
18
19import cirq
20
21
22def test_run_sweep():
23    a, b, c = [cirq.NamedQubit(s) for s in ['a', 'b', 'c']]
24    circuit = cirq.Circuit([cirq.measure(a)], [cirq.measure(b, c)])
25    sampler = cirq.ZerosSampler()
26
27    result = sampler.run_sweep(circuit, None, 3)
28
29    assert len(result) == 1
30    assert result[0].measurements.keys() == {'a', 'b,c'}
31    assert result[0].measurements['a'].shape == (3, 1)
32    assert np.all(result[0].measurements['a'] == 0)
33    assert result[0].measurements['b,c'].shape == (3, 2)
34    assert np.all(result[0].measurements['b,c'] == 0)
35
36
37def test_sample():
38    # Create a circuit whose measurements are always zeros, and check that
39    # results of ZeroSampler on this circuit are identical to results of
40    # actual simulation.
41    qs = cirq.LineQubit.range(6)
42    c = cirq.Circuit([cirq.CNOT(qs[0], qs[1]), cirq.X(qs[2]), cirq.X(qs[2])])
43    c += cirq.Z(qs[3]) ** sympy.Symbol('p')
44    c += [cirq.measure(q) for q in qs[0:3]]
45    c += cirq.measure(qs[4], qs[5])
46    # Z to even power is an identity.
47    params = cirq.Points(sympy.Symbol('p'), [0, 2, 4, 6])
48
49    result1 = cirq.ZerosSampler().sample(c, repetitions=10, params=params)
50    result2 = cirq.Simulator().sample(c, repetitions=10, params=params)
51
52    assert np.all(result1 == result2)
53
54
55class OnlyMeasurementsDevice(cirq.Device):
56    def validate_operation(self, operation: 'cirq.Operation') -> None:
57        if not cirq.is_measurement(operation):
58            raise ValueError(f'{operation} is not a measurement and this device only measures!')
59
60
61def test_validate_device():
62    device = OnlyMeasurementsDevice()
63    sampler = cirq.ZerosSampler(device)
64
65    a, b, c = [cirq.NamedQubit(s) for s in ['a', 'b', 'c']]
66    circuit = cirq.Circuit(cirq.measure(a), cirq.measure(b, c))
67
68    _ = sampler.run_sweep(circuit, None, 3)
69
70    circuit = cirq.Circuit(cirq.measure(a), cirq.X(b))
71    with pytest.raises(ValueError, match=r'X\(b\) is not a measurement'):
72        _ = sampler.run_sweep(circuit, None, 3)
73