1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import itertools
16
17import numpy as np
18import pytest
19import scipy.linalg
20
21import cirq
22
23I = np.eye(2)
24X = np.array([[0, 1], [1, 0]])
25Y = np.array([[0, -1j], [1j, 0]])
26Z = np.array([[1, 0], [0, -1]])
27H = np.array([[1, 1], [1, -1]]) * np.sqrt(0.5)
28SQRT_X = np.array([[np.sqrt(1j), np.sqrt(-1j)], [np.sqrt(-1j), np.sqrt(1j)]]) * np.sqrt(0.5)
29SQRT_Y = np.array([[np.sqrt(1j), -np.sqrt(1j)], [np.sqrt(1j), np.sqrt(1j)]]) * np.sqrt(0.5)
30SQRT_Z = np.diag([1, 1j])
31E00 = np.diag([1, 0])
32E01 = np.array([[0, 1], [0, 0]])
33E10 = np.array([[0, 0], [1, 0]])
34E11 = np.diag([0, 1])
35PAULI_BASIS = cirq.PAULI_BASIS
36STANDARD_BASIS = {'a': E00, 'b': E01, 'c': E10, 'd': E11}
37
38
39def _one_hot_matrix(size: int, i: int, j: int) -> np.ndarray:
40    result = np.zeros((size, size))
41    result[i, j] = 1
42    return result
43
44
45@pytest.mark.parametrize(
46    'basis1, basis2, expected_kron_basis',
47    (
48        (
49            PAULI_BASIS,
50            PAULI_BASIS,
51            {
52                'II': np.eye(4),
53                'IX': scipy.linalg.block_diag(X, X),
54                'IY': scipy.linalg.block_diag(Y, Y),
55                'IZ': np.diag([1, -1, 1, -1]),
56                'XI': np.array([[0, 0, 1, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 1, 0, 0]]),
57                'XX': np.rot90(np.eye(4)),
58                'XY': np.rot90(np.diag([1j, -1j, 1j, -1j])),
59                'XZ': np.array([[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], [0, -1, 0, 0]]),
60                'YI': np.array([[0, 0, -1j, 0], [0, 0, 0, -1j], [1j, 0, 0, 0], [0, 1j, 0, 0]]),
61                'YX': np.rot90(np.diag([1j, 1j, -1j, -1j])),
62                'YY': np.rot90(np.diag([-1, 1, 1, -1])),
63                'YZ': np.array([[0, 0, -1j, 0], [0, 0, 0, 1j], [1j, 0, 0, 0], [0, -1j, 0, 0]]),
64                'ZI': np.diag([1, 1, -1, -1]),
65                'ZX': scipy.linalg.block_diag(X, -X),
66                'ZY': scipy.linalg.block_diag(Y, -Y),
67                'ZZ': np.diag([1, -1, -1, 1]),
68            },
69        ),
70        (
71            STANDARD_BASIS,
72            STANDARD_BASIS,
73            {
74                'abcd'[2 * row_outer + col_outer]
75                + 'abcd'[2 * row_inner + col_inner]: _one_hot_matrix(
76                    4, 2 * row_outer + row_inner, 2 * col_outer + col_inner
77                )
78                for row_outer in range(2)
79                for row_inner in range(2)
80                for col_outer in range(2)
81                for col_inner in range(2)
82            },
83        ),
84    ),
85)
86def test_kron_bases(basis1, basis2, expected_kron_basis):
87    kron_basis = cirq.kron_bases(basis1, basis2)
88    assert len(kron_basis) == 16
89    assert set(kron_basis.keys()) == set(expected_kron_basis.keys())
90    for name in kron_basis.keys():
91        assert np.all(kron_basis[name] == expected_kron_basis[name])
92
93
94@pytest.mark.parametrize(
95    'basis1,basis2',
96    (
97        (PAULI_BASIS, cirq.kron_bases(PAULI_BASIS)),
98        (STANDARD_BASIS, cirq.kron_bases(STANDARD_BASIS, repeat=1)),
99        (cirq.kron_bases(PAULI_BASIS, PAULI_BASIS), cirq.kron_bases(PAULI_BASIS, repeat=2)),
100        (
101            cirq.kron_bases(
102                cirq.kron_bases(PAULI_BASIS, repeat=2),
103                cirq.kron_bases(PAULI_BASIS, repeat=3),
104                PAULI_BASIS,
105            ),
106            cirq.kron_bases(PAULI_BASIS, repeat=6),
107        ),
108        (
109            cirq.kron_bases(
110                cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS),
111                cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS),
112            ),
113            cirq.kron_bases(PAULI_BASIS, STANDARD_BASIS, repeat=2),
114        ),
115    ),
116)
117def test_kron_bases_consistency(basis1, basis2):
118    assert set(basis1.keys()) == set(basis2.keys())
119    for name in basis1.keys():
120        assert np.all(basis1[name] == basis2[name])
121
122
123@pytest.mark.parametrize(
124    'basis,repeat', itertools.product((PAULI_BASIS, STANDARD_BASIS), range(1, 5))
125)
126def test_kron_bases_repeat_sanity_checks(basis, repeat):
127    product_basis = cirq.kron_bases(basis, repeat=repeat)
128    assert len(product_basis) == 4 ** repeat
129    for name1, matrix1 in product_basis.items():
130        for name2, matrix2 in product_basis.items():
131            p = cirq.hilbert_schmidt_inner_product(matrix1, matrix2)
132            if name1 != name2:
133                assert p == 0
134            else:
135                assert abs(p) >= 1
136
137
138@pytest.mark.parametrize(
139    'm1,m2,expect_real',
140    (
141        (X, X, True),
142        (X, Y, True),
143        (X, H, True),
144        (X, SQRT_X, False),
145        (I, SQRT_Z, False),
146    ),
147)
148def test_hilbert_schmidt_inner_product_is_conjugate_symmetric(m1, m2, expect_real):
149    v1 = cirq.hilbert_schmidt_inner_product(m1, m2)
150    v2 = cirq.hilbert_schmidt_inner_product(m2, m1)
151    assert v1 == v2.conjugate()
152
153    assert np.isreal(v1) == expect_real
154    if not expect_real:
155        assert v1 != v2
156
157
158@pytest.mark.parametrize(
159    'a,m1,b,m2',
160    (
161        (1, X, 1, Z),
162        (2, X, 3, Y),
163        (2j, X, 3, I),
164        (2, X, 3, X),
165    ),
166)
167def test_hilbert_schmidt_inner_product_is_linear(a, m1, b, m2):
168    v1 = cirq.hilbert_schmidt_inner_product(H, (a * m1 + b * m2))
169    v2 = a * cirq.hilbert_schmidt_inner_product(H, m1) + b * cirq.hilbert_schmidt_inner_product(
170        H, m2
171    )
172    assert v1 == v2
173
174
175@pytest.mark.parametrize('m', (I, X, Y, Z, H, SQRT_X, SQRT_Y, SQRT_Z))
176def test_hilbert_schmidt_inner_product_is_positive_definite(m):
177    v = cirq.hilbert_schmidt_inner_product(m, m)
178    assert np.isreal(v)
179    assert v.real > 0
180
181
182@pytest.mark.parametrize(
183    'm1,m2,expected_value',
184    (
185        (X, I, 0),
186        (X, X, 2),
187        (X, Y, 0),
188        (X, Z, 0),
189        (H, X, np.sqrt(2)),
190        (H, Y, 0),
191        (H, Z, np.sqrt(2)),
192        (Z, E00, 1),
193        (Z, E01, 0),
194        (Z, E10, 0),
195        (Z, E11, -1),
196        (SQRT_X, E00, np.sqrt(-0.5j)),
197        (SQRT_X, E01, np.sqrt(0.5j)),
198        (SQRT_X, E10, np.sqrt(0.5j)),
199        (SQRT_X, E11, np.sqrt(-0.5j)),
200    ),
201)
202def test_hilbert_schmidt_inner_product_values(m1, m2, expected_value):
203    v = cirq.hilbert_schmidt_inner_product(m1, m2)
204    assert np.isclose(v, expected_value)
205
206
207@pytest.mark.parametrize(
208    'm,basis',
209    itertools.product(
210        (I, X, Y, Z, H, SQRT_X, SQRT_Y, SQRT_Z),
211        (PAULI_BASIS, STANDARD_BASIS),
212    ),
213)
214def test_expand_matrix_in_orthogonal_basis(m, basis):
215    expansion = cirq.expand_matrix_in_orthogonal_basis(m, basis)
216
217    reconstructed = np.zeros(m.shape, dtype=complex)
218    for name, coefficient in expansion.items():
219        reconstructed += coefficient * basis[name]
220    assert np.allclose(m, reconstructed)
221
222
223@pytest.mark.parametrize(
224    'expansion',
225    (
226        {'I': 1},
227        {'X': 1},
228        {'Y': 1},
229        {'Z': 1},
230        {'X': 1, 'Z': 1},
231        {'I': 0.5, 'X': 0.4, 'Y': 0.3, 'Z': 0.2},
232        {'I': 1, 'X': 2, 'Y': 3, 'Z': 4},
233    ),
234)
235def test_matrix_from_basis_coefficients(expansion):
236    m = cirq.matrix_from_basis_coefficients(expansion, PAULI_BASIS)
237
238    for name, coefficient in expansion.items():
239        element = PAULI_BASIS[name]
240        expected_coefficient = cirq.hilbert_schmidt_inner_product(
241            m, element
242        ) / cirq.hilbert_schmidt_inner_product(element, element)
243        assert np.isclose(coefficient, expected_coefficient)
244
245
246@pytest.mark.parametrize(
247    'm1,basis',
248    (
249        itertools.product(
250            (I, X, Y, Z, H, SQRT_X, SQRT_Y, SQRT_Z, E00, E01, E10, E11),
251            (PAULI_BASIS, STANDARD_BASIS),
252        )
253    ),
254)
255def test_expand_is_inverse_of_reconstruct(m1, basis):
256    c1 = cirq.expand_matrix_in_orthogonal_basis(m1, basis)
257    m2 = cirq.matrix_from_basis_coefficients(c1, basis)
258    c2 = cirq.expand_matrix_in_orthogonal_basis(m2, basis)
259    assert np.allclose(m1, m2)
260    assert c1 == c2
261
262
263@pytest.mark.parametrize(
264    'coefficients,exponent',
265    itertools.product(
266        (
267            (0, 0, 0, 0),
268            (-1, 0, 0, 0),
269            (0.5, 0, 0, 0),
270            (0.5j, 0, 0, 0),
271            (1, 0, 0, 0),
272            (2, 0, 0, 0),
273            (0, -1, 0, 0),
274            (0, 0.5, 0, 0),
275            (0, 0.5j, 0, 0),
276            (0, 1, 0, 0),
277            (0, 2, 0, 0),
278            (0, 0, -1, 0),
279            (0, 0, 0.5, 0),
280            (0, 0, 0.5j, 0),
281            (0, 0, 1, 0),
282            (0, 0, 2, 0),
283            (0, 0, 0, -1),
284            (0, 0, 0, 0.5),
285            (0, 0, 0, 0.5j),
286            (0, 0, 0, 1),
287            (0, 0, 0, 2),
288            (0, -1, 0, -1),
289            (0, 1, 0, 1j),
290            (0, 0.5, 0, 0.5),
291            (0, 0.5j, 0, 0.5j),
292            (0, 0.5, 0, 0.5j),
293            (0, 1, 0, 1),
294            (0, 2, 0, 2),
295            (0, 0.5, 0.5, 0.5),
296            (0, 1, 1, 1),
297            (0, 1.1j, 0.5 - 0.4j, 0.9),
298            (0.7j, 1.1j, 0.5 - 0.4j, 0.9),
299            (0.25, 0.25, 0.25, 0.25),
300            (0.25j, 0.25j, 0.25j, 0.25j),
301            (0.4, 0, 0.5, 0),
302            (1, 2, 3, 4),
303            (-1, -2, -3, -4),
304            (-1, -2, 3, 4),
305            (1j, 2j, 3j, 4j),
306            (1j, 2j, 3, 4),
307        ),
308        (0, 1, 2, 3, 4, 5, 100, 101),
309    ),
310)
311def test_pow_pauli_combination(coefficients, exponent):
312    i = cirq.PAULI_BASIS['I']
313    x = cirq.PAULI_BASIS['X']
314    y = cirq.PAULI_BASIS['Y']
315    z = cirq.PAULI_BASIS['Z']
316    ai, ax, ay, az = coefficients
317
318    matrix = ai * i + ax * x + ay * y + az * z
319    expected_result = np.linalg.matrix_power(matrix, exponent)
320
321    bi, bx, by, bz = cirq.pow_pauli_combination(ai, ax, ay, az, exponent)
322    result = bi * i + bx * x + by * y + bz * z
323
324    assert np.allclose(result, expected_result)
325