1# Copyright 2020 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17import sympy 18 19import cirq 20 21 22def test_run_sweep(): 23 a, b, c = [cirq.NamedQubit(s) for s in ['a', 'b', 'c']] 24 circuit = cirq.Circuit([cirq.measure(a)], [cirq.measure(b, c)]) 25 sampler = cirq.ZerosSampler() 26 27 result = sampler.run_sweep(circuit, None, 3) 28 29 assert len(result) == 1 30 assert result[0].measurements.keys() == {'a', 'b,c'} 31 assert result[0].measurements['a'].shape == (3, 1) 32 assert np.all(result[0].measurements['a'] == 0) 33 assert result[0].measurements['b,c'].shape == (3, 2) 34 assert np.all(result[0].measurements['b,c'] == 0) 35 36 37def test_sample(): 38 # Create a circuit whose measurements are always zeros, and check that 39 # results of ZeroSampler on this circuit are identical to results of 40 # actual simulation. 41 qs = cirq.LineQubit.range(6) 42 c = cirq.Circuit([cirq.CNOT(qs[0], qs[1]), cirq.X(qs[2]), cirq.X(qs[2])]) 43 c += cirq.Z(qs[3]) ** sympy.Symbol('p') 44 c += [cirq.measure(q) for q in qs[0:3]] 45 c += cirq.measure(qs[4], qs[5]) 46 # Z to even power is an identity. 47 params = cirq.Points(sympy.Symbol('p'), [0, 2, 4, 6]) 48 49 result1 = cirq.ZerosSampler().sample(c, repetitions=10, params=params) 50 result2 = cirq.Simulator().sample(c, repetitions=10, params=params) 51 52 assert np.all(result1 == result2) 53 54 55class OnlyMeasurementsDevice(cirq.Device): 56 def validate_operation(self, operation: 'cirq.Operation') -> None: 57 if not cirq.is_measurement(operation): 58 raise ValueError(f'{operation} is not a measurement and this device only measures!') 59 60 61def test_validate_device(): 62 device = OnlyMeasurementsDevice() 63 sampler = cirq.ZerosSampler(device) 64 65 a, b, c = [cirq.NamedQubit(s) for s in ['a', 'b', 'c']] 66 circuit = cirq.Circuit(cirq.measure(a), cirq.measure(b, c)) 67 68 _ = sampler.run_sweep(circuit, None, 3) 69 70 circuit = cirq.Circuit(cirq.measure(a), cirq.X(b)) 71 with pytest.raises(ValueError, match=r'X\(b\) is not a measurement'): 72 _ = sampler.run_sweep(circuit, None, 3) 73