1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import Optional, Dict, Sequence, Union, cast
15import random
16
17import numpy as np
18import pytest
19
20import cirq
21import cirq.testing
22
23
24def test_random_circuit_errors():
25    with pytest.raises(ValueError, match='but was -1'):
26        _ = cirq.testing.random_circuit(qubits=5, n_moments=5, op_density=-1)
27
28    with pytest.raises(ValueError, match='empty'):
29        _ = cirq.testing.random_circuit(qubits=5, n_moments=5, op_density=0.5, gate_domain={})
30
31    with pytest.raises(ValueError, match='At least one'):
32        _ = cirq.testing.random_circuit(qubits=0, n_moments=5, op_density=0.5)
33
34    with pytest.raises(ValueError, match='At least one'):
35        _ = cirq.testing.random_circuit(qubits=(), n_moments=5, op_density=0.5)
36
37    with pytest.raises(
38        ValueError,
39        match='After removing gates that act on less than 1 qubits, gate_domain had no gates',
40    ):
41        _ = cirq.testing.random_circuit(
42            qubits=1, n_moments=5, op_density=0.5, gate_domain={cirq.CNOT: 2}
43        )
44
45
46def _cases_for_random_circuit():
47    i = 0
48    while i < 10:
49        n_qubits = random.randint(1, 20)
50        n_moments = random.randint(1, 10)
51        op_density = random.random()
52        if random.randint(0, 1):
53            gate_domain = dict(
54                random.sample(
55                    tuple(cirq.testing.DEFAULT_GATE_DOMAIN.items()),
56                    random.randint(1, len(cirq.testing.DEFAULT_GATE_DOMAIN)),
57                )
58            )
59            # Sometimes we generate gate domains whose gates all act on a
60            # number of qubits greater that the number of qubits for the
61            # circuit. In this case, try again.
62            if all(n > n_qubits for n in gate_domain.values()):
63                # coverage: ignore
64                continue
65        else:
66            gate_domain = None
67        pass_qubits = random.choice((True, False))
68        yield (n_qubits, n_moments, op_density, gate_domain, pass_qubits)
69        i += 1
70
71
72@pytest.mark.parametrize(
73    'n_qubits,n_moments,op_density,gate_domain,pass_qubits', _cases_for_random_circuit()
74)
75def test_random_circuit(
76    n_qubits: Union[int, Sequence[cirq.Qid]],
77    n_moments: int,
78    op_density: float,
79    gate_domain: Optional[Dict[cirq.Gate, int]],
80    pass_qubits: bool,
81):
82    qubit_set = cirq.LineQubit.range(n_qubits)
83    qubit_arg = qubit_set if pass_qubits else n_qubits
84    circuit = cirq.testing.random_circuit(qubit_arg, n_moments, op_density, gate_domain)
85    if qubit_arg is qubit_set:
86        assert circuit.all_qubits().issubset(qubit_set)
87    assert len(circuit) == n_moments
88    if gate_domain is None:
89        gate_domain = cirq.testing.DEFAULT_GATE_DOMAIN
90    assert set(cast(cirq.GateOperation, op).gate for op in circuit.all_operations()).issubset(
91        gate_domain
92    )
93
94
95@pytest.mark.parametrize('seed', [random.randint(0, 2 ** 32) for _ in range(10)])
96def test_random_circuit_reproducible_with_seed(seed):
97    wrappers = (lambda s: s, np.random.RandomState)
98    circuits = [
99        cirq.testing.random_circuit(
100            qubits=10, n_moments=10, op_density=0.7, random_state=wrapper(seed)
101        )
102        for wrapper in wrappers
103        for _ in range(2)
104    ]
105    eq = cirq.testing.EqualsTester()
106    eq.add_equality_group(*circuits)
107
108
109def test_random_circuit_not_expected_number_of_qubits():
110
111    circuit = cirq.testing.random_circuit(
112        qubits=3, n_moments=1, op_density=1.0, gate_domain={cirq.CNOT: 2}
113    )
114    # Despite having an op density of 1, we always only end up acting on
115    # two qubits.
116    assert len(circuit.all_qubits()) == 2
117
118
119def test_random_circuit_reproducible_between_runs():
120    circuit = cirq.testing.random_circuit(5, 8, 0.5, random_state=77)
121    expected_diagram = """
122                  ┌──┐
1230: ────────────────S─────iSwap───────Y───X───
1241251: ───────────Y──────────iSwap───────Y───────
126
1272: ─────────────────X────T───────────S───S───
1281293: ───────@────────S┼────H───────────────Z───
130          │         │
1314: ───────@─────────@────────────────────X───
132                  └──┘
133    """
134    cirq.testing.assert_has_diagram(circuit, expected_diagram)
135
136
137def test_random_two_qubit_circuit_with_czs():
138    num_czs = lambda circuit: len(
139        [o for o in circuit.all_operations() if isinstance(o.gate, cirq.CZPowGate)]
140    )
141
142    c = cirq.testing.random_two_qubit_circuit_with_czs()
143    assert num_czs(c) == 3
144    assert {cirq.NamedQubit('q0'), cirq.NamedQubit('q1')} == c.all_qubits()
145    assert all(isinstance(op.gate, cirq.PhasedXPowGate) for op in c[0].operations)
146    assert c[0].qubits == c.all_qubits()
147
148    c = cirq.testing.random_two_qubit_circuit_with_czs(num_czs=0)
149    assert num_czs(c) == 0
150    assert {cirq.NamedQubit('q0'), cirq.NamedQubit('q1')} == c.all_qubits()
151    assert all(isinstance(op.gate, cirq.PhasedXPowGate) for op in c[0].operations)
152    assert c[0].qubits == c.all_qubits()
153
154    a, b = cirq.LineQubit.range(2)
155    c = cirq.testing.random_two_qubit_circuit_with_czs(num_czs=1, q1=b)
156    assert num_czs(c) == 1
157    assert {b, cirq.NamedQubit('q0')} == c.all_qubits()
158    assert all(isinstance(op.gate, cirq.PhasedXPowGate) for op in c[0].operations)
159    assert c[0].qubits == c.all_qubits()
160
161    c = cirq.testing.random_two_qubit_circuit_with_czs(num_czs=2, q0=a)
162    assert num_czs(c) == 2
163    assert {a, cirq.NamedQubit('q1')} == c.all_qubits()
164    assert all(isinstance(op.gate, cirq.PhasedXPowGate) for op in c[0].operations)
165    assert c[0].qubits == c.all_qubits()
166
167    c = cirq.testing.random_two_qubit_circuit_with_czs(num_czs=3, q0=a, q1=b)
168    assert num_czs(c) == 3
169    assert c.all_qubits() == {a, b}
170    assert all(isinstance(op.gate, cirq.PhasedXPowGate) for op in c[0].operations)
171    assert c[0].qubits == c.all_qubits()
172
173    seed = 77
174
175    c1 = cirq.testing.random_two_qubit_circuit_with_czs(num_czs=4, q0=a, q1=b, random_state=seed)
176    assert num_czs(c1) == 4
177    assert c1.all_qubits() == {a, b}
178
179    c2 = cirq.testing.random_two_qubit_circuit_with_czs(num_czs=4, q0=a, q1=b, random_state=seed)
180
181    assert c1 == c2
182