1# Copyright 2019 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import pytest 16 17import cirq 18 19 20class ReturnsStr: 21 def _measurement_key_name_(self): 22 return 'door locker' 23 24 25class ReturnsObj: 26 def _measurement_key_obj_(self): 27 return cirq.MeasurementKey(name='door locker') 28 29 30@pytest.mark.parametrize('gate', [ReturnsStr(), ReturnsObj()]) 31def test_measurement_key_name(gate): 32 assert isinstance(cirq.measurement_key_name(gate), str) 33 assert cirq.measurement_key_name(gate) == 'door locker' 34 assert cirq.measurement_key_obj(gate) == cirq.MeasurementKey(name='door locker') 35 36 assert cirq.measurement_key_name(gate, None) == 'door locker' 37 assert cirq.measurement_key_name(gate, NotImplemented) == 'door locker' 38 assert cirq.measurement_key_name(gate, 'a') == 'door locker' 39 40 41@pytest.mark.parametrize('gate', [ReturnsStr(), ReturnsObj()]) 42def test_measurement_key_obj(gate): 43 assert isinstance(cirq.measurement_key_obj(gate), cirq.MeasurementKey) 44 assert cirq.measurement_key_obj(gate) == cirq.MeasurementKey(name='door locker') 45 assert cirq.measurement_key_obj(gate) == 'door locker' 46 47 assert cirq.measurement_key_obj(gate, None) == 'door locker' 48 assert cirq.measurement_key_obj(gate, NotImplemented) == 'door locker' 49 assert cirq.measurement_key_obj(gate, 'a') == 'door locker' 50 51 52@pytest.mark.parametrize('key_method', [cirq.measurement_key_name, cirq.measurement_key_obj]) 53def test_measurement_key_no_method(key_method): 54 class NoMethod: 55 pass 56 57 with pytest.raises(TypeError, match='no measurement keys'): 58 key_method(NoMethod()) 59 60 with pytest.raises(ValueError, match='multiple measurement keys'): 61 key_method( 62 cirq.Circuit( 63 cirq.measure(cirq.LineQubit(0), key='a'), cirq.measure(cirq.LineQubit(0), key='b') 64 ) 65 ) 66 67 assert key_method(NoMethod(), None) is None 68 assert key_method(NoMethod(), NotImplemented) is NotImplemented 69 assert key_method(NoMethod(), 'a') == 'a' 70 71 assert key_method(cirq.X, None) is None 72 assert key_method(cirq.X(cirq.LineQubit(0)), None) is None 73 74 75@pytest.mark.parametrize('key_method', [cirq.measurement_key_name, cirq.measurement_key_obj]) 76def test_measurement_key_not_implemented_default_behavior(key_method): 77 class ReturnsNotImplemented: 78 def _measurement_key_name_(self): 79 return NotImplemented 80 81 def _measurement_key_obj_(self): 82 return NotImplemented 83 84 with pytest.raises(TypeError, match='NotImplemented'): 85 key_method(ReturnsNotImplemented()) 86 87 assert key_method(ReturnsNotImplemented(), None) is None 88 assert key_method(ReturnsNotImplemented(), NotImplemented) is NotImplemented 89 assert key_method(ReturnsNotImplemented(), 'a') == 'a' 90 91 92def test_is_measurement(): 93 q = cirq.NamedQubit('q') 94 assert cirq.is_measurement(cirq.measure(q)) 95 assert cirq.is_measurement(cirq.MeasurementGate(num_qubits=1, key='b')) 96 97 assert not cirq.is_measurement(cirq.X(q)) 98 assert not cirq.is_measurement(cirq.X) 99 assert not cirq.is_measurement(cirq.bit_flip(1)) 100 101 class NotImplementedOperation(cirq.Operation): 102 def with_qubits(self, *new_qubits) -> 'NotImplementedOperation': 103 raise NotImplementedError() 104 105 @property 106 def qubits(self): 107 return cirq.LineQubit.range(2) 108 109 assert not cirq.is_measurement(NotImplementedOperation()) 110 111 112def test_measurement_without_key(): 113 class MeasurementWithoutKey: 114 def _is_measurement_(self): 115 return True 116 117 with pytest.raises(TypeError, match='no measurement keys'): 118 _ = cirq.measurement_key_name(MeasurementWithoutKey()) 119 120 assert cirq.is_measurement(MeasurementWithoutKey()) 121 122 123def test_non_measurement_with_key(): 124 class NonMeasurementGate(cirq.Gate): 125 def _is_measurement_(self): 126 return False 127 128 def _decompose_(self, qubits): 129 # Decompose should not be called by `is_measurement` 130 assert False 131 132 def _measurement_key_name_(self): 133 # `measurement_key_name`` should not be called by `is_measurement` 134 assert False 135 136 def _measurement_key_names_(self): 137 # `measurement_key_names`` should not be called by `is_measurement` 138 assert False 139 140 def _measurement_key_obj_(self): 141 # `measurement_key_obj`` should not be called by `is_measurement` 142 assert False 143 144 def _measurement_key_objs_(self): 145 # `measurement_key_objs`` should not be called by `is_measurement` 146 assert False 147 148 def num_qubits(self) -> int: 149 return 2 # coverage: ignore 150 151 assert not cirq.is_measurement(NonMeasurementGate()) 152 153 154@pytest.mark.parametrize( 155 ('key_method', 'keys'), 156 [(cirq.measurement_key_names, {'a', 'b'}), (cirq.measurement_key_objs, {'c', 'd'})], 157) 158def test_measurement_keys(key_method, keys): 159 class MeasurementKeysGate(cirq.Gate): 160 def _measurement_key_names_(self): 161 return ['a', 'b'] 162 163 def _measurement_key_objs_(self): 164 return [cirq.MeasurementKey('c'), cirq.MeasurementKey('d')] 165 166 def num_qubits(self) -> int: 167 return 1 168 169 a, b = cirq.LineQubit.range(2) 170 assert key_method(None) == set() 171 assert key_method([]) == set() 172 assert key_method(cirq.X) == set() 173 assert key_method(cirq.X(a)) == set() 174 assert key_method(cirq.measure(a, key='out')) == {'out'} 175 assert key_method(cirq.Circuit(cirq.measure(a, key='a'), cirq.measure(b, key='2'))) == { 176 'a', 177 '2', 178 } 179 assert key_method(MeasurementKeysGate()) == keys 180 assert key_method(MeasurementKeysGate().on(a)) == keys 181 182 183def test_measurement_keys_allow_decompose_deprecated(): 184 a = cirq.LineQubit(0) 185 with cirq.testing.assert_deprecated(deadline="v0.14"): 186 assert cirq.measurement_key_names(None, allow_decompose=False) == set() 187 with cirq.testing.assert_deprecated(deadline="v0.14"): 188 assert cirq.measurement_key_names([], allow_decompose=False) == set() 189 with cirq.testing.assert_deprecated(deadline="v0.14"): 190 assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set() 191 with cirq.testing.assert_deprecated(deadline="v0.14"): 192 assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == { 193 'out' 194 } 195 196 197def test_measurement_key_mapping(): 198 class MultiKeyGate: 199 def __init__(self, keys): 200 self._keys = set(keys) 201 202 def _measurement_key_names_(self): 203 return self._keys 204 205 def _with_measurement_key_mapping_(self, key_map): 206 if not all(key in key_map for key in self._keys): 207 raise ValueError('missing keys') 208 return MultiKeyGate([key_map[key] for key in self._keys]) 209 210 assert cirq.measurement_key_names(MultiKeyGate([])) == set() 211 assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'} 212 213 mkg_ab = MultiKeyGate(['a', 'b']) 214 assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'} 215 216 mkg_cd = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd'}) 217 assert cirq.measurement_key_names(mkg_cd) == {'c', 'd'} 218 219 mkg_ac = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'a', 'b': 'c'}) 220 assert cirq.measurement_key_names(mkg_ac) == {'a', 'c'} 221 222 mkg_ba = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'b', 'b': 'a'}) 223 assert cirq.measurement_key_names(mkg_ba) == {'a', 'b'} 224 225 with pytest.raises(ValueError): 226 cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c'}) 227 228 assert cirq.with_measurement_key_mapping(cirq.X, {'a': 'c'}) is NotImplemented 229 230 mkg_cdx = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd', 'x': 'y'}) 231 assert cirq.measurement_key_names(mkg_cdx) == {'c', 'd'} 232 233 234def test_measurement_key_path(): 235 class MultiKeyGate: 236 def __init__(self, keys): 237 self._keys = set([cirq.MeasurementKey.parse_serialized(key) for key in keys]) 238 239 def _measurement_key_names_(self): 240 return {str(key) for key in self._keys} 241 242 def _with_key_path_(self, path): 243 return MultiKeyGate([str(key._with_key_path_(path)) for key in self._keys]) 244 245 assert cirq.measurement_key_names(MultiKeyGate([])) == set() 246 assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'} 247 248 mkg_ab = MultiKeyGate(['a', 'b']) 249 assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'} 250 251 mkg_cd = cirq.with_key_path(mkg_ab, ('c', 'd')) 252 assert cirq.measurement_key_names(mkg_cd) == {'c:d:a', 'c:d:b'} 253 254 assert cirq.with_key_path(cirq.X, ('c', 'd')) is NotImplemented 255