1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Sequence
16
17import numpy as np
18import pytest
19
20import cirq
21from cirq import ops
22from cirq.devices.noise_model import validate_all_measurements
23from cirq.testing import assert_equivalent_op_tree
24
25
26def assert_equivalent_op_tree_sequence(x: Sequence[cirq.OP_TREE], y: Sequence[cirq.OP_TREE]):
27    assert len(x) == len(y)
28    for a, b in zip(x, y):
29        assert_equivalent_op_tree(a, b)
30
31
32def test_requires_one_override():
33    class C(cirq.NoiseModel):
34        pass
35
36    with pytest.raises(TypeError, match='abstract'):
37        _ = C()
38
39
40def test_infers_other_methods():
41    q = cirq.LineQubit(0)
42
43    class NoiseModelWithNoisyMomentListMethod(cirq.NoiseModel):
44        def noisy_moments(self, moments, system_qubits):
45            result = []
46            for moment in moments:
47                if moment.operations:
48                    result.append(
49                        cirq.X(moment.operations[0].qubits[0]).with_tags(ops.VirtualTag())
50                    )
51                else:
52                    result.append([])
53            return result
54
55    a = NoiseModelWithNoisyMomentListMethod()
56    assert_equivalent_op_tree(a.noisy_operation(cirq.H(q)), cirq.X(q).with_tags(ops.VirtualTag()))
57    assert_equivalent_op_tree(
58        a.noisy_moment(cirq.Moment([cirq.H(q)]), [q]), cirq.X(q).with_tags(ops.VirtualTag())
59    )
60    assert_equivalent_op_tree_sequence(
61        a.noisy_moments([cirq.Moment(), cirq.Moment([cirq.H(q)])], [q]),
62        [[], cirq.X(q).with_tags(ops.VirtualTag())],
63    )
64
65    class NoiseModelWithNoisyMomentMethod(cirq.NoiseModel):
66        def noisy_moment(self, moment, system_qubits):
67            return [y.with_tags(ops.VirtualTag()) for y in cirq.Y.on_each(*moment.qubits)]
68
69    b = NoiseModelWithNoisyMomentMethod()
70    assert_equivalent_op_tree(b.noisy_operation(cirq.H(q)), cirq.Y(q).with_tags(ops.VirtualTag()))
71    assert_equivalent_op_tree(
72        b.noisy_moment(cirq.Moment([cirq.H(q)]), [q]), cirq.Y(q).with_tags(ops.VirtualTag())
73    )
74    assert_equivalent_op_tree_sequence(
75        b.noisy_moments([cirq.Moment(), cirq.Moment([cirq.H(q)])], [q]),
76        [[], cirq.Y(q).with_tags(ops.VirtualTag())],
77    )
78
79    class NoiseModelWithNoisyOperationMethod(cirq.NoiseModel):
80        def noisy_operation(self, operation: 'cirq.Operation'):
81            return cirq.Z(operation.qubits[0]).with_tags(ops.VirtualTag())
82
83    c = NoiseModelWithNoisyOperationMethod()
84    assert_equivalent_op_tree(c.noisy_operation(cirq.H(q)), cirq.Z(q).with_tags(ops.VirtualTag()))
85    assert_equivalent_op_tree(
86        c.noisy_moment(cirq.Moment([cirq.H(q)]), [q]), cirq.Z(q).with_tags(ops.VirtualTag())
87    )
88    assert_equivalent_op_tree_sequence(
89        c.noisy_moments([cirq.Moment(), cirq.Moment([cirq.H(q)])], [q]),
90        [[], cirq.Z(q).with_tags(ops.VirtualTag())],
91    )
92
93
94def test_no_noise():
95    q = cirq.LineQubit(0)
96    m = cirq.Moment([cirq.X(q)])
97    assert cirq.NO_NOISE.noisy_operation(cirq.X(q)) == cirq.X(q)
98    assert cirq.NO_NOISE.noisy_moment(m, [q]) is m
99    assert cirq.NO_NOISE.noisy_moments([m, m], [q]) == [m, m]
100    assert cirq.NO_NOISE == cirq.NO_NOISE
101    assert str(cirq.NO_NOISE) == '(no noise)'
102    cirq.testing.assert_equivalent_repr(cirq.NO_NOISE)
103
104
105def test_constant_qubit_noise():
106    a, b, c = cirq.LineQubit.range(3)
107    damp = cirq.amplitude_damp(0.5)
108    damp_all = cirq.ConstantQubitNoiseModel(damp)
109    actual = damp_all.noisy_moments([cirq.Moment([cirq.X(a)]), cirq.Moment()], [a, b, c])
110    expected = [
111        [
112            cirq.Moment([cirq.X(a)]),
113            cirq.Moment(d.with_tags(ops.VirtualTag()) for d in [damp(a), damp(b), damp(c)]),
114        ],
115        [
116            cirq.Moment(),
117            cirq.Moment(d.with_tags(ops.VirtualTag()) for d in [damp(a), damp(b), damp(c)]),
118        ],
119    ]
120    assert actual == expected
121    cirq.testing.assert_equivalent_repr(damp_all)
122
123    with pytest.raises(ValueError, match='num_qubits'):
124        _ = cirq.ConstantQubitNoiseModel(cirq.CNOT ** 0.01)
125
126
127def test_noise_composition():
128    # Verify that noise models can be composed without regard to ordering, as
129    # long as the noise operators commute with one another.
130    a, b, c = cirq.LineQubit.range(3)
131    noise_z = cirq.ConstantQubitNoiseModel(cirq.Z)
132    noise_inv_s = cirq.ConstantQubitNoiseModel(cirq.S ** -1)
133    merge = cirq.optimizers.merge_single_qubit_gates_into_phased_x_z
134    base_moments = [cirq.Moment([cirq.X(a)]), cirq.Moment([cirq.Y(b)]), cirq.Moment([cirq.H(c)])]
135    circuit_z = cirq.Circuit(noise_z.noisy_moments(base_moments, [a, b, c]))
136    circuit_s = cirq.Circuit(noise_inv_s.noisy_moments(base_moments, [a, b, c]))
137    actual_zs = cirq.Circuit(noise_inv_s.noisy_moments(circuit_z.moments, [a, b, c]))
138    actual_sz = cirq.Circuit(noise_z.noisy_moments(circuit_s.moments, [a, b, c]))
139
140    expected_circuit = cirq.Circuit(
141        cirq.Moment([cirq.X(a)]),
142        cirq.Moment([cirq.S(a), cirq.S(b), cirq.S(c)]),
143        cirq.Moment([cirq.Y(b)]),
144        cirq.Moment([cirq.S(a), cirq.S(b), cirq.S(c)]),
145        cirq.Moment([cirq.H(c)]),
146        cirq.Moment([cirq.S(a), cirq.S(b), cirq.S(c)]),
147    )
148
149    # All of the gates will be the same, just out of order. Merging fixes this.
150    merge(actual_zs)
151    merge(actual_sz)
152    merge(expected_circuit)
153    assert_equivalent_op_tree(actual_zs, actual_sz)
154    assert_equivalent_op_tree(actual_zs, expected_circuit)
155
156
157def test_constant_qubit_noise_repr():
158    cirq.testing.assert_equivalent_repr(cirq.ConstantQubitNoiseModel(cirq.X ** 0.01))
159
160
161def test_wrap():
162    class Forget(cirq.NoiseModel):
163        def noisy_operation(self, operation):
164            raise NotImplementedError()
165
166    forget = Forget()
167
168    assert cirq.NoiseModel.from_noise_model_like(None) is cirq.NO_NOISE
169    assert cirq.NoiseModel.from_noise_model_like(
170        cirq.depolarize(0.1)
171    ) == cirq.ConstantQubitNoiseModel(cirq.depolarize(0.1))
172    assert cirq.NoiseModel.from_noise_model_like(cirq.Z ** 0.01) == cirq.ConstantQubitNoiseModel(
173        cirq.Z ** 0.01
174    )
175    assert cirq.NoiseModel.from_noise_model_like(forget) is forget
176
177    with pytest.raises(TypeError, match='Expected a NOISE_MODEL_LIKE'):
178        _ = cirq.NoiseModel.from_noise_model_like('test')
179
180    with pytest.raises(ValueError, match='Multi-qubit gate'):
181        _ = cirq.NoiseModel.from_noise_model_like(cirq.CZ ** 0.01)
182
183
184def test_gate_substitution_noise_model():
185    def _overrotation(op):
186        if isinstance(op.gate, cirq.XPowGate):
187            return cirq.XPowGate(exponent=op.gate.exponent + 0.1).on(*op.qubits)
188        return op
189
190    noise = cirq.devices.noise_model.GateSubstitutionNoiseModel(_overrotation)
191
192    q0 = cirq.LineQubit(0)
193    circuit = cirq.Circuit(cirq.X(q0) ** 0.5, cirq.Y(q0))
194    circuit2 = cirq.Circuit(cirq.X(q0) ** 0.6, cirq.Y(q0))
195    rho1 = cirq.final_density_matrix(circuit, noise=noise)
196    rho2 = cirq.final_density_matrix(circuit2)
197    np.testing.assert_allclose(rho1, rho2)
198
199
200def test_moment_is_measurements():
201    q = cirq.LineQubit.range(2)
202    circ = cirq.Circuit([cirq.X(q[0]), cirq.X(q[1]), cirq.measure(*q, key='z')])
203    assert not validate_all_measurements(circ[0])
204    assert validate_all_measurements(circ[1])
205
206
207def test_moment_is_measurements_mixed1():
208    q = cirq.LineQubit.range(2)
209    circ = cirq.Circuit(
210        [
211            cirq.X(q[0]),
212            cirq.X(q[1]),
213            cirq.measure(q[0], key='z'),
214            cirq.Z(q[1]),
215        ]
216    )
217    assert not validate_all_measurements(circ[0])
218    with pytest.raises(ValueError) as e:
219        validate_all_measurements(circ[1])
220    assert e.match(".*must be homogeneous: all measurements.*")
221
222
223def test_moment_is_measurements_mixed2():
224    q = cirq.LineQubit.range(2)
225    circ = cirq.Circuit(
226        [
227            cirq.X(q[0]),
228            cirq.X(q[1]),
229            cirq.Z(q[0]),
230            cirq.measure(q[1], key='z'),
231        ]
232    )
233    assert not validate_all_measurements(circ[0])
234    with pytest.raises(ValueError) as e:
235        validate_all_measurements(circ[1])
236    assert e.match(".*must be homogeneous: all measurements.*")
237