1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import cmath
16import numpy as np
17import pytest
18
19import cirq
20from cirq.linalg import matrix_commutes
21
22
23def test_is_diagonal():
24    assert cirq.is_diagonal(np.empty((0, 0)))
25    assert cirq.is_diagonal(np.empty((1, 0)))
26    assert cirq.is_diagonal(np.empty((0, 1)))
27
28    assert cirq.is_diagonal(np.array([[1]]))
29    assert cirq.is_diagonal(np.array([[-1]]))
30    assert cirq.is_diagonal(np.array([[5]]))
31    assert cirq.is_diagonal(np.array([[3j]]))
32
33    assert cirq.is_diagonal(np.array([[1, 0]]))
34    assert cirq.is_diagonal(np.array([[1], [0]]))
35    assert not cirq.is_diagonal(np.array([[1, 1]]))
36    assert not cirq.is_diagonal(np.array([[1], [1]]))
37
38    assert cirq.is_diagonal(np.array([[5j, 0], [0, 2]]))
39    assert cirq.is_diagonal(np.array([[1, 0], [0, 1]]))
40    assert not cirq.is_diagonal(np.array([[1, 0], [1, 1]]))
41    assert not cirq.is_diagonal(np.array([[1, 1], [0, 1]]))
42    assert not cirq.is_diagonal(np.array([[1, 1], [1, 1]]))
43    assert not cirq.is_diagonal(np.array([[1, 0.1], [0.1, 1]]))
44
45    assert cirq.is_diagonal(np.array([[1, 1e-11], [1e-10, 1]]))
46
47
48def test_is_diagonal_tolerance():
49    atol = 0.5
50
51    # Pays attention to specified tolerance.
52    assert cirq.is_diagonal(np.array([[1, 0], [-0.5, 1]]), atol=atol)
53    assert not cirq.is_diagonal(np.array([[1, 0], [-0.6, 1]]), atol=atol)
54
55    # Error isn't accumulated across entries.
56    assert cirq.is_diagonal(np.array([[1, 0.5], [-0.5, 1]]), atol=atol)
57    assert not cirq.is_diagonal(np.array([[1, 0.5], [-0.6, 1]]), atol=atol)
58
59
60def test_is_hermitian():
61    assert cirq.is_hermitian(np.empty((0, 0)))
62    assert not cirq.is_hermitian(np.empty((1, 0)))
63    assert not cirq.is_hermitian(np.empty((0, 1)))
64
65    assert cirq.is_hermitian(np.array([[1]]))
66    assert cirq.is_hermitian(np.array([[-1]]))
67    assert cirq.is_hermitian(np.array([[5]]))
68    assert not cirq.is_hermitian(np.array([[3j]]))
69
70    assert not cirq.is_hermitian(np.array([[0, 0]]))
71    assert not cirq.is_hermitian(np.array([[0], [0]]))
72
73    assert not cirq.is_hermitian(np.array([[5j, 0], [0, 2]]))
74    assert cirq.is_hermitian(np.array([[5, 0], [0, 2]]))
75    assert cirq.is_hermitian(np.array([[1, 0], [0, 1]]))
76    assert not cirq.is_hermitian(np.array([[1, 0], [1, 1]]))
77    assert not cirq.is_hermitian(np.array([[1, 1], [0, 1]]))
78    assert cirq.is_hermitian(np.array([[1, 1], [1, 1]]))
79    assert cirq.is_hermitian(np.array([[1, 1j], [-1j, 1]]))
80    assert cirq.is_hermitian(np.array([[1, 1j], [-1j, 1]]) * np.sqrt(0.5))
81    assert not cirq.is_hermitian(np.array([[1, 1j], [1j, 1]]))
82    assert not cirq.is_hermitian(np.array([[1, 0.1], [-0.1, 1]]))
83
84    assert cirq.is_hermitian(np.array([[1, 1j + 1e-11], [-1j, 1 + 1j * 1e-9]]))
85
86
87def test_is_hermitian_tolerance():
88    atol = 0.5
89
90    # Pays attention to specified tolerance.
91    assert cirq.is_hermitian(np.array([[1, 0], [-0.5, 1]]), atol=atol)
92    assert cirq.is_hermitian(np.array([[1, 0.25], [-0.25, 1]]), atol=atol)
93    assert not cirq.is_hermitian(np.array([[1, 0], [-0.6, 1]]), atol=atol)
94    assert not cirq.is_hermitian(np.array([[1, 0.25], [-0.35, 1]]), atol=atol)
95
96    # Error isn't accumulated across entries.
97    assert cirq.is_hermitian(np.array([[1, 0.5, 0.5], [0, 1, 0], [0, 0, 1]]), atol=atol)
98    assert not cirq.is_hermitian(np.array([[1, 0.5, 0.6], [0, 1, 0], [0, 0, 1]]), atol=atol)
99    assert not cirq.is_hermitian(np.array([[1, 0, 0.6], [0, 1, 0], [0, 0, 1]]), atol=atol)
100
101
102def test_is_unitary():
103    assert cirq.is_unitary(np.empty((0, 0)))
104    assert not cirq.is_unitary(np.empty((1, 0)))
105    assert not cirq.is_unitary(np.empty((0, 1)))
106
107    assert cirq.is_unitary(np.array([[1]]))
108    assert cirq.is_unitary(np.array([[-1]]))
109    assert cirq.is_unitary(np.array([[1j]]))
110    assert not cirq.is_unitary(np.array([[5]]))
111    assert not cirq.is_unitary(np.array([[3j]]))
112
113    assert not cirq.is_unitary(np.array([[1, 0]]))
114    assert not cirq.is_unitary(np.array([[1], [0]]))
115
116    assert not cirq.is_unitary(np.array([[1, 0], [0, -2]]))
117    assert cirq.is_unitary(np.array([[1, 0], [0, -1]]))
118    assert cirq.is_unitary(np.array([[1j, 0], [0, 1]]))
119    assert not cirq.is_unitary(np.array([[1, 0], [1, 1]]))
120    assert not cirq.is_unitary(np.array([[1, 1], [0, 1]]))
121    assert not cirq.is_unitary(np.array([[1, 1], [1, 1]]))
122    assert not cirq.is_unitary(np.array([[1, -1], [1, 1]]))
123    assert cirq.is_unitary(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5))
124    assert cirq.is_unitary(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5))
125    assert not cirq.is_unitary(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5))
126
127    assert cirq.is_unitary(np.array([[1, 1j + 1e-11], [1j, 1 + 1j * 1e-9]]) * np.sqrt(0.5))
128
129
130def test_is_unitary_tolerance():
131    atol = 0.5
132
133    # Pays attention to specified tolerance.
134    assert cirq.is_unitary(np.array([[1, 0], [-0.5, 1]]), atol=atol)
135    assert not cirq.is_unitary(np.array([[1, 0], [-0.6, 1]]), atol=atol)
136
137    # Error isn't accumulated across entries.
138    assert cirq.is_unitary(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol)
139    assert not cirq.is_unitary(np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1.2]]), atol=atol)
140
141
142def test_is_orthogonal():
143    assert cirq.is_orthogonal(np.empty((0, 0)))
144    assert not cirq.is_orthogonal(np.empty((1, 0)))
145    assert not cirq.is_orthogonal(np.empty((0, 1)))
146
147    assert cirq.is_orthogonal(np.array([[1]]))
148    assert cirq.is_orthogonal(np.array([[-1]]))
149    assert not cirq.is_orthogonal(np.array([[1j]]))
150    assert not cirq.is_orthogonal(np.array([[5]]))
151    assert not cirq.is_orthogonal(np.array([[3j]]))
152
153    assert not cirq.is_orthogonal(np.array([[1, 0]]))
154    assert not cirq.is_orthogonal(np.array([[1], [0]]))
155
156    assert not cirq.is_orthogonal(np.array([[1, 0], [0, -2]]))
157    assert cirq.is_orthogonal(np.array([[1, 0], [0, -1]]))
158    assert not cirq.is_orthogonal(np.array([[1j, 0], [0, 1]]))
159    assert not cirq.is_orthogonal(np.array([[1, 0], [1, 1]]))
160    assert not cirq.is_orthogonal(np.array([[1, 1], [0, 1]]))
161    assert not cirq.is_orthogonal(np.array([[1, 1], [1, 1]]))
162    assert not cirq.is_orthogonal(np.array([[1, -1], [1, 1]]))
163    assert cirq.is_orthogonal(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5))
164    assert not cirq.is_orthogonal(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5))
165    assert not cirq.is_orthogonal(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5))
166
167    assert cirq.is_orthogonal(np.array([[1, 1e-11], [0, 1 + 1e-11]]))
168
169
170def test_is_orthogonal_tolerance():
171    atol = 0.5
172
173    # Pays attention to specified tolerance.
174    assert cirq.is_orthogonal(np.array([[1, 0], [-0.5, 1]]), atol=atol)
175    assert not cirq.is_orthogonal(np.array([[1, 0], [-0.6, 1]]), atol=atol)
176
177    # Error isn't accumulated across entries.
178    assert cirq.is_orthogonal(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol)
179    assert not cirq.is_orthogonal(np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1.2]]), atol=atol)
180
181
182def test_is_special_orthogonal():
183    assert cirq.is_special_orthogonal(np.empty((0, 0)))
184    assert not cirq.is_special_orthogonal(np.empty((1, 0)))
185    assert not cirq.is_special_orthogonal(np.empty((0, 1)))
186
187    assert cirq.is_special_orthogonal(np.array([[1]]))
188    assert not cirq.is_special_orthogonal(np.array([[-1]]))
189    assert not cirq.is_special_orthogonal(np.array([[1j]]))
190    assert not cirq.is_special_orthogonal(np.array([[5]]))
191    assert not cirq.is_special_orthogonal(np.array([[3j]]))
192
193    assert not cirq.is_special_orthogonal(np.array([[1, 0]]))
194    assert not cirq.is_special_orthogonal(np.array([[1], [0]]))
195
196    assert not cirq.is_special_orthogonal(np.array([[1, 0], [0, -2]]))
197    assert not cirq.is_special_orthogonal(np.array([[1, 0], [0, -1]]))
198    assert cirq.is_special_orthogonal(np.array([[-1, 0], [0, -1]]))
199    assert not cirq.is_special_orthogonal(np.array([[1j, 0], [0, 1]]))
200    assert not cirq.is_special_orthogonal(np.array([[1, 0], [1, 1]]))
201    assert not cirq.is_special_orthogonal(np.array([[1, 1], [0, 1]]))
202    assert not cirq.is_special_orthogonal(np.array([[1, 1], [1, 1]]))
203    assert not cirq.is_special_orthogonal(np.array([[1, -1], [1, 1]]))
204    assert cirq.is_special_orthogonal(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5))
205    assert not cirq.is_special_orthogonal(np.array([[1, 1], [1, -1]]) * np.sqrt(0.5))
206    assert not cirq.is_special_orthogonal(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5))
207    assert not cirq.is_special_orthogonal(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5))
208
209    assert cirq.is_special_orthogonal(np.array([[1, 1e-11], [0, 1 + 1e-11]]))
210
211
212def test_is_special_orthogonal_tolerance():
213    atol = 0.5
214
215    # Pays attention to specified tolerance.
216    assert cirq.is_special_orthogonal(np.array([[1, 0], [-0.5, 1]]), atol=atol)
217    assert not cirq.is_special_orthogonal(np.array([[1, 0], [-0.6, 1]]), atol=atol)
218
219    # Error isn't accumulated across entries, except for determinant factors.
220    assert cirq.is_special_orthogonal(
221        np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1 / 1.2]]), atol=atol
222    )
223    assert not cirq.is_special_orthogonal(
224        np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol
225    )
226    assert not cirq.is_special_orthogonal(
227        np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1 / 1.2]]), atol=atol
228    )
229
230
231def test_is_special_unitary():
232    assert cirq.is_special_unitary(np.empty((0, 0)))
233    assert not cirq.is_special_unitary(np.empty((1, 0)))
234    assert not cirq.is_special_unitary(np.empty((0, 1)))
235
236    assert cirq.is_special_unitary(np.array([[1]]))
237    assert not cirq.is_special_unitary(np.array([[-1]]))
238    assert not cirq.is_special_unitary(np.array([[5]]))
239    assert not cirq.is_special_unitary(np.array([[3j]]))
240
241    assert not cirq.is_special_unitary(np.array([[1, 0], [0, -2]]))
242    assert not cirq.is_special_unitary(np.array([[1, 0], [0, -1]]))
243    assert cirq.is_special_unitary(np.array([[-1, 0], [0, -1]]))
244    assert not cirq.is_special_unitary(np.array([[1j, 0], [0, 1]]))
245    assert cirq.is_special_unitary(np.array([[1j, 0], [0, -1j]]))
246    assert not cirq.is_special_unitary(np.array([[1, 0], [1, 1]]))
247    assert not cirq.is_special_unitary(np.array([[1, 1], [0, 1]]))
248    assert not cirq.is_special_unitary(np.array([[1, 1], [1, 1]]))
249    assert not cirq.is_special_unitary(np.array([[1, -1], [1, 1]]))
250    assert cirq.is_special_unitary(np.array([[1, -1], [1, 1]]) * np.sqrt(0.5))
251    assert cirq.is_special_unitary(np.array([[1, 1j], [1j, 1]]) * np.sqrt(0.5))
252    assert not cirq.is_special_unitary(np.array([[1, -1j], [1j, 1]]) * np.sqrt(0.5))
253
254    assert cirq.is_special_unitary(np.array([[1, 1j + 1e-11], [1j, 1 + 1j * 1e-9]]) * np.sqrt(0.5))
255
256
257def test_is_special_unitary_tolerance():
258    atol = 0.5
259
260    # Pays attention to specified tolerance.
261    assert cirq.is_special_unitary(np.array([[1, 0], [-0.5, 1]]), atol=atol)
262    assert not cirq.is_special_unitary(np.array([[1, 0], [-0.6, 1]]), atol=atol)
263    assert cirq.is_special_unitary(np.array([[1, 0], [0, 1]]) * cmath.exp(1j * 0.1), atol=atol)
264    assert not cirq.is_special_unitary(np.array([[1, 0], [0, 1]]) * cmath.exp(1j * 0.3), atol=atol)
265
266    # Error isn't accumulated across entries, except for determinant factors.
267    assert cirq.is_special_unitary(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1 / 1.2]]), atol=atol)
268    assert not cirq.is_special_unitary(np.array([[1.2, 0, 0], [0, 1.2, 0], [0, 0, 1.2]]), atol=atol)
269    assert not cirq.is_special_unitary(
270        np.array([[1.2, 0, 0], [0, 1.3, 0], [0, 0, 1 / 1.2]]), atol=atol
271    )
272
273
274def test_is_normal():
275    assert cirq.is_normal(np.array([[1]]))
276    assert cirq.is_normal(np.array([[3j]]))
277    assert cirq.is_normal(cirq.testing.random_density_matrix(4))
278    assert cirq.is_normal(cirq.testing.random_unitary(5))
279    assert not cirq.is_normal(np.array([[0, 1], [0, 0]]))
280    assert not cirq.is_normal(np.zeros((1, 0)))
281
282
283def test_is_normal_tolerance():
284    atol = 0.25
285
286    # Pays attention to specified tolerance.
287    assert cirq.is_normal(np.array([[0, 0.5], [0, 0]]), atol=atol)
288    assert not cirq.is_normal(np.array([[0, 0.6], [0, 0]]), atol=atol)
289
290    # Error isn't accumulated across entries.
291    assert cirq.is_normal(np.array([[0, 0.5, 0], [0, 0, 0.5], [0, 0, 0]]), atol=atol)
292    assert not cirq.is_normal(np.array([[0, 0.5, 0], [0, 0, 0.6], [0, 0, 0]]), atol=atol)
293
294
295def test_is_cptp():
296    rt2 = np.sqrt(0.5)
297    # Amplitude damping with gamma=0.5.
298    assert cirq.is_cptp(kraus_ops=[np.array([[1, 0], [0, rt2]]), np.array([[0, rt2], [0, 0]])])
299    # Depolarizing channel with p=0.75.
300    assert cirq.is_cptp(
301        kraus_ops=[
302            np.array([[1, 0], [0, 1]]) * 0.5,
303            np.array([[0, 1], [1, 0]]) * 0.5,
304            np.array([[0, -1j], [1j, 0]]) * 0.5,
305            np.array([[1, 0], [0, -1]]) * 0.5,
306        ]
307    )
308
309    assert not cirq.is_cptp(kraus_ops=[np.array([[1, 0], [0, 1]]), np.array([[0, 1], [0, 0]])])
310    assert not cirq.is_cptp(
311        kraus_ops=[
312            np.array([[1, 0], [0, 1]]),
313            np.array([[0, 1], [1, 0]]),
314            np.array([[0, -1j], [1j, 0]]),
315            np.array([[1, 0], [0, -1]]),
316        ]
317    )
318
319    # Makes 4 2x2 kraus ops.
320    one_qubit_u = cirq.testing.random_unitary(8)
321    one_qubit_kraus = np.reshape(one_qubit_u[:, :2], (-1, 2, 2))
322    assert cirq.is_cptp(kraus_ops=one_qubit_kraus)
323
324    # Makes 16 4x4 kraus ops.
325    two_qubit_u = cirq.testing.random_unitary(64)
326    two_qubit_kraus = np.reshape(two_qubit_u[:, :4], (-1, 4, 4))
327    assert cirq.is_cptp(kraus_ops=two_qubit_kraus)
328
329
330def test_is_cptp_tolerance():
331    rt2_ish = np.sqrt(0.5) - 0.01
332    atol = 0.25
333    # Moderately-incorrect amplitude damping with gamma=0.5.
334    assert cirq.is_cptp(
335        kraus_ops=[np.array([[1, 0], [0, rt2_ish]]), np.array([[0, rt2_ish], [0, 0]])], atol=atol
336    )
337    assert not cirq.is_cptp(
338        kraus_ops=[np.array([[1, 0], [0, rt2_ish]]), np.array([[0, rt2_ish], [0, 0]])], atol=1e-8
339    )
340
341
342def test_commutes():
343    assert matrix_commutes(np.empty((0, 0)), np.empty((0, 0)))
344    assert not matrix_commutes(np.empty((1, 0)), np.empty((0, 1)))
345    assert not matrix_commutes(np.empty((0, 1)), np.empty((1, 0)))
346    assert not matrix_commutes(np.empty((1, 0)), np.empty((1, 0)))
347    assert not matrix_commutes(np.empty((0, 1)), np.empty((0, 1)))
348
349    assert matrix_commutes(np.array([[1]]), np.array([[2]]))
350    assert matrix_commutes(np.array([[1]]), np.array([[0]]))
351
352    x = np.array([[0, 1], [1, 0]])
353    y = np.array([[0, -1j], [1j, 0]])
354    z = np.array([[1, 0], [0, -1]])
355    xx = np.kron(x, x)
356    zz = np.kron(z, z)
357
358    assert matrix_commutes(x, x)
359    assert matrix_commutes(y, y)
360    assert matrix_commutes(z, z)
361    assert not matrix_commutes(x, y)
362    assert not matrix_commutes(x, z)
363    assert not matrix_commutes(y, z)
364
365    assert matrix_commutes(xx, zz)
366    assert matrix_commutes(xx, np.diag([1, -1, -1, 1 + 1e-9]))
367
368
369def test_commutes_tolerance():
370    atol = 0.5
371
372    x = np.array([[0, 1], [1, 0]])
373    z = np.array([[1, 0], [0, -1]])
374
375    # Pays attention to specified tolerance.
376    assert matrix_commutes(x, x + z * 0.1, atol=atol)
377    assert not matrix_commutes(x, x + z * 0.5, atol=atol)
378
379
380def test_allclose_up_to_global_phase():
381    assert cirq.allclose_up_to_global_phase(np.array([1]), np.array([1j]))
382
383    assert not cirq.allclose_up_to_global_phase(np.array([[[1]]]), np.array([1]))
384
385    assert cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[1]]))
386    assert cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[-1]]))
387
388    assert cirq.allclose_up_to_global_phase(np.array([[0]]), np.array([[0]]))
389
390    assert cirq.allclose_up_to_global_phase(np.array([[1, 2]]), np.array([[1j, 2j]]))
391
392    assert cirq.allclose_up_to_global_phase(np.array([[1, 2.0000000001]]), np.array([[1j, 2j]]))
393
394    assert not cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[1, 0]]))
395    assert not cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[2]]))
396    assert not cirq.allclose_up_to_global_phase(np.array([[1]]), np.array([[2]]))
397
398
399def test_binary_sub_tensor_slice():
400    a = slice(None)
401    e = Ellipsis
402
403    assert cirq.slice_for_qubits_equal_to([], 0) == (e,)
404    assert cirq.slice_for_qubits_equal_to([0], 0b0) == (0, e)
405    assert cirq.slice_for_qubits_equal_to([0], 0b1) == (1, e)
406    assert cirq.slice_for_qubits_equal_to([1], 0b0) == (a, 0, e)
407    assert cirq.slice_for_qubits_equal_to([1], 0b1) == (a, 1, e)
408    assert cirq.slice_for_qubits_equal_to([2], 0b0) == (a, a, 0, e)
409    assert cirq.slice_for_qubits_equal_to([2], 0b1) == (a, a, 1, e)
410
411    assert cirq.slice_for_qubits_equal_to([0, 1], 0b00) == (0, 0, e)
412    assert cirq.slice_for_qubits_equal_to([1, 2], 0b00) == (a, 0, 0, e)
413    assert cirq.slice_for_qubits_equal_to([1, 3], 0b00) == (a, 0, a, 0, e)
414    assert cirq.slice_for_qubits_equal_to([1, 3], 0b10) == (a, 0, a, 1, e)
415    assert cirq.slice_for_qubits_equal_to([3, 1], 0b10) == (a, 1, a, 0, e)
416
417    assert cirq.slice_for_qubits_equal_to([2, 1, 0], 0b001) == (0, 0, 1, e)
418    assert cirq.slice_for_qubits_equal_to([2, 1, 0], 0b010) == (0, 1, 0, e)
419    assert cirq.slice_for_qubits_equal_to([2, 1, 0], 0b100) == (1, 0, 0, e)
420    assert cirq.slice_for_qubits_equal_to([0, 1, 2], 0b101) == (1, 0, 1, e)
421    assert cirq.slice_for_qubits_equal_to([0, 2, 1], 0b101) == (1, 1, 0, e)
422
423    m = np.array([0] * 16).reshape((2, 2, 2, 2))
424    for k in range(16):
425        m[cirq.slice_for_qubits_equal_to([3, 2, 1, 0], k)] = k
426    assert list(m.reshape(16)) == list(range(16))
427
428    assert cirq.slice_for_qubits_equal_to([0], 0b1, num_qubits=1) == (1,)
429    assert cirq.slice_for_qubits_equal_to([1], 0b0, num_qubits=2) == (a, 0)
430    assert cirq.slice_for_qubits_equal_to([1], 0b0, num_qubits=3) == (a, 0, a)
431    assert cirq.slice_for_qubits_equal_to([2], 0b0, num_qubits=3) == (a, a, 0)
432
433
434def test_binary_sub_tensor_slice_big_endian():
435    a = slice(None)
436    e = Ellipsis
437    sfqet = cirq.slice_for_qubits_equal_to
438
439    assert sfqet([], big_endian_qureg_value=0) == (e,)
440    assert sfqet([0], big_endian_qureg_value=0b0) == (0, e)
441    assert sfqet([0], big_endian_qureg_value=0b1) == (1, e)
442    assert sfqet([1], big_endian_qureg_value=0b0) == (a, 0, e)
443    assert sfqet([1], big_endian_qureg_value=0b1) == (a, 1, e)
444    assert sfqet([2], big_endian_qureg_value=0b0) == (a, a, 0, e)
445    assert sfqet([2], big_endian_qureg_value=0b1) == (a, a, 1, e)
446
447    assert sfqet([0, 1], big_endian_qureg_value=0b00) == (0, 0, e)
448    assert sfqet([1, 2], big_endian_qureg_value=0b00) == (a, 0, 0, e)
449    assert sfqet([1, 3], big_endian_qureg_value=0b00) == (a, 0, a, 0, e)
450    assert sfqet([1, 3], big_endian_qureg_value=0b01) == (a, 0, a, 1, e)
451    assert sfqet([3, 1], big_endian_qureg_value=0b01) == (a, 1, a, 0, e)
452
453    assert sfqet([2, 1, 0], big_endian_qureg_value=0b100) == (0, 0, 1, e)
454    assert sfqet([2, 1, 0], big_endian_qureg_value=0b010) == (0, 1, 0, e)
455    assert sfqet([2, 1, 0], big_endian_qureg_value=0b001) == (1, 0, 0, e)
456    assert sfqet([0, 1, 2], big_endian_qureg_value=0b101) == (1, 0, 1, e)
457    assert sfqet([0, 2, 1], big_endian_qureg_value=0b101) == (1, 1, 0, e)
458
459    m = np.array([0] * 16).reshape((2, 2, 2, 2))
460    for k in range(16):
461        m[sfqet([0, 1, 2, 3], big_endian_qureg_value=k)] = k
462    assert list(m.reshape(16)) == list(range(16))
463
464    assert sfqet([0], big_endian_qureg_value=0b1, num_qubits=1) == (1,)
465    assert sfqet([1], big_endian_qureg_value=0b0, num_qubits=2) == (a, 0)
466    assert sfqet([1], big_endian_qureg_value=0b0, num_qubits=3) == (a, 0, a)
467    assert sfqet([2], big_endian_qureg_value=0b0, num_qubits=3) == (a, a, 0)
468
469
470def test_qudit_sub_tensor_slice():
471    a = slice(None)
472    sfqet = cirq.slice_for_qubits_equal_to
473
474    assert sfqet([], 0, qid_shape=()) == ()
475    assert sfqet([0], 0, qid_shape=(3,)) == (0,)
476    assert sfqet([0], 1, qid_shape=(3,)) == (1,)
477    assert sfqet([0], 2, qid_shape=(3,)) == (2,)
478    assert sfqet([2], 0, qid_shape=(1, 2, 3)) == (a, a, 0)
479    assert sfqet([2], 2, qid_shape=(1, 2, 3)) == (a, a, 2)
480    assert sfqet([2], big_endian_qureg_value=2, qid_shape=(1, 2, 3)) == (a, a, 2)
481
482    assert sfqet([1, 3], 3 * 2 + 1, qid_shape=(2, 3, 4, 5)) == (a, 1, a, 2)
483    assert sfqet([3, 1], 5 * 2 + 1, qid_shape=(2, 3, 4, 5)) == (a, 2, a, 1)
484    assert sfqet([2, 1, 0], 9 * 2 + 3 * 1, qid_shape=(3,) * 3) == (2, 1, 0)
485    assert sfqet([1, 3], big_endian_qureg_value=5 * 1 + 2, qid_shape=(2, 3, 4, 5)) == (a, 1, a, 2)
486    assert sfqet([3, 1], big_endian_qureg_value=3 * 1 + 2, qid_shape=(2, 3, 4, 5)) == (a, 2, a, 1)
487
488    m = np.array([0] * 24).reshape((1, 2, 3, 4))
489    for k in range(24):
490        m[sfqet([3, 2, 1, 0], k, qid_shape=(1, 2, 3, 4))] = k
491    assert list(m.reshape(24)) == list(range(24))
492
493    assert sfqet([0], 1, num_qubits=1, qid_shape=(3,)) == (1,)
494    assert sfqet([1], 0, num_qubits=3, qid_shape=(3, 3, 3)) == (a, 0, a)
495
496    with pytest.raises(ValueError, match='len.* !='):
497        sfqet([], num_qubits=2, qid_shape=(1, 2, 3))
498
499    with pytest.raises(ValueError, match='exactly one'):
500        sfqet([0, 1, 2], 0b101, big_endian_qureg_value=0b101)
501