1# Copyright 2021 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import functools
15import itertools
16import math
17import random
18
19import numpy as np
20import pytest
21import sympy.parsing.sympy_parser as sympy_parser
22
23import cirq
24import cirq.ops.boolean_hamiltonian as bh
25
26
27@pytest.mark.parametrize(
28    'boolean_str',
29    [
30        'x0',
31        '~x0',
32        'x0 ^ x1',
33        'x0 & x1',
34        'x0 | x1',
35        'x0 & x1 & x2',
36        'x0 & x1 & ~x2',
37        'x0 & ~x1 & x2',
38        'x0 & ~x1 & ~x2',
39        '~x0 & x1 & x2',
40        '~x0 & x1 & ~x2',
41        '~x0 & ~x1 & x2',
42        '~x0 & ~x1 & ~x2',
43        'x0 ^ x1 ^ x2',
44        'x0 | (x1 & x2)',
45        'x0 & (x1 | x2)',
46        '(x0 ^ x1 ^ x2) | (x2 ^ x3 ^ x4)',
47        '(x0 ^ x2 ^ x4) | (x1 ^ x2 ^ x3)',
48        'x0 & x1 & (x2 | x3)',
49        'x0 & ~x2',
50        '~x0 & x2',
51        'x2 & ~x0',
52        '~x2 & x0',
53        '(x2 | x1) ^ x0',
54    ],
55)
56def test_circuit(boolean_str):
57    boolean_expr = sympy_parser.parse_expr(boolean_str)
58    var_names = cirq.parameter_names(boolean_expr)
59    qubits = [cirq.NamedQubit(name) for name in var_names]
60
61    # We use Sympy to evaluate the expression:
62    n = len(var_names)
63
64    expected = []
65    for binary_inputs in itertools.product([0, 1], repeat=n):
66        subed_expr = boolean_expr
67        for var_name, binary_input in zip(var_names, binary_inputs):
68            subed_expr = subed_expr.subs(var_name, binary_input)
69        expected.append(bool(subed_expr))
70
71    # We build a circuit and look at its output state vector:
72    circuit = cirq.Circuit()
73    circuit.append(cirq.H.on_each(*qubits))
74
75    hamiltonian_gate = cirq.BooleanHamiltonian(
76        {q.name: q for q in qubits}, [boolean_str], 0.1 * math.pi
77    )
78
79    assert hamiltonian_gate.num_qubits() == n
80
81    circuit.append(hamiltonian_gate)
82
83    phi = cirq.Simulator().simulate(circuit, qubit_order=qubits, initial_state=0).state_vector()
84    actual = np.arctan2(phi.real, phi.imag) - math.pi / 2.0 > 0.0
85
86    # Compare the two:
87    np.testing.assert_array_equal(actual, expected)
88
89
90def test_with_custom_names():
91    q0, q1, q2, q3 = cirq.LineQubit.range(4)
92    original_op = cirq.BooleanHamiltonian(
93        {'a': q0, 'b': q1},
94        ['a'],
95        0.1,
96    )
97    assert cirq.decompose(original_op) == [cirq.Rz(rads=-0.05).on(q0)]
98
99    renamed_op = original_op.with_qubits(q2, q3)
100    assert cirq.decompose(renamed_op) == [cirq.Rz(rads=-0.05).on(q2)]
101
102    with pytest.raises(ValueError, match='Length of replacement qubits must be the same'):
103        original_op.with_qubits(q2)
104
105
106@pytest.mark.parametrize(
107    'n_bits,expected_hs',
108    [
109        (1, [(), (0,)]),
110        (2, [(), (0,), (0, 1), (1,)]),
111        (3, [(), (0,), (0, 1), (1,), (1, 2), (0, 1, 2), (0, 2), (2,)]),
112    ],
113)
114def test_gray_code_sorting(n_bits, expected_hs):
115    hs_template = []
116    for x in range(2 ** n_bits):
117        h = []
118        for i in range(n_bits):
119            if x % 2 == 1:
120                h.append(i)
121                x -= 1
122            x //= 2
123        hs_template.append(tuple(sorted(h)))
124
125    for seed in range(10):
126        random.seed(seed)
127
128        hs = hs_template.copy()
129        random.shuffle(hs)
130
131        sorted_hs = sorted(list(hs), key=functools.cmp_to_key(bh._gray_code_comparator))
132
133        np.testing.assert_array_equal(sorted_hs, expected_hs)
134
135
136@pytest.mark.parametrize(
137    'seq_a,seq_b,expected',
138    [
139        ((), (), 0),
140        ((), (0,), -1),
141        ((0,), (), 1),
142        ((0,), (0,), 0),
143    ],
144)
145def test_gray_code_comparison(seq_a, seq_b, expected):
146    assert bh._gray_code_comparator(seq_a, seq_b) == expected
147
148
149@pytest.mark.parametrize(
150    'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots',
151    [
152        # Empty inputs don't get simplified.
153        ([], False, False, []),
154        ([], True, False, []),
155        # Single CNOTs don't get simplified.
156        ([(0, 1)], False, False, [(0, 1)]),
157        ([(0, 1)], True, False, [(0, 1)]),
158        # Simplify away two CNOTs that are identical:
159        ([(0, 1), (0, 1)], False, True, []),
160        ([(0, 1), (0, 1)], True, True, []),
161        # Also simplify away if there's another CNOT in between.
162        ([(0, 1), (2, 1), (0, 1)], False, True, [(2, 1)]),
163        ([(0, 1), (0, 2), (0, 1)], True, True, [(0, 2)]),
164        # However, the in-between has to share the same target/control.
165        ([(0, 1), (0, 2), (0, 1)], False, False, [(0, 1), (0, 2), (0, 1)]),
166        ([(0, 1), (2, 1), (0, 1)], True, False, [(0, 1), (2, 1), (0, 1)]),
167        # Can simplify, but violates CNOT ordering assumption
168        ([(0, 1), (2, 3), (0, 1)], False, False, [(0, 1), (2, 3), (0, 1)]),
169    ],
170)
171def test_simplify_commuting_cnots(
172    input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots
173):
174    actual_simplified, actual_output_cnots = bh._simplify_commuting_cnots(
175        input_cnots, input_flip_control_and_target
176    )
177    assert actual_simplified == expected_simplified
178    assert actual_output_cnots == expected_output_cnots
179
180
181@pytest.mark.parametrize(
182    'input_cnots,input_flip_control_and_target,expected_simplified,expected_output_cnots',
183    [
184        # Empty inputs don't get simplified.
185        ([], False, False, []),
186        ([], True, False, []),
187        # Single CNOTs don't get simplified.
188        ([(0, 1)], False, False, [(0, 1)]),
189        ([(0, 1)], True, False, [(0, 1)]),
190        # Simplify according to equation 11 of [4].
191        ([(2, 1), (2, 0), (1, 0)], False, True, [(1, 0), (2, 1)]),
192        ([(1, 2), (0, 2), (0, 1)], True, True, [(0, 1), (1, 2)]),
193        # Same as above, but with a intervening CNOTs that prevent simplifications.
194        ([(2, 1), (2, 0), (100, 101), (1, 0)], False, False, [(2, 1), (2, 0), (100, 101), (1, 0)]),
195        ([(2, 1), (100, 101), (2, 0), (1, 0)], False, False, [(2, 1), (100, 101), (2, 0), (1, 0)]),
196        # swap (2, 1) and (1, 0) around (2, 0)
197        ([(2, 1), (2, 3), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
198        ([(2, 1), (2, 0), (2, 3), (3, 0), (1, 0)], False, True, [(1, 0), (2, 1), (2, 3), (3, 0)]),
199        ([(2, 3), (2, 1), (2, 0), (3, 0), (1, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
200        ([(2, 1), (2, 3), (3, 0), (2, 0), (1, 0)], False, True, [(2, 3), (3, 0), (1, 0), (2, 1)]),
201        ([(2, 1), (2, 3), (2, 0), (1, 0), (3, 0)], False, True, [(2, 3), (1, 0), (2, 1), (3, 0)]),
202    ],
203)
204def test_simplify_cnots_triplets(
205    input_cnots, input_flip_control_and_target, expected_simplified, expected_output_cnots
206):
207    actual_simplified, actual_output_cnots = bh._simplify_cnots_triplets(
208        input_cnots, input_flip_control_and_target
209    )
210    assert actual_simplified == expected_simplified
211    assert actual_output_cnots == expected_output_cnots
212
213    # Check that the unitaries are the same.
214    qubit_ids = set(sum(input_cnots, ()))
215    qubits = {qubit_id: cirq.NamedQubit(f"{qubit_id}") for qubit_id in qubit_ids}
216
217    target, control = (0, 1) if input_flip_control_and_target else (1, 0)
218
219    circuit_input = cirq.Circuit()
220    for input_cnot in input_cnots:
221        circuit_input.append(cirq.CNOT(qubits[input_cnot[target]], qubits[input_cnot[control]]))
222    circuit_actual = cirq.Circuit()
223    for actual_cnot in actual_output_cnots:
224        circuit_actual.append(cirq.CNOT(qubits[actual_cnot[target]], qubits[actual_cnot[control]]))
225
226    np.testing.assert_allclose(cirq.unitary(circuit_input), cirq.unitary(circuit_actual), atol=1e-6)
227