1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import pytest
16
17import cirq
18
19
20class ReturnsStr:
21    def _measurement_key_name_(self):
22        return 'door locker'
23
24
25class ReturnsObj:
26    def _measurement_key_obj_(self):
27        return cirq.MeasurementKey(name='door locker')
28
29
30@pytest.mark.parametrize('gate', [ReturnsStr(), ReturnsObj()])
31def test_measurement_key_name(gate):
32    assert isinstance(cirq.measurement_key_name(gate), str)
33    assert cirq.measurement_key_name(gate) == 'door locker'
34    assert cirq.measurement_key_obj(gate) == cirq.MeasurementKey(name='door locker')
35
36    assert cirq.measurement_key_name(gate, None) == 'door locker'
37    assert cirq.measurement_key_name(gate, NotImplemented) == 'door locker'
38    assert cirq.measurement_key_name(gate, 'a') == 'door locker'
39
40
41@pytest.mark.parametrize('gate', [ReturnsStr(), ReturnsObj()])
42def test_measurement_key_obj(gate):
43    assert isinstance(cirq.measurement_key_obj(gate), cirq.MeasurementKey)
44    assert cirq.measurement_key_obj(gate) == cirq.MeasurementKey(name='door locker')
45    assert cirq.measurement_key_obj(gate) == 'door locker'
46
47    assert cirq.measurement_key_obj(gate, None) == 'door locker'
48    assert cirq.measurement_key_obj(gate, NotImplemented) == 'door locker'
49    assert cirq.measurement_key_obj(gate, 'a') == 'door locker'
50
51
52@pytest.mark.parametrize('key_method', [cirq.measurement_key_name, cirq.measurement_key_obj])
53def test_measurement_key_no_method(key_method):
54    class NoMethod:
55        pass
56
57    with pytest.raises(TypeError, match='no measurement keys'):
58        key_method(NoMethod())
59
60    with pytest.raises(ValueError, match='multiple measurement keys'):
61        key_method(
62            cirq.Circuit(
63                cirq.measure(cirq.LineQubit(0), key='a'), cirq.measure(cirq.LineQubit(0), key='b')
64            )
65        )
66
67    assert key_method(NoMethod(), None) is None
68    assert key_method(NoMethod(), NotImplemented) is NotImplemented
69    assert key_method(NoMethod(), 'a') == 'a'
70
71    assert key_method(cirq.X, None) is None
72    assert key_method(cirq.X(cirq.LineQubit(0)), None) is None
73
74
75@pytest.mark.parametrize('key_method', [cirq.measurement_key_name, cirq.measurement_key_obj])
76def test_measurement_key_not_implemented_default_behavior(key_method):
77    class ReturnsNotImplemented:
78        def _measurement_key_name_(self):
79            return NotImplemented
80
81        def _measurement_key_obj_(self):
82            return NotImplemented
83
84    with pytest.raises(TypeError, match='NotImplemented'):
85        key_method(ReturnsNotImplemented())
86
87    assert key_method(ReturnsNotImplemented(), None) is None
88    assert key_method(ReturnsNotImplemented(), NotImplemented) is NotImplemented
89    assert key_method(ReturnsNotImplemented(), 'a') == 'a'
90
91
92def test_is_measurement():
93    q = cirq.NamedQubit('q')
94    assert cirq.is_measurement(cirq.measure(q))
95    assert cirq.is_measurement(cirq.MeasurementGate(num_qubits=1, key='b'))
96
97    assert not cirq.is_measurement(cirq.X(q))
98    assert not cirq.is_measurement(cirq.X)
99    assert not cirq.is_measurement(cirq.bit_flip(1))
100
101    class NotImplementedOperation(cirq.Operation):
102        def with_qubits(self, *new_qubits) -> 'NotImplementedOperation':
103            raise NotImplementedError()
104
105        @property
106        def qubits(self):
107            return cirq.LineQubit.range(2)
108
109    assert not cirq.is_measurement(NotImplementedOperation())
110
111
112def test_measurement_without_key():
113    class MeasurementWithoutKey:
114        def _is_measurement_(self):
115            return True
116
117    with pytest.raises(TypeError, match='no measurement keys'):
118        _ = cirq.measurement_key_name(MeasurementWithoutKey())
119
120    assert cirq.is_measurement(MeasurementWithoutKey())
121
122
123def test_non_measurement_with_key():
124    class NonMeasurementGate(cirq.Gate):
125        def _is_measurement_(self):
126            return False
127
128        def _decompose_(self, qubits):
129            # Decompose should not be called by `is_measurement`
130            assert False
131
132        def _measurement_key_name_(self):
133            # `measurement_key_name`` should not be called by `is_measurement`
134            assert False
135
136        def _measurement_key_names_(self):
137            # `measurement_key_names`` should not be called by `is_measurement`
138            assert False
139
140        def _measurement_key_obj_(self):
141            # `measurement_key_obj`` should not be called by `is_measurement`
142            assert False
143
144        def _measurement_key_objs_(self):
145            # `measurement_key_objs`` should not be called by `is_measurement`
146            assert False
147
148        def num_qubits(self) -> int:
149            return 2  # coverage: ignore
150
151    assert not cirq.is_measurement(NonMeasurementGate())
152
153
154@pytest.mark.parametrize(
155    ('key_method', 'keys'),
156    [(cirq.measurement_key_names, {'a', 'b'}), (cirq.measurement_key_objs, {'c', 'd'})],
157)
158def test_measurement_keys(key_method, keys):
159    class MeasurementKeysGate(cirq.Gate):
160        def _measurement_key_names_(self):
161            return ['a', 'b']
162
163        def _measurement_key_objs_(self):
164            return [cirq.MeasurementKey('c'), cirq.MeasurementKey('d')]
165
166        def num_qubits(self) -> int:
167            return 1
168
169    a, b = cirq.LineQubit.range(2)
170    assert key_method(None) == set()
171    assert key_method([]) == set()
172    assert key_method(cirq.X) == set()
173    assert key_method(cirq.X(a)) == set()
174    assert key_method(cirq.measure(a, key='out')) == {'out'}
175    assert key_method(cirq.Circuit(cirq.measure(a, key='a'), cirq.measure(b, key='2'))) == {
176        'a',
177        '2',
178    }
179    assert key_method(MeasurementKeysGate()) == keys
180    assert key_method(MeasurementKeysGate().on(a)) == keys
181
182
183def test_measurement_keys_allow_decompose_deprecated():
184    a = cirq.LineQubit(0)
185    with cirq.testing.assert_deprecated(deadline="v0.14"):
186        assert cirq.measurement_key_names(None, allow_decompose=False) == set()
187    with cirq.testing.assert_deprecated(deadline="v0.14"):
188        assert cirq.measurement_key_names([], allow_decompose=False) == set()
189    with cirq.testing.assert_deprecated(deadline="v0.14"):
190        assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set()
191    with cirq.testing.assert_deprecated(deadline="v0.14"):
192        assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == {
193            'out'
194        }
195
196
197def test_measurement_key_mapping():
198    class MultiKeyGate:
199        def __init__(self, keys):
200            self._keys = set(keys)
201
202        def _measurement_key_names_(self):
203            return self._keys
204
205        def _with_measurement_key_mapping_(self, key_map):
206            if not all(key in key_map for key in self._keys):
207                raise ValueError('missing keys')
208            return MultiKeyGate([key_map[key] for key in self._keys])
209
210    assert cirq.measurement_key_names(MultiKeyGate([])) == set()
211    assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'}
212
213    mkg_ab = MultiKeyGate(['a', 'b'])
214    assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'}
215
216    mkg_cd = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd'})
217    assert cirq.measurement_key_names(mkg_cd) == {'c', 'd'}
218
219    mkg_ac = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'a', 'b': 'c'})
220    assert cirq.measurement_key_names(mkg_ac) == {'a', 'c'}
221
222    mkg_ba = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'b', 'b': 'a'})
223    assert cirq.measurement_key_names(mkg_ba) == {'a', 'b'}
224
225    with pytest.raises(ValueError):
226        cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c'})
227
228    assert cirq.with_measurement_key_mapping(cirq.X, {'a': 'c'}) is NotImplemented
229
230    mkg_cdx = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd', 'x': 'y'})
231    assert cirq.measurement_key_names(mkg_cdx) == {'c', 'd'}
232
233
234def test_measurement_key_path():
235    class MultiKeyGate:
236        def __init__(self, keys):
237            self._keys = set([cirq.MeasurementKey.parse_serialized(key) for key in keys])
238
239        def _measurement_key_names_(self):
240            return {str(key) for key in self._keys}
241
242        def _with_key_path_(self, path):
243            return MultiKeyGate([str(key._with_key_path_(path)) for key in self._keys])
244
245    assert cirq.measurement_key_names(MultiKeyGate([])) == set()
246    assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'}
247
248    mkg_ab = MultiKeyGate(['a', 'b'])
249    assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'}
250
251    mkg_cd = cirq.with_key_path(mkg_ab, ('c', 'd'))
252    assert cirq.measurement_key_names(mkg_cd) == {'c:d:a', 'c:d:b'}
253
254    assert cirq.with_key_path(cirq.X, ('c', 'd')) is NotImplemented
255