1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from datetime import timedelta
16
17import numpy as np
18import pytest
19
20import cirq
21import cirq.ion as ci
22
23
24def ion_device(chain_length: int, use_timedelta=False) -> ci.IonDevice:
25    ms = 1000 * cirq.Duration(nanos=1) if not use_timedelta else timedelta(microseconds=1)
26    return ci.IonDevice(  # type: ignore
27        measurement_duration=100 * ms,  # type: ignore
28        twoq_gates_duration=200 * ms,  # type: ignore
29        oneq_gates_duration=10 * ms,  # type: ignore
30        qubits=cirq.LineQubit.range(chain_length),
31    )
32
33
34class NotImplementedOperation(cirq.Operation):
35    def with_qubits(self, *new_qubits) -> 'NotImplementedOperation':
36        raise NotImplementedError()
37
38    @property
39    def qubits(self):
40        raise NotImplementedError()
41
42
43def test_init():
44    d = ion_device(3)
45    ms = 1000 * cirq.Duration(nanos=1)
46    q0 = cirq.LineQubit(0)
47    q1 = cirq.LineQubit(1)
48    q2 = cirq.LineQubit(2)
49
50    assert d.qubits == {q0, q1, q2}
51    assert d.duration_of(cirq.Z(q0)) == 10 * ms
52    assert d.duration_of(cirq.measure(q0)) == 100 * ms
53    assert d.duration_of(cirq.measure(q0, q1)) == 100 * ms
54    assert d.duration_of(cirq.ops.XX(q0, q1)) == 200 * ms
55    with pytest.raises(ValueError):
56        _ = d.duration_of(cirq.SingleQubitGate().on(q0))
57
58
59def test_init_timedelta():
60    d = ion_device(3, use_timedelta=True)
61    ms = 1000 * cirq.Duration(nanos=1)
62    q0 = cirq.LineQubit(0)
63    q1 = cirq.LineQubit(1)
64    q2 = cirq.LineQubit(2)
65
66    assert d.qubits == {q0, q1, q2}
67    assert d.duration_of(cirq.Z(q0)) == 10 * ms
68    assert d.duration_of(cirq.measure(q0)) == 100 * ms
69    assert d.duration_of(cirq.measure(q0, q1)) == 100 * ms
70    assert d.duration_of(cirq.ops.XX(q0, q1)) == 200 * ms
71    with pytest.raises(ValueError):
72        _ = d.duration_of(cirq.SingleQubitGate().on(q0))
73
74
75def test_decomposition():
76    d = ion_device(3)
77    q0 = cirq.LineQubit(0)
78    q1 = cirq.LineQubit(1)
79    assert d.decompose_operation(cirq.H(q0)) == [
80        cirq.rx(np.pi * 1.0).on(cirq.LineQubit(0)),
81        cirq.ry(np.pi * -0.5).on(cirq.LineQubit(0)),
82    ]
83    circuit = cirq.Circuit()
84    circuit.append([cirq.X(q0), cirq.CNOT(q0, q1)])
85    ion_circuit = d.decompose_circuit(circuit)
86    d.validate_circuit(ion_circuit)
87    cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
88        circuit, ion_circuit, atol=1e-6
89    )
90
91
92def test_repr():
93    d = ion_device(3)
94
95    assert repr(d) == (
96        "IonDevice("
97        "measurement_duration=cirq.Duration(micros=100), "
98        "twoq_gates_duration=cirq.Duration(micros=200), "
99        "oneq_gates_duration=cirq.Duration(micros=10) "
100        "qubits=[cirq.LineQubit(0), cirq.LineQubit(1), "
101        "cirq.LineQubit(2)])"
102    )
103
104
105def test_validate_measurement_non_adjacent_qubits_ok():
106    d = ion_device(3)
107
108    d.validate_operation(
109        cirq.GateOperation(cirq.MeasurementGate(2, 'key'), (cirq.LineQubit(0), cirq.LineQubit(1)))
110    )
111
112
113def test_validate_operation_existing_qubits():
114    d = ion_device(3)
115
116    d.validate_operation(cirq.GateOperation(cirq.XX, (cirq.LineQubit(0), cirq.LineQubit(1))))
117    d.validate_operation(cirq.Z(cirq.LineQubit(0)))
118    d.validate_operation(
119        cirq.PhasedXPowGate(phase_exponent=0.75, exponent=0.25, global_shift=0.1).on(
120            cirq.LineQubit(1)
121        )
122    )
123
124    with pytest.raises(ValueError):
125        d.validate_operation(cirq.CZ(cirq.LineQubit(0), cirq.LineQubit(-1)))
126    with pytest.raises(ValueError):
127        d.validate_operation(cirq.Z(cirq.LineQubit(-1)))
128    with pytest.raises(ValueError):
129        d.validate_operation(cirq.CZ(cirq.LineQubit(1), cirq.LineQubit(1)))
130    with pytest.raises(ValueError):
131        d.validate_operation(cirq.X(cirq.NamedQubit("q1")))
132
133
134def test_validate_operation_supported_gate():
135    d = ion_device(3)
136
137    class MyGate(cirq.Gate):
138        def num_qubits(self):
139            return 1
140
141    d.validate_operation(cirq.GateOperation(cirq.Z, [cirq.LineQubit(0)]))
142
143    assert MyGate().num_qubits() == 1
144    with pytest.raises(ValueError):
145        d.validate_operation(cirq.GateOperation(MyGate(), [cirq.LineQubit(0)]))
146    with pytest.raises(ValueError):
147        d.validate_operation(NotImplementedOperation())
148
149
150def test_can_add_operation_into_moment():
151    d = ion_device(3)
152    q0 = cirq.LineQubit(0)
153    q1 = cirq.LineQubit(1)
154    q2 = cirq.LineQubit(2)
155    q3 = cirq.LineQubit(3)
156    circuit = cirq.Circuit()
157    circuit.append(cirq.XX(q0, q1))
158    for moment in circuit:
159        assert not d.can_add_operation_into_moment(cirq.XX(q2, q0), moment)
160        assert not d.can_add_operation_into_moment(cirq.XX(q1, q2), moment)
161        assert d.can_add_operation_into_moment(cirq.XX(q2, q3), moment)
162        assert d.can_add_operation_into_moment(cirq.Z(q3), moment)
163    circuit = cirq.Circuit([cirq.X(q0)])
164    assert d.can_add_operation_into_moment(cirq.XX(q1, q2), circuit[0])
165
166
167def test_ion_device_eq():
168    eq = cirq.testing.EqualsTester()
169    eq.make_equality_group(lambda: ion_device(3))
170    eq.make_equality_group(lambda: ion_device(4))
171
172
173def test_validate_circuit_repeat_measurement_keys():
174    d = ion_device(3)
175
176    circuit = cirq.Circuit()
177    circuit.append(
178        [cirq.measure(cirq.LineQubit(0), key='a'), cirq.measure(cirq.LineQubit(1), key='a')]
179    )
180
181    with pytest.raises(ValueError, match='Measurement key a repeated'):
182        d.validate_circuit(circuit)
183
184
185def test_ion_device_str():
186    assert (
187        str(ion_device(3)).strip()
188        == """
1890───1───2
190    """.strip()
191    )
192
193
194def test_at():
195    d = ion_device(3)
196    assert d.at(-1) is None
197    assert d.at(0) == cirq.LineQubit(0)
198    assert d.at(2) == cirq.LineQubit(2)
199
200
201def test_qubit_set():
202    assert ion_device(3).qubit_set() == frozenset(cirq.LineQubit.range(3))
203
204
205def test_qid_pairs():
206    assert len(ion_device(10).qid_pairs()) == 45
207