1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import pytest
15
16import cirq
17from cirq.testing.devices import ValidatingTestDevice
18
19
20def test_validating_types_and_qubits():
21    dev = ValidatingTestDevice(
22        allowed_qubit_types=(cirq.GridQubit,),
23        allowed_gates=(cirq.XPowGate,),
24        qubits={cirq.GridQubit(0, 0)},
25        name='test',
26    )
27
28    dev.validate_operation(cirq.X(cirq.GridQubit(0, 0)))
29
30    with pytest.raises(ValueError, match="Unsupported qubit type"):
31        dev.validate_operation(cirq.X(cirq.NamedQubit("a")))
32
33    with pytest.raises(ValueError, match="Qubit not on device"):
34        dev.validate_operation(cirq.X(cirq.GridQubit(1, 0)))
35
36    with pytest.raises(ValueError, match="Unsupported gate type"):
37        dev.validate_operation(cirq.Y(cirq.GridQubit(0, 0)))
38
39
40def test_validating_locality():
41    dev = ValidatingTestDevice(
42        allowed_qubit_types=(cirq.GridQubit,),
43        allowed_gates=(cirq.CZPowGate, cirq.MeasurementGate),
44        qubits=set(cirq.GridQubit.rect(3, 3)),
45        name='test',
46        validate_locality=True,
47    )
48
49    dev.validate_operation(cirq.CZ(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)))
50    dev.validate_operation(cirq.measure(cirq.GridQubit(0, 0), cirq.GridQubit(0, 2)))
51
52    with pytest.raises(ValueError, match="Non-local interaction"):
53        dev.validate_operation(cirq.CZ(cirq.GridQubit(0, 0), cirq.GridQubit(0, 2)))
54
55    with pytest.raises(ValueError, match="GridQubit must be an allowed qubit type"):
56        ValidatingTestDevice(
57            allowed_qubit_types=(cirq.NamedQubit,),
58            allowed_gates=(cirq.CZPowGate, cirq.MeasurementGate),
59            qubits=set(cirq.GridQubit.rect(3, 3)),
60            name='test',
61            validate_locality=True,
62        )
63
64
65def test_autodecompose():
66    dev = ValidatingTestDevice(
67        allowed_qubit_types=(cirq.LineQubit,),
68        allowed_gates=(
69            cirq.XPowGate,
70            cirq.ZPowGate,
71            cirq.CZPowGate,
72            cirq.YPowGate,
73            cirq.MeasurementGate,
74        ),
75        qubits=set(cirq.LineQubit.range(3)),
76        name='test',
77        validate_locality=False,
78        auto_decompose_gates=(cirq.CCXPowGate,),
79    )
80
81    a, b, c = cirq.LineQubit.range(3)
82    circuit = cirq.Circuit(cirq.CCX(a, b, c), device=dev)
83    decomposed = cirq.decompose(cirq.CCX(a, b, c))
84    assert circuit.moments == cirq.Circuit(decomposed).moments
85
86    with pytest.raises(ValueError, match="Unsupported gate type: cirq.TOFFOLI"):
87        dev = ValidatingTestDevice(
88            allowed_qubit_types=(cirq.LineQubit,),
89            allowed_gates=(
90                cirq.XPowGate,
91                cirq.ZPowGate,
92                cirq.CZPowGate,
93                cirq.YPowGate,
94                cirq.MeasurementGate,
95            ),
96            qubits=set(cirq.LineQubit.range(3)),
97            name='test',
98            validate_locality=False,
99            auto_decompose_gates=tuple(),
100        )
101
102        a, b, c = cirq.LineQubit.range(3)
103        cirq.Circuit(cirq.CCX(a, b, c), device=dev)
104
105
106def test_repr():
107    dev = ValidatingTestDevice(
108        allowed_qubit_types=(cirq.GridQubit,),
109        allowed_gates=(cirq.CZPowGate, cirq.MeasurementGate),
110        qubits=set(cirq.GridQubit.rect(3, 3)),
111        name='test',
112        validate_locality=True,
113    )
114    assert repr(dev) == 'test'
115
116
117def test_defaults():
118    dev = ValidatingTestDevice(qubits={cirq.GridQubit(0, 0)})
119    assert repr(dev) == 'ValidatingTestDevice'
120    assert dev.allowed_qubit_types == (cirq.GridQubit,)
121    assert not dev.validate_locality
122    assert not dev.auto_decompose_gates
123