1import random
2
3import numpy as np
4import pytest
5
6import cirq
7
8
9def _operations_to_matrix(operations, qubits):
10    return cirq.Circuit(operations).unitary(
11        qubit_order=cirq.QubitOrder.explicit(qubits), qubits_that_should_be_present=qubits
12    )
13
14
15def _random_single_MS_effect():
16    t = random.random()
17    s = np.sin(t)
18    c = np.cos(t)
19    return cirq.dot(
20        cirq.kron(cirq.testing.random_unitary(2), cirq.testing.random_unitary(2)),
21        np.array([[c, 0, 0, -1j * s], [0, c, -1j * s, 0], [0, -1j * s, c, 0], [-1j * s, 0, 0, c]]),
22        cirq.kron(cirq.testing.random_unitary(2), cirq.testing.random_unitary(2)),
23    )
24
25
26def _random_double_MS_effect():
27    t1 = random.random()
28    s1 = np.sin(t1)
29    c1 = np.cos(t1)
30
31    t2 = random.random()
32    s2 = np.sin(t2)
33    c2 = np.cos(t2)
34    return cirq.dot(
35        cirq.kron(cirq.testing.random_unitary(2), cirq.testing.random_unitary(2)),
36        np.array(
37            [[c1, 0, 0, -1j * s1], [0, c1, -1j * s1, 0], [0, -1j * s1, c1, 0], [-1j * s1, 0, 0, c1]]
38        ),
39        cirq.kron(cirq.testing.random_unitary(2), cirq.testing.random_unitary(2)),
40        np.array(
41            [[c2, 0, 0, -1j * s2], [0, c2, -1j * s2, 0], [0, -1j * s2, c2, 0], [-1j * s2, 0, 0, c2]]
42        ),
43        cirq.kron(cirq.testing.random_unitary(2), cirq.testing.random_unitary(2)),
44    )
45
46
47def assert_ops_implement_unitary(q0, q1, operations, intended_effect, atol=0.01):
48    actual_effect = _operations_to_matrix(operations, (q0, q1))
49    assert cirq.allclose_up_to_global_phase(actual_effect, intended_effect, atol=atol)
50
51
52def assert_ms_depth_below(operations, threshold):
53    total_ms = 0
54
55    for op in operations:
56        assert len(op.qubits) <= 2
57        if len(op.qubits) == 2:
58            assert isinstance(op, cirq.GateOperation)
59            assert isinstance(op.gate, cirq.XXPowGate)
60            total_ms += abs(op.gate.exponent)
61    assert total_ms <= threshold
62
63
64# yapf: disable
65@pytest.mark.parametrize('max_ms_depth,effect', [
66    (0, np.eye(4)),
67    (0, np.array([
68        [0, 0, 0, 1],
69        [0, 0, 1, 0],
70        [0, 1, 0, 0],
71        [1, 0, 0, 0j]
72    ])),
73    (1, cirq.unitary(cirq.ms(np.pi/4))),
74
75    (0, cirq.unitary(cirq.CZ ** 0.00000001)),
76    (0.5, cirq.unitary(cirq.CZ ** 0.5)),
77
78    (1, cirq.unitary(cirq.CZ)),
79    (1, cirq.unitary(cirq.CNOT)),
80    (1, np.array([
81        [1, 0, 0, 1j],
82        [0, 1, 1j, 0],
83        [0, 1j, 1, 0],
84        [1j, 0, 0, 1],
85    ]) * np.sqrt(0.5)),
86    (1, np.array([
87        [1, 0, 0, -1j],
88        [0, 1, -1j, 0],
89        [0, -1j, 1, 0],
90        [-1j, 0, 0, 1],
91    ]) * np.sqrt(0.5)),
92    (1, np.array([
93        [1, 0, 0, 1j],
94        [0, 1, -1j, 0],
95        [0, -1j, 1, 0],
96        [1j, 0, 0, 1],
97    ]) * np.sqrt(0.5)),
98
99    (1.5, cirq.map_eigenvalues(cirq.unitary(cirq.SWAP),
100                               lambda e: e ** 0.5)),
101
102    (2, cirq.unitary(cirq.SWAP).dot(cirq.unitary(cirq.CZ))),
103
104    (3, cirq.unitary(cirq.SWAP)),
105    (3, np.array([
106        [0, 0, 0, 1],
107        [0, 1, 0, 0],
108        [0, 0, 1, 0],
109        [1, 0, 0, 0j],
110    ])),
111] + [
112    (1, _random_single_MS_effect()) for _ in range(10)
113] + [
114    (3, cirq.testing.random_unitary(4)) for _ in range(10)
115] + [
116    (2, _random_double_MS_effect()) for _ in range(10)
117])
118# yapf: enable
119def test_two_to_ops(max_ms_depth: int, effect: np.array):
120    q0 = cirq.NamedQubit('q0')
121    q1 = cirq.NamedQubit('q1')
122
123    operations = cirq.two_qubit_matrix_to_ion_operations(q0, q1, effect)
124    assert_ops_implement_unitary(q0, q1, operations, effect)
125    assert_ms_depth_below(operations, max_ms_depth)
126